# Storing and Loading Models & Weights in PyTorch

This notebook demonstrates the two primary ways to save and load your models in PyTorch:

1.  **Saving/Loading the `state_dict` (Recommended)**: This saves only the learnable parameters (weights and biases) of the model. It's the most flexible and robust method, as it doesn't depend on the specific class definition or file structure remaining identical.
2.  **Saving/Loading the Entire Model**: This saves the entire model object using Python's `pickle` module. While simple, it can break if you refactor your code or use the model in a different project.

## 1. Imports and Model Definition

First, let's import the necessary libraries and define the same simple CNN model we used before.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

# Define the same CNN model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

print("Libraries imported and CNN model defined.")

## 2. Method 1: Saving/Loading the `state_dict` (Recommended)

A model's `state_dict` is a Python dictionary that maps each layer to its parameter tensor (weights and biases).

In [None]:
# === Saving the state_dict ===

# 1. Initialize a model
model_to_save = CNN()

# (Optional) Train the model here...
# For this example, we'll just use its initial random weights.

# 2. Define a path
WEIGHTS_PATH = 'model_weights.pth'

# 3. Save the state_dict
torch.save(model_to_save.state_dict(), WEIGHTS_PATH)
print(f"Model weights saved to {WEIGHTS_PATH}")

# === Loading the state_dict ===

# 1. Initialize a *new* instance of the model
# You MUST create an instance of the model *before* you can load weights
model_to_load = CNN()

# 2. Load the state_dict from the file
loaded_state_dict = torch.load(WEIGHTS_PATH)

# 3. Load the state_dict into the model
model_to_load.load_state_dict(loaded_state_dict)

# 4. Set the model to evaluation mode (important if you have dropout, batchnorm, etc.)
model_to_load.eval()

print(f"Model weights loaded successfully from {WEIGHTS_PATH}")

# You can verify by checking if the weights are identical
original_weights = list(model_to_save.parameters())[0]
loaded_weights = list(model_to_load.parameters())[0]
print(f"Weights are identical: {torch.equal(original_weights, loaded_weights)}")

## 3. Method 2: Saving/Loading the Entire Model

This method saves the entire model object using Python's `pickle`.

In [None]:
# === Saving the entire model ===

# 1. Initialize a model
model_to_save_full = CNN()

# 2. Define a path
FULL_MODEL_PATH = 'full_model.pth'

# 3. Save the entire model object
torch.save(model_to_save_full, FULL_MODEL_PATH)
print(f"Full model object saved to {FULL_MODEL_PATH}")

# === Loading the entire model ===

# 1. Load the entire object from the file
# You do not need to create an instance of the model first
model_to_load_full = torch.load(FULL_MODEL_PATH)

# 2. Set the model to evaluation mode
model_to_load_full.eval()

print(f"Full model object loaded successfully from {FULL_MODEL_PATH}")

# Verify
print(model_to_load_full)

## 4. Saving in a Federated Learning Loop

In your main training loop, you would typically save the `state_dict` of your **global model** at the end of each round or at the very end of training.

In [None]:
def example_fl_loop():
    # --- Inside your main training loop (from the previous notebook) ---
    
    # Initialize the global model
    global_model = CNN() # .to(device)
    
    num_rounds = 5 # Example: 5 rounds
    
    for round_idx in range(num_rounds):
        print(f"Running round {round_idx + 1}...")
        # ... (client training and server aggregation logic) ...
        # global_model.load_state_dict(new_aggregated_weights)
        
        # --- Save a checkpoint --- 
        # You could save at intermediate rounds
        if (round_idx + 1) % 2 == 0:
            checkpoint_path = f"global_model_round_{round_idx + 1}.pth"
            torch.save(global_model.state_dict(), checkpoint_path)
            print(f"Saved checkpoint: {checkpoint_path}")
            
    # --- Save the final model --- 
    FINAL_MODEL_PATH = "global_model_final.pth"
    torch.save(global_model.state_dict(), FINAL_MODEL_PATH)
    print(f"\nFinal model weights saved to {FINAL_MODEL_PATH}")

# Run the example loop
example_fl_loop()