In [1]:
from torch.nn.modules.module import Module
import torch
from torch.autograd import Variable
import numpy as np

In [2]:
class AffineGridGenV2(Module):
    def __init__(self, height, width, lr = 1, aux_loss = False):
        super(AffineGridGenV2, self).__init__()
        self.height, self.width = height, width
        self.aux_loss = aux_loss
        self.lr = lr
        
        self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
        self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.height), 0), repeats = self.width, axis = 0).T, 0)
        self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)
        self.grid[:,:,2] = np.ones([self.height, width])
        self.grid = torch.from_numpy(self.grid.astype(np.float32))
        
        
    def forward(self, input1):
        self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())

        for i in range(input1.size(0)):
            self.batchgrid[i] = self.grid
        self.batchgrid = Variable(self.batchgrid)
        output = torch.bmm(self.batchgrid.view(-1, self.height*self.width, 3), torch.transpose(input1, 1, 2)).view(-1, self.height, self.width, 2)
        
        return output

In [27]:
class DenseAffineGridGen(Module):
    def __init__(self, height, width, lr = 1, aux_loss = False):
        super(DenseAffineGridGen, self).__init__()
        self.height, self.width = height, width
        self.aux_loss = aux_loss
        self.lr = lr
        
        self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
        self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.height), 0), repeats = self.width, axis = 0).T, 0)
        self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)
        self.grid[:,:,2] = np.ones([self.height, width])
        self.grid = torch.from_numpy(self.grid.astype(np.float32))
        
        
    def forward(self, input1):
        self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())

        for i in range(input1.size(0)):
            self.batchgrid[i] = self.grid
        
        self.batchgrid = Variable(self.batchgrid)
        #print self.batchgrid,  input1[:,:,:,0:3]
        #print self.batchgrid,  input1[:,:,:,4:6]
        x = torch.mul(self.batchgrid, input1[:,:,:,0:3])
        y = torch.mul(self.batchgrid, input1[:,:,:,3:6])
        
        output = torch.cat([torch.sum(x,3),torch.sum(y,3)], 3)
        return output

In [28]:
g = AffineGridGenV2(3,5)

In [29]:
input = Variable(torch.from_numpy(np.array([[[1, 0.5, 0], [0.5, 1, 0]]], dtype=np.float32)), requires_grad = True)

In [30]:
print input
out = g(input)

Variable containing:
(0 ,.,.) = 
  1.0000  0.5000  0.0000
  0.5000  1.0000  0.0000
[torch.FloatTensor of size 1x2x3]



In [31]:
rnd = torch.rand((1,3,5,2))

In [32]:
out.backward(rnd)

In [33]:
input.grad

Variable containing:
(0 ,.,.) = 
 -3.0555 -2.0090  7.6003
 -1.2794 -1.7730  5.7409
[torch.FloatTensor of size 1x2x3]

In [34]:
g = DenseAffineGridGen(3,5)

In [35]:
input = Variable(torch.rand((1,3,5,6)), requires_grad = True)

In [38]:
out = g(input)
out.backward(rnd)

In [39]:
input.grad

Variable containing:
(0 ,0 ,.,.) = 
 -0.7175 -0.7175  0.7175 -0.1054 -0.1054  0.1054
 -0.1534 -0.0920  0.1534 -0.0064 -0.0039  0.0064
 -0.7153 -0.1431  0.7153 -0.4376 -0.0875  0.4376
 -0.2002  0.0400  0.2002 -0.7665  0.1533  0.7665
 -0.7103  0.4262  0.7103 -0.1890  0.1134  0.1890

(0 ,1 ,.,.) = 
 -0.2301 -0.6903  0.6903 -0.1463 -0.4389  0.4389
 -0.2595 -0.4671  0.7785 -0.0549 -0.0989  0.1648
 -0.2411 -0.1446  0.7232 -0.2792 -0.1675  0.8376
 -0.1337  0.0802  0.4012 -0.0671  0.0402  0.2012
 -0.2656  0.4781  0.7968 -0.0457  0.0822  0.1370

(0 ,2 ,.,.) = 
  0.3117 -0.9350  0.9350  0.2860 -0.8580  0.8580
  0.0414 -0.0746  0.1243  0.2675 -0.4816  0.8026
  0.0226 -0.0135  0.0677  0.1060 -0.0636  0.3180
  0.0897  0.0538  0.2692  0.1197  0.0718  0.3590
  0.1058  0.1904  0.3173  0.0396  0.0713  0.1188
[torch.FloatTensor of size 1x3x5x6]