In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AffNetFast(nn.Module):
    def __init__(self):
        super(AffNetFast, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1, bias = False),
            nn.BatchNorm2d(16, affine=False),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias = False),
            nn.BatchNorm2d(16, affine=False),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias = False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias = False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias = False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
            nn.Dropout(0.),
            nn.Conv2d(64, 3, kernel_size=8, stride=1, padding=0, bias = True),
            nn.Tanh(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.register_buffer('const', torch.tensor([[1.,0.,1.]],dtype=torch.float).view(1,3) )
        return
    def forward(self, input):
        xy = self.features(self.input_norm(input)).view(-1,3);
        xy = xy + self.const
        return xy

    def input_norm(self,x):
        std, mean = torch.std_mean(x, dim=[2,3])
        return (x - mean.detach()[...,None,None]) / (std.detach()[...,None,None]+1e-7)


model = AffNetFast()
checkpoint = '../pretrained/AffNet.pth'
model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict'],strict=False)
model.eval()

AffNetFast(
  (features): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (11): ReLU()
    (12): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (14): 

In [2]:
#Converting to JIT
example = torch.rand(1,1,32,32)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("AffNetJIT.pt")

In [3]:
#Checking if this works properly
inp1 = torch.rand(1, 1, 32, 32)
out_jit = traced_script_module(inp1)
out_pytorch = model(inp1)
print (out_jit - out_pytorch)

tensor([[0., 0., 0.]], grad_fn=<SubBackward0>)


In [4]:
class OriNetFast(nn.Module):
    def __init__(self):
        super(OriNetFast, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1, bias = False),
            nn.BatchNorm2d(16, affine=False),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias = False),
            nn.BatchNorm2d(16, affine=False),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias = False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias = False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias = False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Conv2d(64, 2, kernel_size=8, stride=1,padding=1, bias = True),
            nn.Tanh(),
            nn.AdaptiveAvgPool2d(1)
        )
        return
    def input_norm(self,x):
        std, mean = torch.std_mean(x, dim=[2,3])
        return (x - mean.detach()[...,None,None]) / (std.detach()[...,None,None]+1e-7)
    def forward(self, input):
        xy = self.features(self.input_norm(input)).view(-1,2) 
        return xy

In [5]:
model = OriNetFast()
checkpoint = '../pretrained/OriNet.pth'
model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict'])
model.eval()

#Converting to JIT
example = torch.rand(1,1,32,32)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("OriNetJIT.pt")