the learnable parameters of an torch.nn.Module is contained in model.parameters(). 

A state_dict is a python dictionary that maps each layer to its parameter tensor. Only layers with learnable parameters have entries in the model's state_dict.

Because state_dict are python dictionaries, they can be easily saved, updates, restored.

In [1]:
import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F

In [2]:
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
model = TheModelClass()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print('model state_dict: ')
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor])
    
print('optimizer state_dict: ')
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
print('--'*30)    
print(model.state_dict())

model state_dict: 
conv1.weight 	 tensor([[[[ 0.0438, -0.0183, -0.0235,  0.0087, -0.0452],
          [-0.0382, -0.0011, -0.0887,  0.0852, -0.0593],
          [ 0.0992,  0.1070, -0.0183, -0.0653,  0.0153],
          [ 0.0768, -0.0683, -0.0779,  0.1074,  0.0938],
          [ 0.0995, -0.0449,  0.0872, -0.0703, -0.0776]],

         [[ 0.0970, -0.1146, -0.0131,  0.0087,  0.0698],
          [ 0.0237,  0.0810, -0.0786,  0.0112, -0.0811],
          [-0.0137,  0.0680, -0.0106, -0.0592, -0.0131],
          [ 0.0390,  0.0849, -0.0602, -0.0058,  0.0092],
          [-0.0253,  0.0891, -0.0086, -0.0714,  0.1044]],

         [[ 0.0129, -0.1045,  0.0937,  0.0702,  0.0146],
          [-0.0943, -0.0690,  0.0924, -0.0877, -0.0782],
          [ 0.0375, -0.1132,  0.0034,  0.1036,  0.0392],
          [-0.1100,  0.0281, -0.0812, -0.0073,  0.1020],
          [-0.0978, -0.0880, -0.0282,  0.0379, -0.0983]]],


        [[[-0.1136, -0.0506, -0.1049, -0.0512,  0.0679],
          [ 0.0542,  0.0934, -0.0596, -0.0224,

save & load model

In [10]:
torch.save(model.state_dict(), 'saved_model.pth')   #.pt or .pth file extension

model = TheModelClass()
model.load_state_dict(torch.load('saved_model.pth'))
model.eval()   # to set dropout and batch normalization layers to evaluation mode

TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

save & load a general checkpoint

In [None]:
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss
    ...
}, 'checkpoint.tar')   # use the .tar file extension

#load
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load('checkpoint.tar')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
...

model.eval()  # or
model.train()


save multiple models in one file

In [None]:
torch.save({
    'model1_state_dict': model1.state_dict(),
    'model2_state_dict': model2.state_dict(),
    'optimizer1_state_dict': optimizer1.state_dict(),
    'optimizer2_state_dict': optimizer2.state_dict(),
    ...
}, 'multiple_models.tar')  # .tar file

model1 = TheModel1Class()
model2 = TheModel2Class()
optimizer1 = TheOptimizer1Class(*args, **kwargs)
optimizer2 = TheOptimizer2Class(*args, **kwargs)

checkpoint = torch.load('multiple_model.tar')

model1.load_state_dict(checkpoint['model1_state_dict'])
model2.load_state_dict(checkpoint['model2_state_dict'])
optimizer1.load_state_dict(checkpoint['optimizer1_state_dict'])
optimizer2.load_state_dict(checkpoint['optimizer2_state_dict'])

model1.eval()
model2.eval()
#or
model1.train()
model2.train()


using parameters from a different model

you can set the ```strict``` argument to False in the `load_state_dict()` function to ignore non_matching keys

In [None]:
torch.save(modelA.state_dict(), PATH)

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

save & load model across devices

In [None]:
#save on GPU, load on CPU
torch.save(model.state_dict(), PATH)

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model = load_state_dict(torch.load(PATH, map_location=device))

In [None]:
#save on GPU. load on GPU
torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device) #remember to call input=input.to(device) on input tensors that you feed to the model

In [None]:
#save on CPU, load on GPU
torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # choose whatever GPU device number you want
model.to(device)

when loading a model on GPU that was trained on CPU, set the `map_location` in `torch.load()` to cuda:device_id.

