📝 **Author:** Amirhossein Heydari - 📧 **Email:** amirhosseinheydari78@gmail.com - 📍 **Linktree:** [linktr.ee/mr_pylin](https://linktr.ee/mr_pylin)

---

**Table of contents**<a id='toc0_'></a>    
- [Dependencies](#toc1_)    
- [Consider An Initialized Model As Trained](#toc2_)    
- [Save & Load](#toc3_)    
    - [Save and Load ONLY Parameters](#toc3_1_1_)    
    - [Save & Load the ENTIRE Model](#toc3_1_2_)    
    - [Saving & Loading a General Checkpoint for Inference and/or Resuming Training](#toc3_1_3_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Dependencies](#toc0_)

In [18]:
import torch
from torch import nn, optim
from torchinfo import summary

In [None]:
# set a seed for deterministic results
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# <a id='toc2_'></a>[Consider An Initialized Model As Trained](#toc0_)

In [None]:
trained_model = nn.Sequential(
    nn.Linear(4, 2),
    nn.Sigmoid(),
    nn.Linear(2, 1),
    nn.Sigmoid(),
)

# log
print(trained_model)

In [None]:
summary(trained_model, input_size=(16, 4), device="cpu")

In [None]:
# weights and biases per layer (using model.parameters())
for i, param in enumerate(trained_model.parameters()):
    if i % 2 == 0:  # weights of the model
        print(str(param).replace("Parameter containing:", f"weights (layer {i // 2 + 1}):"), end="\n\n")
    else:  # biases of the model
        print(str(param).replace("Parameter containing:", f"biases (layer {(i-1) // 2 + 1}):"), end="\n\n")

In [None]:
# weights and biases per layer (using model.state_dict())
for param in trained_model.state_dict().items():
    print(param)

# <a id='toc3_'></a>[Save & Load](#toc0_)
   - The extension `.pth` has no specific meaning to PyTorch internally.
   - `.pth` (or sometimes `.pt`) is used conventionally to indicate the file contains a PyTorch model or parameters.

📝 **Docs & Tutorials** 📚:
   - torch.save: [pytorch.org/docs/stable/generated/torch.save.html](https://pytorch.org/docs/stable/generated/torch.save.html)
   - torch.load: [pytorch.org/docs/stable/generated/torch.load.html](https://pytorch.org/docs/stable/generated/torch.load.html)
   - Saving and Loading Models: [pytorch.org/tutorials/beginner/saving_loading_models.html](https://pytorch.org/tutorials/beginner/saving_loading_models.html)
   - Save and Load the Model: [pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html](https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html)

### <a id='toc3_1_1_'></a>[Save and Load ONLY Parameters](#toc0_)
   - This is the recommended approach.
   - Model architecture can be defined separately and changed without issues
   - Efficient for saving memory and storage

In [24]:
# get model parameters
trained_model_parameters = trained_model.state_dict()

# save
torch.save(obj=trained_model_parameters, f="../../assets/models/model_1.pth")

In [None]:
# load
weights = torch.load(f="../../assets/models/model_1.pth", weights_only=True)

# log
weights

In [None]:
# insert weights to the model
model_1 = nn.Sequential(nn.Linear(4, 2), nn.Sigmoid(), nn.Linear(2, 1), nn.Sigmoid())

model_1.load_state_dict(weights)

# log
for param in model_1.state_dict().items():
    print(param)

### <a id='toc3_1_2_'></a>[Save & Load the ENTIRE Model](#toc0_)
   - ✅ Easier to use since you don’t need to redefine the model architecture.
   - ⚠️ Not portable across different PyTorch versions.

In [None]:
# save
torch.save(obj=trained_model, f="../../assets/models/model_2.pth")

# load
model_2 = torch.load(f="../../assets/models/model_2.pth", weights_only=False)

# log
model_2

In [None]:
# log
for param in model_2.state_dict().items():
    print(param)

### <a id='toc3_1_3_'></a>[Saving & Loading a General Checkpoint for Inference and/or Resuming Training](#toc0_)
   - you can save a checkpoint whenever you are training the model at each epoch

In [29]:
epoch = 10
criterion = nn.MSELoss()
optimizer = optim.SGD(params=trained_model.parameters(), lr=0.01)

In [30]:
# save both model and optimizer state_dict for resuming training
torch.save(
    obj={
        "model_state_dict": trained_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,  # Save the epoch to resume training
        "criterion": criterion,  # Optional, save the last loss
    },
    f="../../assets/models/model_3.pth",
)

In [31]:
# load the checkpoint
checkpoint = torch.load("../../assets/models/model_3.pth", weights_only=False)

# model
model_3 = nn.Sequential(nn.Linear(4, 2), nn.Sigmoid(), nn.Linear(2, 1), nn.Sigmoid())

# optimizer
optimizer = optim.SGD(model_3.parameters(), lr=0.01)

# insert values
model_3.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
criterion = checkpoint["criterion"]
epoch = checkpoint["epoch"]

In [None]:
# log
for param in model_3.state_dict().items():
    print(param)

In [None]:
# log
print(optimizer)

In [None]:
# log
print(f"epoch : {epoch}")