In [1]:
import torch
import torch.nn as nn
import torchvision.models as models

In [2]:
m1 = torch.load('torch_model.net')  # load torch model
print(m1)

nn.Sequential {
  [input -> (0) -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> output]
  (0): nn.SpatialConvolution(3 -> 64, 11x11, 4, 4, 2, 2)
  (1): nn.ReLU
  (2): nn.SpatialMaxPooling(3x3, 2, 2)
  (3): nn.SpatialConvolution(64 -> 192, 5x5, 1, 1, 2, 2)
  (4): nn.ReLU
  (5): nn.SpatialMaxPooling(3x3, 2, 2)
  (6): nn.SpatialConvolution(192 -> 384, 3x3, 1, 1, 1, 1)
  (7): nn.ReLU
  (8): nn.SpatialConvolution(384 -> 256, 3x3, 1, 1, 1, 1)
  (9): nn.ReLU
  (10): nn.SpatialConvolution(256 -> 256, 3x3, 1, 1, 1, 1)
  (11): nn.ReLU
  (12): nn.SpatialMaxPooling(3x3, 2, 2)
  (13): nn.View(1, 9216)
  (14): nn.Linear(9216 -> 4096)
  (15): nn.ReLU
  (16): nn.Linear(4096 -> 4096)
  (17): nn.ReLU
  (18): nn.Linear(4096 -> 46)
  (19): nn.SoftMax
}


In [3]:
class ModelDef(nn.Module):

    def __init__(self, num_classes=46):
        super(ModelDef, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),                                
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(9216, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )                                                         
                                                                  
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [4]:
m2 = ModelDef()

In [5]:
# copy weights from torch model into pytorch model
j = 0
for i in m2.modules():
    if not list(i.children()):
        if len(i.state_dict()) > 0:
            i.weight.data = m1.modules[j].weight
            i.bias.data = m1.modules[j].bias
            
        j += 1
        if j == 13:    # Ignore nn.View
            j += 1


In [6]:
m2._modules['classifier'][0].weight

Parameter containing:
-0.0000 -0.0000 -0.0000  ...  -0.0001  0.0000  0.0000
 0.0000  0.0000  0.0000  ...   0.0000 -0.0000 -0.0000
 0.0002 -0.0001  0.0001  ...   0.0000 -0.0002  0.0169
          ...             ⋱             ...          
 0.0000 -0.0000 -0.0000  ...   0.0000 -0.0000  0.0000
-0.0000  0.0000  0.0000  ...   0.0001 -0.0000 -0.0000
 0.0000 -0.0000  0.0000  ...  -0.0000 -0.0000 -0.0000
[torch.FloatTensor of size 4096x9216]

In [7]:
m1.modules[14].weight # both weights should match


-0.0000 -0.0000 -0.0000  ...  -0.0001  0.0000  0.0000
 0.0000  0.0000  0.0000  ...   0.0000 -0.0000 -0.0000
 0.0002 -0.0001  0.0001  ...   0.0000 -0.0002  0.0169
          ...             ⋱             ...          
 0.0000 -0.0000 -0.0000  ...   0.0000 -0.0000  0.0000
-0.0000  0.0000  0.0000  ...   0.0001 -0.0000 -0.0000
 0.0000 -0.0000  0.0000  ...  -0.0000 -0.0000 -0.0000
[torch.FloatTensor of size 4096x9216]

In [None]:
torch.save(m2.state_dict(), 'pytorch_model.pth.tar')