In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
        
        # Regressor for the 3x2 affine matrix
        self.fc_loc = nn.Sequential(
                nn.Linear(10 * 3 * 3, 32),
                nn.ReLU(True),
                nn.Linear(32, 3*2)
                )
        
        # Initialize the weights/bias with identity transformation 
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0], dtype=torch.float))
            
    def affine_grid(self, theta, size, align_corners=False):
        N, C, H, W = size
        grid = self.create_grid(N, C, H, W).to(theta.device)
        grid = grid.view(N, H * W, 3).bmm(theta.transpose(1, 2))
        grid = grid.view(N, H, W, 2)
        return grid

    def create_grid(self, N, C, H, W):
        grid = torch.empty((N, H, W, 3), dtype=torch.float32)
        grid.select(-1, 0).copy_(self.linspace_from_neg_one(W))
        grid.select(-1, 1).copy_(self.linspace_from_neg_one(H).unsqueeze_(-1))
        grid.select(-1, 2).fill_(1)
        return grid
    
    def linspace_from_neg_one(self, num_steps, dtype=torch.float32):
        r = torch.linspace(-1, 1, num_steps, dtype=torch.float32)
        r = r * (num_steps - 1) / num_steps
        return r
    
    def grid_sample(self, im, grid, align_corners=False):
        # https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/point_sample.py
        
        n, c, h, w = im.shape
        gn, gh, gw, _ = grid.shape
        # assert n == gn

        x = grid[:, :, :, 0]
        y = grid[:, :, :, 1]

        if align_corners:
            x = ((x + 1) / 2) * (w - 1)
            y = ((y + 1) / 2) * (h - 1)
        else:
            x = ((x + 1) * w - 1) / 2
            y = ((y + 1) * h - 1) / 2

        x = x.view(n, -1)
        y = y.view(n, -1)

        x0 = torch.floor(x).long()
        y0 = torch.floor(y).long()
        x1 = x0 + 1
        y1 = y0 + 1

        wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
        wb = ((x1 - x) * (y - y0)).unsqueeze(1)
        wc = ((x - x0) * (y1 - y)).unsqueeze(1)
        wd = ((x - x0) * (y - y0)).unsqueeze(1)

        # Apply default for grid_sample function zero padding
        im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
        padded_h = h + 2
        padded_w = w + 2
        
        # save points positions after padding
        x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

        # Clip coordinates to padded image size
        zero = torch.tensor(0).to(x0.device)
        
        x0 = torch.where(x0 < 0, zero, x0)
        x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).to(x0.device), x0)
        
        x1 = torch.where(x1 < 0, zero, x1)
        x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).to(x0.device), x1)
        
        y0 = torch.where(y0 < 0, zero, y0)
        y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).to(x0.device), y0)
        
        y1 = torch.where(y1 < 0, zero, y1)
        y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).to(x0.device), y1)

        im_padded = im_padded.view(n, c, -1)

        x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
        x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
        x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
        x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

        Ia = torch.gather(im_padded, 2, x0_y0)
        Ib = torch.gather(im_padded, 2, x0_y1)
        Ic = torch.gather(im_padded, 2, x1_y0)
        Id = torch.gather(im_padded, 2, x1_y1)

        return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
    
    def forward(self, x):
        N, C, H, W = x.shape
        
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1,2,3)
        
        # grid = F.affine_grid(theta, (N, C, H, W))
        grid = self.affine_grid(theta, (N, C, H, W))

        # x = torch.nn.functional.grid_sample(x, grid, align_corners=False)
        x = self.grid_sample(x, grid, align_corners=False)
        
        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)


model = Net().to('cuda:0')

In [3]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

In [4]:
epochs = 2

# Training loop
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(epochs):
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to('cuda:0'), target.to('cuda:0')

        optimizer.zero_grad()
        output = model(data)

        # print(output.shape)
        # print(target.shape)

        loss = torch.nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()

1875it [00:16, 112.43it/s]
1875it [00:16, 114.45it/s]


In [5]:
# Export the model to ONNX format
dummy_input = torch.randn(1, 1, 28, 28)  # Example input shape
torch.onnx.export(model, dummy_input.to('cuda:0'), "stn_mnist.onnx", verbose=False)

  zero = torch.tensor(0).to(x0.device)
  x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).to(x0.device), x0)
  x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).to(x0.device), x0)
  x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).to(x0.device), x1)
  x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).to(x0.device), x1)
  y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).to(x0.device), y0)
  y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).to(x0.device), y0)
  y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).to(x0.device), y1)
  y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).to(x0.device), y1)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
