# Test model reload
## Author: G. Erlebacher
We will perform a very simple experiment: 
- a single layer with one input and one output, with two parameters to train. 
- Initialize the model and save the initial state, `model0`
- Run the `model0` for a single epoch and save it to `model1`
- Run `model0` for two epochs and save the results to `model2a`
- load `model1` and run `model1` for a single epoch. Save this to `model2b`
- Compare `model2a` and `model2b`. They should be identical.

In [1]:
import torch
import numpy as np

In [4]:
x = np.random.rand(10)
ylab = .3 * x + .1 + np.random.rand(10) * .02

# reshape for batch size with all the points
x = torch.as_tensor(x, dtype=torch.float32).view(-1, 1)
ylab = torch.as_tensor(ylab, dtype=torch.float32).view(-1, 1)

In [6]:
model = torch.nn.Linear(1, 1, dtype=torch.float32)
loss_fct = torch.nn.MSELoss(reduce='mean')
opt = torch.optim.SGD(model.parameters(), lr=1e-3)



In [7]:
def save_model(model, opt):
    file_name = "linear_model.pth"
    torch.save({
        "model": model, 
        "opt": opt
    }, file_name)

def load_model():
    file_name = "linear_model.pth"
    checkpoint = torch.load(file_name)
    return checkpoint["model"], checkpoint["opt"]

In [37]:
def print_weight_norms(model, msg):
    norm = torch.sqrt(sum([torch.norm(w[0], p=2) ** 2 for w in model.parameters()]))
    print(f"==> {msg}, {norm}")

In [38]:
save_model(model, opt)
model1, opt1 = load_model()

In [42]:
def train_model(model, opt):
    n_epochs = 3
    for _ in range(n_epochs):
        model.train()
        y_pred = model(x)
        loss = loss_fct(y_pred, ylab)
        print("loss: ", loss.item())
        print(model.state_dict())
        print_weight_norms(model, "Weight norms: ")
        if model.training:
            loss.backward()
            opt.step()

In [43]:
train_model(model, opt)

loss:  0.006478526629507542
OrderedDict([('weight', tensor([[0.2338]])), ('bias', tensor([0.0592]))])
==> Weight norms: , 0.2412053942680359
loss:  0.005174466408789158
OrderedDict([('weight', tensor([[0.2374]])), ('bias', tensor([0.0662]))])
==> Weight norms: , 0.24643538892269135
loss:  0.003998951055109501
OrderedDict([('weight', tensor([[0.2410]])), ('bias', tensor([0.0733]))])
==> Weight norms: , 0.2519116997718811


In [44]:
train_model(model1, opt1)

loss:  0.01692776195704937
OrderedDict([('weight', tensor([[0.2133]])), ('bias', tensor([0.0190]))])
==> Weight norms: , 0.21418455243110657
loss:  0.016600437462329865
OrderedDict([('weight', tensor([[0.2139]])), ('bias', tensor([0.0200]))])
==> Weight norms: , 0.21480286121368408
loss:  0.01619659550487995
OrderedDict([('weight', tensor([[0.2145]])), ('bias', tensor([0.0213]))])
==> Weight norms: , 0.21558034420013428
