# Demo: Training Progress GIF and Checkpoint Save/Load

This short demo shows how to embed the generated `training_progress.gif` and provides a runnable PyTorch example that creates, saves, and loads a small checkpoint. The notebook is self-contained and does not require the Unity environment to run.

In [None]:
# 1) Import Required Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from IPython.display import Image, display
from pathlib import Path
import json

print('torch version:', getattr(torch, '__version__', 'unknown'))

## Embed training_progress.gif

If you have run `python3 generate_plot.py` the GIF will be at `checkpoints/demos/training_progress.gif`. If not present, change the path below to a GIF in the workspace.

In [None]:
# 2) Display the GIF if present
gif_path = Path('checkpoints/demos/training_progress.gif')
if gif_path.exists():
    display(Image(str(gif_path)))
else:
    print('GIF not found at', gif_path)
    print('Run: python3 generate_plot.py --checkpoints checkpoints --out checkpoints/demos')

## Minimal PyTorch model and dummy input

We define a tiny MLP and create a dummy input for inference.

In [None]:
class TinyMLP(nn.Module):
    def __init__(self, input_dim=8, hidden=16, output_dim=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, output_dim)
        )
    def forward(self, x):
        return self.net(x)

input_dim = 8
model = TinyMLP(input_dim=input_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

dummy_input = torch.randn(1, input_dim)
print('Model parameter count:', sum(p.numel() for p in model.parameters()))

## Create and save a sample checkpoint

Run a single forward/backward step to modify weights and save a checkpoint named `checkpoint.pth` in the notebook working directory.

In [None]:
# 4) Single training step and save checkpoint
model.train()
optimizer.zero_grad()
output = model(dummy_input)
loss = output.pow(2).mean()
loss.backward()
optimizer.step()

ckpt = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': 1
}

ckpt_path = Path('checkpoint.pth')
torch.save(ckpt, str(ckpt_path))
print('Saved checkpoint to', ckpt_path)

## Load checkpoint and run inference (runnable example)

This cell creates a fresh model, loads the saved checkpoint, and runs inference on the dummy input.

In [None]:
# 5) Load the checkpoint and run inference
try:
    ckpt = torch.load('checkpoint.pth', map_location='cpu')
    model2 = TinyMLP(input_dim=input_dim)
    model2.load_state_dict(ckpt['model_state_dict'])
    model2.eval()
    with torch.no_grad():
        out = model2(dummy_input)
    print('Inference output (tensor):', out)
except FileNotFoundError:
    print('checkpoint.pth not found. Run the previous cell to create one.')
except Exception as e:
    print('Error loading or running model:', e)

## Capture output summary (JSON)

This final cell prints a small JSON summary so outputs are easy to inspect in the VS Code Output pane.

In [None]:
# 6) Output summary and execution hint
try:
    summary = {
        'loaded_epoch': int(ckpt.get('epoch', -1)),
        'output_shape': list(out.shape)
    }
    print(json.dumps(summary))
except Exception as e:
    print('Could not create summary:', e)

# Hint: to execute this notebook from a terminal:
# jupyter nbconvert --to notebook --execute demo.ipynb --output demo_executed.ipynb