In [1]:
# Convert torch model into pytorch model

import torch
import torch.nn as nn
import torch.legacy.nn as nn1
from torch.utils.serialization import load_lua
from torch.autograd import Variable

In [2]:
# load torch model

nn1.SpatialConvolutionMM = nn1.SpatialConvolution #load_lua does not recognize SpatialConvolutionMM

m1 = load_lua('/Workspace/model.net')
m1.evaluate()

In [3]:
def patch(m):
    s = str(type(m))
    s = s[str.rfind(s, '.')+1:-2]
    if s == 'Padding' and hasattr(m, 'nInputDim') and m.nInputDim == 3:
        m.dim = m.dim + 1
    if s == 'View' and len(m.size) == 1:
        m.size = torch.Size([1,m.size[0]])
    if hasattr(m, 'modules'):
        for m in m.modules:
            patch(m)


In [4]:
patch(m1)
print(m1)

nn.Sequential {
  [input -> (0) -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> output]
  (0): nn.SpatialConvolution(3 -> 64, 11x11, 4, 4, 2, 2)
  (1): nn.ReLU
  (2): nn.SpatialMaxPooling(3x3, 2, 2)
  (3): nn.SpatialConvolution(64 -> 192, 5x5, 1, 1, 2, 2)
  (4): nn.ReLU
  (5): nn.SpatialMaxPooling(3x3, 2, 2)
  (6): nn.SpatialConvolution(192 -> 384, 3x3, 1, 1, 1, 1)
  (7): nn.ReLU
  (8): nn.SpatialConvolution(384 -> 256, 3x3, 1, 1, 1, 1)
  (9): nn.ReLU
  (10): nn.SpatialConvolution(256 -> 256, 3x3, 1, 1, 1, 1)
  (11): nn.ReLU
  (12): nn.SpatialMaxPooling(3x3, 2, 2)
  (13): nn.View(1, 9216)
  (14): nn.Linear(9216 -> 4096)
  (15): nn.ReLU
  (16): nn.Linear(4096 -> 4096)
  (17): nn.ReLU
  (18): nn.Linear(4096 -> 46)
  (19): nn.SoftMax
}


In [5]:
class ModelDef(nn.Module):

    def __init__(self, num_classes=46):
        super(ModelDef, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),                                
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(9216, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )                                                         
                                                                  
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x


In [6]:
m2 = ModelDef()
m2.eval()
m = nn.Softmax()

In [7]:
x1 = torch.randn(1, 3, 224, 224)
x1_var = Variable(x1)
y1 = m1.forward(x1)
y2 = m(m2(x1_var))
# Output of both network will be different; which is obvious!!!
for i in range(len(y1[0])):
    print(str(y1[0][i]) + '\t' + str(y2.data[0][i]))

1.3432990044748294e-06	0.022018155083060265
0.0008719050674699247	0.022039148956537247
0.007889222353696823	0.021509507670998573
0.0004464764497242868	0.021576479077339172
0.0003359577094670385	0.021639961749315262
1.7650879672800879e-12	0.02152116224169731
0.00433177687227726	0.021732622757554054
3.5346817139902953e-10	0.02132371813058853
5.946105777020674e-17	0.022074732929468155
7.42116170772157e-18	0.021620875224471092
2.153080686184694e-06	0.021924156695604324
0.010166157968342304	0.02187509462237358
0.14177590608596802	0.021905595436692238
4.358494152256753e-06	0.022067096084356308
2.4869538736118643e-18	0.02138614095747471
0.008647882379591465	0.021801117807626724
8.150720376409737e-11	0.02167879231274128
8.782055260780908e-07	0.021408220753073692
0.012028225697577	0.02191108837723732
0.010035747662186623	0.02156299166381359
0.11148055642843246	0.02182559110224247
7.24380515748635e-05	0.02188880741596222
0.000648920948151499	0.021966079249978065
0.00041553800110705197	0.02156689

In [8]:
# copy weights from torch model into pytorch model
j = 0
for i in m2.modules():
    if not list(i.children()):
        if len(i.state_dict()) > 0:
            i.weight.data = m1.modules[j].weight
            i.bias.data = m1.modules[j].bias
            
        j += 1
        if j == 13:    # Ignore nn.View
            j += 1


In [9]:
m2._modules['classifier'][0].weight

Parameter containing:
-0.0000 -0.0000 -0.0000  ...  -0.0001  0.0000  0.0000
 0.0000  0.0000  0.0000  ...   0.0000 -0.0000 -0.0000
 0.0002 -0.0001  0.0001  ...   0.0000 -0.0002  0.0169
          ...             ⋱             ...          
 0.0000 -0.0000 -0.0000  ...   0.0000 -0.0000  0.0000
-0.0000  0.0000  0.0000  ...   0.0001 -0.0000 -0.0000
 0.0000 -0.0000  0.0000  ...  -0.0000 -0.0000 -0.0000
[torch.FloatTensor of size 4096x9216]

In [10]:
m1.modules[14].weight # both weights should match


-0.0000 -0.0000 -0.0000  ...  -0.0001  0.0000  0.0000
 0.0000  0.0000  0.0000  ...   0.0000 -0.0000 -0.0000
 0.0002 -0.0001  0.0001  ...   0.0000 -0.0002  0.0169
          ...             ⋱             ...          
 0.0000 -0.0000 -0.0000  ...   0.0000 -0.0000  0.0000
-0.0000  0.0000  0.0000  ...   0.0001 -0.0000 -0.0000
 0.0000 -0.0000  0.0000  ...  -0.0000 -0.0000 -0.0000
[torch.FloatTensor of size 4096x9216]

In [11]:
y1 = m1.forward(x1)
y2 = m(m2(x1_var))
# Output of both networks are same because they now have the same weights
for i in range(len(y1[0])):
    print(str(y1[0][i]) + '\t' + str(y2.data[0][i]))

1.3432990044748294e-06	1.3432990044748294e-06
0.0008719050674699247	0.0008719050674699247
0.007889222353696823	0.007889222353696823
0.0004464764497242868	0.0004464764497242868
0.0003359577094670385	0.0003359577094670385
1.7650879672800879e-12	1.7650879672800879e-12
0.00433177687227726	0.00433177687227726
3.5346817139902953e-10	3.5346817139902953e-10
5.946105777020674e-17	5.946105777020674e-17
7.42116170772157e-18	7.42116170772157e-18
2.153080686184694e-06	2.153080686184694e-06
0.010166157968342304	0.010166157968342304
0.14177590608596802	0.14177590608596802
4.358494152256753e-06	4.358494152256753e-06
2.4869538736118643e-18	2.4869538736118643e-18
0.008647882379591465	0.008647882379591465
8.150720376409737e-11	8.150720376409737e-11
8.782055260780908e-07	8.782055260780908e-07
0.012028225697577	0.012028225697577
0.010035747662186623	0.010035747662186623
0.11148055642843246	0.11148055642843246
7.24380515748635e-05	7.24380515748635e-05
0.000648920948151499	0.000648920948151499
0.000415538001

In [12]:
# Conversion of torch to pytorch model complete
# Time to save the new model
torch.save(m2.state_dict(), '/Workspace/pytorch_model.pth.tar')