## Saving and Loading Models

In [3]:
import torch 
import torchvision.models as models

#### Saving the Weights of a Model to Disk

In [4]:
resnet50 = models.resnet50(pretrained=True) #get the pretrained resnet 50. 
torch.save(resnet50.state_dict(), 'resnet50_weights.pth')

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ozan/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100.0%


#### Load the Weights from .pth File to a Model

In [7]:
my_resnet = models.resnet50() #you need an instance of the same model to load the weights.
my_resnet.load_state_dict(torch.load('resnet50_weights.pth'))
my_resnet.eval() # freeze all the weights.

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

#### Alternatively, save the model definition along with the weights

In [8]:
torch.save(my_resnet, 'my_resnet.pth')
arbitrary_model = torch.load('my_resnet.pth')

### Recap
* Use torch.save(model_obj, 'filename.pth') to save the model architecture and weights.
* Use torch.save(model_obj.state_dict(), 'file.pth') to save only the weights.
* torch.load('filename.pth') returns the content of the .pth file.
&nbsp;

If it has the architecture:
model = torch.load('filename.pth') will create a model with the given architecture and set the weights. 
&nbsp;

If .pth does not have the architecture, then manually create a model of the same class and use:
model.load_state_dict(torch.load('filename.pth')) to load the weights of the model.