In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
from torch import nn, optim

In [3]:
# Define model
class ConvModel(nn.Module):
    def __init__(self):
        super(ConvModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(2, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = ConvModel()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [4]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].shape)

Model's state_dict:
conv1.weight 	 torch.Size([2, 1, 5, 5])
conv1.bias 	 torch.Size([2])
conv2.weight 	 torch.Size([16, 2, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


# Save whole model serialised (not recommended)

In [5]:
# Saving a model (seriealised)
torch.save(model, 'model.pt')

In [6]:
model_restored = torch.load('model.pt')

In [7]:
print("Model's state_dict:")
for param_tensor in model_restored.state_dict():
    print(param_tensor, "\t", model_restored.state_dict()[param_tensor].shape)

Model's state_dict:
conv1.weight 	 torch.Size([2, 1, 5, 5])
conv1.bias 	 torch.Size([2])
conv2.weight 	 torch.Size([16, 2, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


# Save parameters of the model (recommended way)

In [8]:
# We save the parameters
torch.save(model.state_dict(), 'model_params.pt')

In [9]:
# We generate another object with the same shape
model_2 = ConvModel()

In [10]:
list( model_2.parameters() )[0]

Parameter containing:
tensor([[[[-0.0250, -0.1927, -0.1721,  0.1006, -0.1552],
          [ 0.1121,  0.0398,  0.0953, -0.0219,  0.0270],
          [-0.1435,  0.0705,  0.0162,  0.1826,  0.0134],
          [-0.0992,  0.1626,  0.0632, -0.1126, -0.1139],
          [-0.1561,  0.1838,  0.0267, -0.1809, -0.1873]]],


        [[[-0.1042, -0.0234,  0.1420,  0.1122,  0.1865],
          [ 0.0618,  0.0830, -0.0818, -0.0553,  0.0511],
          [ 0.0754, -0.1321, -0.0070, -0.0998, -0.1283],
          [ 0.0968,  0.1002,  0.1660,  0.0684,  0.0074],
          [ 0.0332, -0.0946,  0.1185,  0.0040,  0.0939]]]], requires_grad=True)

In [11]:
model_2.load_state_dict(torch.load('model_params.pt'))

<All keys matched successfully>

In [12]:
list( model_2.parameters() )[0]

Parameter containing:
tensor([[[[ 0.0676, -0.1965,  0.0124,  0.1616, -0.0505],
          [-0.1171, -0.1013,  0.0044, -0.1624, -0.0323],
          [-0.1984,  0.1492, -0.1321,  0.0816, -0.0149],
          [ 0.1143,  0.0813,  0.1371,  0.0167, -0.1863],
          [-0.0123,  0.1452,  0.1941, -0.0260, -0.1003]]],


        [[[ 0.1184, -0.0809,  0.0626, -0.1644, -0.1212],
          [-0.1548, -0.1173, -0.1422,  0.1562, -0.1439],
          [ 0.0282, -0.1586,  0.1682, -0.1287,  0.0714],
          [ 0.0440,  0.1154,  0.1684,  0.0696,  0.1641],
          [ 0.0574, -0.1678,  0.0597, -0.0298,  0.0234]]]], requires_grad=True)

In [13]:
list( model.parameters() )[0]

Parameter containing:
tensor([[[[ 0.0676, -0.1965,  0.0124,  0.1616, -0.0505],
          [-0.1171, -0.1013,  0.0044, -0.1624, -0.0323],
          [-0.1984,  0.1492, -0.1321,  0.0816, -0.0149],
          [ 0.1143,  0.0813,  0.1371,  0.0167, -0.1863],
          [-0.0123,  0.1452,  0.1941, -0.0260, -0.1003]]],


        [[[ 0.1184, -0.0809,  0.0626, -0.1644, -0.1212],
          [-0.1548, -0.1173, -0.1422,  0.1562, -0.1439],
          [ 0.0282, -0.1586,  0.1682, -0.1287,  0.0714],
          [ 0.0440,  0.1154,  0.1684,  0.0696,  0.1641],
          [ 0.0574, -0.1678,  0.0597, -0.0298,  0.0234]]]], requires_grad=True)