In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
from torch import nn, optim

In [3]:
# Define model
class ConvModel(nn.Module):
    def __init__(self):
        super(ConvModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(2, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = ConvModel()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [4]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].shape)

Model's state_dict:
conv1.weight 	 torch.Size([2, 1, 5, 5])
conv1.bias 	 torch.Size([2])
conv2.weight 	 torch.Size([16, 2, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


# Save whole model serialised (not recommended)

In [5]:
# Saving a model (seriealised)
torch.save(model, 'model.pt')

In [6]:
model_restored = torch.load('model.pt')

In [7]:
print("Model's state_dict:")
for param_tensor in model_restored.state_dict():
    print(param_tensor, "\t", model_restored.state_dict()[param_tensor].shape)

Model's state_dict:
conv1.weight 	 torch.Size([2, 1, 5, 5])
conv1.bias 	 torch.Size([2])
conv2.weight 	 torch.Size([16, 2, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


# Save parameters of the model (recommended way)

In [8]:
# We save the parameters
torch.save(model.state_dict(), 'model_params.pt')

In [9]:
# We generate another object with the same shape
model_2 = ConvModel()

In [10]:
list( model_2.parameters() )[0]

Parameter containing:
tensor([[[[ 0.1238, -0.1151,  0.1551, -0.0914, -0.1724],
          [ 0.0418,  0.1676, -0.1460, -0.1062, -0.0197],
          [ 0.1974,  0.1027,  0.0787,  0.0060, -0.0075],
          [-0.0379, -0.1745, -0.1925,  0.1519, -0.0093],
          [ 0.1419, -0.0785, -0.0469, -0.0459, -0.1283]]],


        [[[-0.1950,  0.0287,  0.1162, -0.1438, -0.1971],
          [-0.1488, -0.1836,  0.1935, -0.1578,  0.1431],
          [-0.1527, -0.0799, -0.0696,  0.1878, -0.1473],
          [ 0.0102,  0.0118,  0.0059,  0.1160,  0.1397],
          [ 0.0977, -0.1775,  0.1203,  0.0436, -0.1443]]]], requires_grad=True)

In [11]:
model_2.load_state_dict(torch.load('model_params.pt'))

<All keys matched successfully>

In [12]:
list( model_2.parameters() )[0]

Parameter containing:
tensor([[[[ 0.1946, -0.1003,  0.0519, -0.1749,  0.1749],
          [-0.0279, -0.0186, -0.0408, -0.1266,  0.1106],
          [-0.0316,  0.1874,  0.1554,  0.0999, -0.1427],
          [-0.1352, -0.0623,  0.1385,  0.0913,  0.0545],
          [ 0.0167, -0.0181,  0.1557, -0.0579, -0.0235]]],


        [[[ 0.0760, -0.0011, -0.1318,  0.1222,  0.1512],
          [ 0.1367, -0.0802, -0.0815, -0.1521, -0.1479],
          [-0.1535, -0.0156,  0.1213,  0.0381,  0.0718],
          [ 0.0774, -0.0844,  0.0078,  0.1703,  0.0775],
          [ 0.0774, -0.0710, -0.0067,  0.1794, -0.1629]]]], requires_grad=True)

In [13]:
list( model.parameters() )[0]

Parameter containing:
tensor([[[[ 0.1946, -0.1003,  0.0519, -0.1749,  0.1749],
          [-0.0279, -0.0186, -0.0408, -0.1266,  0.1106],
          [-0.0316,  0.1874,  0.1554,  0.0999, -0.1427],
          [-0.1352, -0.0623,  0.1385,  0.0913,  0.0545],
          [ 0.0167, -0.0181,  0.1557, -0.0579, -0.0235]]],


        [[[ 0.0760, -0.0011, -0.1318,  0.1222,  0.1512],
          [ 0.1367, -0.0802, -0.0815, -0.1521, -0.1479],
          [-0.1535, -0.0156,  0.1213,  0.0381,  0.0718],
          [ 0.0774, -0.0844,  0.0078,  0.1703,  0.0775],
          [ 0.0774, -0.0710, -0.0067,  0.1794, -0.1629]]]], requires_grad=True)