# Save and Load Neural Nets

There are two approaches for saving and loading models:
1. saving and loading the `state_dict`
2. saving and loading the entire model

## 1. saving and loading the `state_dict`
Saving the model’s `state_dict` with the `torch.save()` only saves the trained model’s learned parameters. 

## 2. Saving and loading and entire model
Most convienient, most limited.

Saving and entire model saves the entire module using Python’s pickle module (pretty terrible if you ask me). 

***Main disadvantage***: 
- Serialized data is bound to the specific classes and the exact directory structure used when the model is saved. Python's pickle **does not** save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time.

## Roadmap
To save and load a model we fisrt have to define a neural net.
1. Define the neural net class
2. Initilize model
3. Set optimizer
4. Save model
5. Load model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

### 1. Define the neural net class

In [None]:
class BasicNet(nn.Module):
    def __init__(self):
        super(BasicNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### 2. Initialize model

In [None]:
net = BasicNet()
net

### 3. Set optimizer

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

### 4. Save model

#### a) Save `state_dict`

In [None]:
# Specify a path
PATH_state_dict = "state_dict_model.pt"

# Save
torch.save(net.state_dict(), PATH_state_dict)

#### b) Save entire model

In [None]:
PATH_entire_model = "entire_model.pt"

# Save
torch.save(net, PATH_entire_model)

### 5. Load model

#### a) Load `state_dict`

In [None]:
# Load
model = BasicNet()
model.load_state_dict(torch.load(PATH_state_dict)) # load_state_dict takes a state_dict as argument
model.eval()

#### b) Load entire model

In [None]:
# Load
model = torch.load(PATH_entire_model)
model.eval()