# Sample Code to Save Model Weights


we COULD always binarize the whole model object in python, but that can break easily -- here we use the torch method of saving just the *weights* and reading them back in to a different model

In [4]:
# Imports + Sample Model Architecture
import torch
import torch.nn as nn


# This Architecture NEEDS to be present to read in the model weights 
# how else will torch know where to put this Agent's Neurons?
device = "cuda" if torch.cuda.is_available() else "cpu" #nice flag for setting device (will always work even if GPU availability is shifty, ehm ehm Google Colab)

class BaseNet(nn.Module):
    def __init__(self):
        super(BaseNet, self).__init__()
        self.flatten = nn.Flatten()
        # Actual Architecture here is irrelevant, use a real one defined downstream
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(3, 10),
            nn.ReLU(),
            nn.Linear(10, 3),
            nn.Hardswish(),
            nn.Linear(3, 2)
        )

    def forward(self, x):
        x = self.flatten(x) 
        logits = self.linear_relu_stack(x)
        return logits

# `model` is the specific instance of this architecture we will call save on (we can load weights in to different var/instance of this same architecture)
model = BaseNet().to(device) #Instantiate the model and send to our device (cpu or gpu if 'cuda')
print(model)


BaseNet(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=3, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=3, bias=True)
    (3): Hardswish()
    (4): Linear(in_features=3, out_features=2, bias=True)
  )
)


In [5]:
# Saving a Model

save_path = "../models/misc_save_model.pth"
torch.save(model.state_dict(), save_path) #only the weights get pickled!

In [25]:
# Loading in a Saved Model

save_path = "../models/misc_save_model.pth"
saved_model = BaseNet() #make sure model architecture is same!
model.load_state_dict(torch.load(save_path, map_location=device)) #load in trained weights & move to device if available

<All keys matched successfully>

In [26]:
# Sample Forward Pass on Random Data
test_vec = torch.rand(3)
print(test_vec.shape)
test_vec = torch.unsqueeze(test_vec, 0).to(device) #get matmul dim right & move to GPU
print(test_vec.shape)

model.forward(test_vec)

torch.Size([3])
torch.Size([1, 3])


tensor([[-0.0559,  0.0858]], grad_fn=<AddmmBackward0>)