In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
from torch import nn, optim

In [3]:
# Define model
class ConvModel(nn.Module):
    def __init__(self):
        super(ConvModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(2, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = ConvModel()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [4]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].shape)

Model's state_dict:
conv1.weight 	 torch.Size([2, 1, 5, 5])
conv1.bias 	 torch.Size([2])
conv2.weight 	 torch.Size([16, 2, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


# Save whole model serialised (not recommended)

In [5]:
# Saving a model (seriealised)
torch.save(model, 'model.pt')

In [6]:
model_restored = torch.load('model.pt')

In [7]:
print("Model's state_dict:")
for param_tensor in model_restored.state_dict():
    print(param_tensor, "\t", model_restored.state_dict()[param_tensor].shape)

Model's state_dict:
conv1.weight 	 torch.Size([2, 1, 5, 5])
conv1.bias 	 torch.Size([2])
conv2.weight 	 torch.Size([16, 2, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


# Save parameters of the model (recommended way)

In [8]:
# We save the parameters
torch.save(model.state_dict(), 'model_params.pt')

In [9]:
# We generate another object with the same shape
model_2 = ConvModel()

In [10]:
list( model_2.parameters() )[0]

Parameter containing:
tensor([[[[-0.0444,  0.1570, -0.0186, -0.0965,  0.1197],
          [ 0.1386, -0.1475, -0.1814, -0.1548, -0.0949],
          [-0.0817,  0.1686, -0.1831,  0.0099,  0.1208],
          [-0.1084,  0.0096, -0.0999, -0.0679,  0.1540],
          [-0.0448, -0.0826,  0.1859,  0.0273, -0.0661]]],


        [[[ 0.1576,  0.1604, -0.0548, -0.1326, -0.0544],
          [ 0.0026, -0.0719,  0.1970,  0.1336,  0.1556],
          [ 0.0206, -0.0461,  0.0766, -0.0481, -0.0120],
          [-0.1711,  0.1747, -0.0507,  0.0919,  0.0737],
          [-0.0929,  0.1062, -0.0069,  0.1425, -0.0714]]]], requires_grad=True)

In [11]:
model_2.load_state_dict(torch.load('model_params.pt'))

<All keys matched successfully>

In [12]:
list( model_2.parameters() )[0]

Parameter containing:
tensor([[[[-0.0530, -0.0431, -0.1500,  0.0498, -0.1820],
          [-0.1012,  0.1697, -0.0537,  0.0496, -0.1558],
          [-0.0728, -0.1747, -0.1867,  0.0351,  0.0357],
          [-0.1214,  0.0509,  0.1317, -0.0915,  0.1183],
          [ 0.1287,  0.1430,  0.0432, -0.1135, -0.1910]]],


        [[[-0.0130,  0.1356,  0.0087, -0.1009, -0.1176],
          [-0.0286, -0.1337, -0.0197,  0.1211, -0.0805],
          [ 0.1709, -0.1395,  0.1944, -0.0823, -0.1315],
          [ 0.1996, -0.0503,  0.0823,  0.0854, -0.1392],
          [-0.1482,  0.0882, -0.1148,  0.1722,  0.1766]]]], requires_grad=True)

In [13]:
list( model.parameters() )[0]

Parameter containing:
tensor([[[[-0.0530, -0.0431, -0.1500,  0.0498, -0.1820],
          [-0.1012,  0.1697, -0.0537,  0.0496, -0.1558],
          [-0.0728, -0.1747, -0.1867,  0.0351,  0.0357],
          [-0.1214,  0.0509,  0.1317, -0.0915,  0.1183],
          [ 0.1287,  0.1430,  0.0432, -0.1135, -0.1910]]],


        [[[-0.0130,  0.1356,  0.0087, -0.1009, -0.1176],
          [-0.0286, -0.1337, -0.0197,  0.1211, -0.0805],
          [ 0.1709, -0.1395,  0.1944, -0.0823, -0.1315],
          [ 0.1996, -0.0503,  0.0823,  0.0854, -0.1392],
          [-0.1482,  0.0882, -0.1148,  0.1722,  0.1766]]]], requires_grad=True)