# Notebook 08: Warm-Start MAML (XuetangX) - Contribution 1

**Purpose:** Improve MAML by initializing from pre-trained GRU4Rec weights instead of random initialization.

**Research Motivation:**
- Vanilla MAML (28.66%) underperforms GRU4Rec baseline (33.55%)
- Root cause: Random initialization + only K=5 pairs is insufficient to learn course patterns
- Solution: Start from GRU4Rec's learned weights (33.55% accuracy) and adapt for personalization

**Key Insight:**
```
Without Warm-Start: Random init (0.06%) + K=5 adaptation = 28.66%
With Warm-Start:    GRU4Rec init (33.55%) + K=5 adaptation = ???%
```

**Hypothesis:** Warm-Start MAML will outperform both:
1. Vanilla MAML (28.66%) - because better initialization
2. GRU4Rec baseline (33.55%) - because of user-specific adaptation

**Inputs:**
- `data/processed/xuetangx/episodes/episodes_{train|val|test}_K5_Q10.parquet`
- `models/baselines/gru_global.pth` (pre-trained GRU4Rec: 33.55% Acc@1)
- `data/processed/xuetangx/vocab/course2id.json` (1,518 courses)

**Outputs:**
- `models/contributions/warmstart_maml_K5.pth`
- `results/warmstart_maml_K5_Q10.json`
- `reports/08_warmstart_maml_xuetangx/<run_tag>/report.json`

**Comparison:**
| Model | Initialization | Adaptation | Expected Acc@1 |
|-------|---------------|------------|----------------|
| GRU4Rec (baseline) | Trained | None | 33.55% |
| Vanilla MAML | Random | K=5 inner loop | 28.66% |
| **Warm-Start MAML** | **GRU4Rec weights** | K=5 inner loop | **>33.55%?** |

In [None]:
# [CELL 08-00] Bootstrap: repo root + paths + logger

import os
import sys
import json
import time
import uuid
import copy
import hashlib
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Tuple, Optional
from collections import OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

t0 = datetime.now()
print(f"[CELL 08-00] start={t0.isoformat(timespec='seconds')}")
print("[CELL 08-00] CWD:", Path.cwd().resolve())

def find_repo_root(start: Path) -> Path:
    start = start.resolve()
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    raise RuntimeError("Could not find repo root (no PROJECT_STATE.md)")

REPO_ROOT = find_repo_root(Path.cwd())
print(f"[CELL 08-00] REPO_ROOT: {REPO_ROOT}")

sys.path.insert(0, str(REPO_ROOT / "src"))

META_REGISTRY = REPO_ROOT / "meta.json"
DATA_INTERIM = REPO_ROOT / "data" / "interim"
DATA_PROCESSED = REPO_ROOT / "data" / "processed"
MODELS = REPO_ROOT / "models"
RESULTS = REPO_ROOT / "results"
REPORTS = REPO_ROOT / "reports"

PATHS = {
    "REPO_ROOT": REPO_ROOT,
    "META_REGISTRY": META_REGISTRY,
    "DATA_INTERIM": DATA_INTERIM,
    "DATA_PROCESSED": DATA_PROCESSED,
    "MODELS": MODELS,
    "RESULTS": RESULTS,
    "REPORTS": REPORTS,
}

for name, path in PATHS.items():
    print(f"[CELL 08-00] {name}={path}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[CELL 08-00] PyTorch device: {DEVICE}")

def cell_start(cell_id: str, description: str = "") -> datetime:
    t = datetime.now()
    msg = f"[{cell_id}] start={t.isoformat(timespec='seconds')}"
    if description:
        msg += f" | {description}"
    print(msg)
    return t

def cell_end(cell_id: str, t_start: datetime, **kv) -> None:
    elapsed = (datetime.now() - t_start).total_seconds()
    for k, v in kv.items():
        print(f"[{cell_id}] {k}={v}")
    print(f"[{cell_id}] done in {elapsed:.1f}s")

print("[CELL 08-00] done")

In [None]:
# [CELL 08-01] Configuration

t0 = cell_start("CELL 08-01", "Configuration")

RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8]
print(f"[CELL 08-01] RUN_TAG: {RUN_TAG}")

CFG = {
    "notebook": "08_warmstart_maml_xuetangx",
    "run_tag": RUN_TAG,
    "dataset": "xuetangx",
    "contribution": "warm_start_maml",
    
    # Episode config (same as notebook 07)
    "episode_config": {
        "K": 5,   # support set size
        "Q": 10,  # query set size
    },
    
    # Model config (same as GRU4Rec baseline)
    "model_config": {
        "embed_dim": 64,
        "hidden_dim": 64,
        "n_layers": 1,
        "dropout": 0.1,
    },
    
    # MAML config
    "maml_config": {
        "inner_lr": 0.01,           # alpha: learning rate for inner loop
        "outer_lr": 0.001,          # beta: learning rate for outer loop
        "inner_steps": 5,           # gradient steps in inner loop
        "meta_batch_size": 32,      # tasks per meta-batch
        "num_meta_iterations": 3000,
        "use_second_order": False,  # FOMAML
        "val_every": 100,
        "checkpoint_every": 500,
    },
    
    # Warm-Start config (NEW)
    "warmstart_config": {
        "pretrained_path": "models/baselines/gru_global.pth",
        "freeze_embeddings": False,  # Whether to freeze embedding layer
    },
    
    "seed": 42,
}

# Set seeds
np.random.seed(CFG["seed"])
torch.manual_seed(CFG["seed"])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CFG["seed"])

print(f"[CELL 08-01] Configuration:")
print(f"  - Contribution: {CFG['contribution']}")
print(f"  - Pre-trained model: {CFG['warmstart_config']['pretrained_path']}")
print(f"  - Episode: K={CFG['episode_config']['K']}, Q={CFG['episode_config']['Q']}")
print(f"  - MAML: inner_lr={CFG['maml_config']['inner_lr']}, outer_lr={CFG['maml_config']['outer_lr']}")
print(f"  - Meta iterations: {CFG['maml_config']['num_meta_iterations']}")
print(f"  - FOMAML: {not CFG['maml_config']['use_second_order']}")

cell_end("CELL 08-01", t0)

In [None]:
# [CELL 08-02] Setup paths and create directories

t0 = cell_start("CELL 08-02", "Setup paths")

# Input paths
EPISODES_DIR = DATA_PROCESSED / "xuetangx" / "episodes"
VOCAB_DIR = DATA_PROCESSED / "xuetangx" / "vocab"
PRETRAINED_PATH = REPO_ROOT / CFG["warmstart_config"]["pretrained_path"]

K = CFG["episode_config"]["K"]
Q = CFG["episode_config"]["Q"]

EPISODES_TRAIN = EPISODES_DIR / f"episodes_train_K{K}_Q{Q}.parquet"
EPISODES_VAL = EPISODES_DIR / f"episodes_val_K{K}_Q{Q}.parquet"
EPISODES_TEST = EPISODES_DIR / f"episodes_test_K{K}_Q{Q}.parquet"
COURSE2ID_PATH = VOCAB_DIR / "course2id.json"

# Output paths
CONTRIB_MODELS_DIR = MODELS / "contributions"
CONTRIB_MODELS_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_DIR = CONTRIB_MODELS_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

REPORT_DIR = REPORTS / "08_warmstart_maml_xuetangx" / RUN_TAG
REPORT_DIR.mkdir(parents=True, exist_ok=True)

OUT_MODEL = CONTRIB_MODELS_DIR / "warmstart_maml_K5.pth"
OUT_RESULTS = RESULTS / "warmstart_maml_K5_Q10.json"
REPORT_PATH = REPORT_DIR / "report.json"

print(f"[CELL 08-02] Input episodes: {EPISODES_TRAIN}")
print(f"[CELL 08-02] Pre-trained GRU4Rec: {PRETRAINED_PATH}")
print(f"[CELL 08-02] Pre-trained exists: {PRETRAINED_PATH.exists()}")
print(f"[CELL 08-02] Output model: {OUT_MODEL}")
print(f"[CELL 08-02] Output results: {OUT_RESULTS}")

cell_end("CELL 08-02", t0)

In [None]:
# [CELL 08-03] Load data

t0 = cell_start("CELL 08-03", "Load episodes and vocabulary")

# Load vocabulary
with open(COURSE2ID_PATH, "r") as f:
    course2id = json.load(f)
n_items = len(course2id)
print(f"[CELL 08-03] Vocabulary size: {n_items} courses")

# Load episodes
episodes_train = pd.read_parquet(EPISODES_TRAIN)
episodes_val = pd.read_parquet(EPISODES_VAL)
episodes_test = pd.read_parquet(EPISODES_TEST)

print(f"[CELL 08-03] Train episodes: {len(episodes_train):,}")
print(f"[CELL 08-03] Val episodes:   {len(episodes_val):,}")
print(f"[CELL 08-03] Test episodes:  {len(episodes_test):,}")

cell_end("CELL 08-03", t0, n_items=n_items, n_train=len(episodes_train), n_val=len(episodes_val), n_test=len(episodes_test))

In [None]:
# [CELL 08-04] Define GRU4Rec model (same architecture as baseline)

t0 = cell_start("CELL 08-04", "Define GRU4Rec model")

class GRURecommender(nn.Module):
    """GRU4Rec model for sequential recommendation.
    
    Same architecture as notebook 06 baseline - required for loading pre-trained weights.
    """
    def __init__(self, n_items: int, embed_dim: int, hidden_dim: int, n_layers: int = 1, dropout: float = 0.1):
        super().__init__()
        self.n_items = n_items
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(n_items, embed_dim, padding_idx=0)
        self.gru = nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0.0,
        )
        self.dropout = nn.Dropout(dropout)
        self.output = nn.Linear(hidden_dim, n_items)
        
    def forward(self, x: torch.Tensor, lengths: torch.Tensor = None) -> torch.Tensor:
        """Forward pass.
        
        Args:
            x: Input sequences [batch, seq_len]
            lengths: Sequence lengths [batch] (optional)
            
        Returns:
            logits: Output logits [batch, n_items]
        """
        # Embed input
        embedded = self.embedding(x)  # [batch, seq_len, embed_dim]
        
        # Pack if lengths provided
        if lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            output, hidden = self.gru(packed)
        else:
            output, hidden = self.gru(embedded)
        
        # Use last hidden state
        last_hidden = hidden[-1]  # [batch, hidden_dim]
        
        # Output projection
        out = self.dropout(last_hidden)
        logits = self.output(out)  # [batch, n_items]
        
        return logits

print(f"[CELL 08-04] GRURecommender defined")
print(f"  - Architecture: Embedding({n_items}, {CFG['model_config']['embed_dim']}) -> GRU({CFG['model_config']['hidden_dim']}) -> Linear({n_items})")

cell_end("CELL 08-04", t0)

In [None]:
# [CELL 08-05] Initialize model with Warm-Start (KEY CONTRIBUTION)

t0 = cell_start("CELL 08-05", "Initialize model with Warm-Start from GRU4Rec")

# Create model with same architecture
meta_model = GRURecommender(
    n_items=n_items,
    embed_dim=CFG["model_config"]["embed_dim"],
    hidden_dim=CFG["model_config"]["hidden_dim"],
    n_layers=CFG["model_config"]["n_layers"],
    dropout=CFG["model_config"]["dropout"],
).to(DEVICE)

# ============================================================
# KEY CONTRIBUTION: Load pre-trained GRU4Rec weights
# ============================================================
print(f"[CELL 08-05] Loading pre-trained GRU4Rec from: {PRETRAINED_PATH}")

pretrained_state = torch.load(PRETRAINED_PATH, map_location=DEVICE)
meta_model.load_state_dict(pretrained_state)

print(f"[CELL 08-05] Successfully loaded pre-trained weights!")
print(f"[CELL 08-05] Meta-model now starts from GRU4Rec baseline (33.55% Acc@1)")
# ============================================================

# Model stats
n_params = sum(p.numel() for p in meta_model.parameters())
n_trainable = sum(p.numel() for p in meta_model.parameters() if p.requires_grad)

print(f"[CELL 08-05] Model parameters: {n_params:,}")
print(f"[CELL 08-05] Trainable parameters: {n_trainable:,}")

# Optionally freeze embeddings
if CFG["warmstart_config"]["freeze_embeddings"]:
    for param in meta_model.embedding.parameters():
        param.requires_grad = False
    n_trainable = sum(p.numel() for p in meta_model.parameters() if p.requires_grad)
    print(f"[CELL 08-05] Embeddings frozen. Trainable: {n_trainable:,}")

# Setup optimizer (outer loop)
meta_optimizer = torch.optim.Adam(meta_model.parameters(), lr=CFG["maml_config"]["outer_lr"])

print(f"[CELL 08-05] Optimizer: Adam (outer_lr={CFG['maml_config']['outer_lr']})")

cell_end("CELL 08-05", t0, n_params=n_params)

In [None]:
# [CELL 08-06] Verify Warm-Start: Test initial accuracy

t0 = cell_start("CELL 08-06", "Verify Warm-Start initialization accuracy")

def prepare_batch(episodes_df: pd.DataFrame, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Prepare a single episode as tensors."""
    row = episodes_df.iloc[idx]
    
    support_x = torch.tensor(row["support_prefixes"], dtype=torch.long)
    support_y = torch.tensor(row["support_labels"], dtype=torch.long)
    query_x = torch.tensor(row["query_prefixes"], dtype=torch.long)
    query_y = torch.tensor(row["query_labels"], dtype=torch.long)
    
    return support_x, support_y, query_x, query_y

def compute_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
    """Compute accuracy@1."""
    preds = logits.argmax(dim=-1)
    return (preds == labels).float().mean().item()

# Test on a sample of test episodes WITHOUT adaptation (zero-shot)
meta_model.eval()
n_test_sample = min(100, len(episodes_test))

correct = 0
total = 0

with torch.no_grad():
    for i in range(n_test_sample):
        _, _, query_x, query_y = prepare_batch(episodes_test, i)
        
        # Pad sequences
        max_len = max(len(seq) for seq in query_x)
        padded_x = torch.zeros(len(query_x), max_len, dtype=torch.long)
        lengths = torch.zeros(len(query_x), dtype=torch.long)
        
        for j, seq in enumerate(query_x):
            padded_x[j, :len(seq)] = torch.tensor(seq)
            lengths[j] = len(seq)
        
        padded_x = padded_x.to(DEVICE)
        query_y = query_y.to(DEVICE)
        lengths = lengths.to(DEVICE)
        
        logits = meta_model(padded_x, lengths)
        preds = logits.argmax(dim=-1)
        
        correct += (preds == query_y).sum().item()
        total += len(query_y)

warmstart_zeroshot_acc = correct / total

print(f"[CELL 08-06] Warm-Start Zero-Shot Accuracy: {warmstart_zeroshot_acc:.4f} ({warmstart_zeroshot_acc*100:.2f}%)")
print(f"[CELL 08-06] Expected (GRU4Rec baseline): ~33.55%")
print(f"[CELL 08-06] Vanilla MAML zero-shot was: 25.62%")
print(f"[CELL 08-06] Improvement from Warm-Start: +{(warmstart_zeroshot_acc - 0.2562)*100:.2f} percentage points")

cell_end("CELL 08-06", t0, warmstart_zeroshot_acc=warmstart_zeroshot_acc)

In [None]:
# [CELL 08-07] Helper functions for MAML training

t0 = cell_start("CELL 08-07", "Define MAML helper functions")

def pad_sequences(sequences: List[List[int]], pad_value: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
    """Pad variable-length sequences."""
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long)
    max_len = max(lengths).item()
    
    padded = torch.full((len(sequences), max_len), pad_value, dtype=torch.long)
    for i, seq in enumerate(sequences):
        padded[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
    
    return padded, lengths

def functional_forward(model: nn.Module, x: torch.Tensor, lengths: torch.Tensor, 
                       params: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Forward pass with external parameters (for MAML inner loop)."""
    # Embedding
    embedded = F.embedding(x, params["embedding.weight"], padding_idx=0)
    
    # GRU - need to handle packed sequences
    if lengths is not None:
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        # Manual GRU forward with params
        weight_ih = params["gru.weight_ih_l0"]
        weight_hh = params["gru.weight_hh_l0"]
        bias_ih = params["gru.bias_ih_l0"]
        bias_hh = params["gru.bias_hh_l0"]
        
        # Unpack and process
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(packed, batch_first=True)
        
        batch_size = x.size(0)
        hidden_dim = weight_hh.size(1)
        h = torch.zeros(batch_size, hidden_dim, device=x.device)
        
        for t in range(unpacked.size(1)):
            inp = unpacked[:, t, :]
            gates = inp @ weight_ih.t() + bias_ih + h @ weight_hh.t() + bias_hh
            
            r, z, n = gates.chunk(3, dim=1)
            r = torch.sigmoid(r)
            z = torch.sigmoid(z)
            n = torch.tanh(n)
            
            h = (1 - z) * n + z * h
        
        # Mask for actual lengths
        last_hidden = h
    else:
        # Simple case without packing
        batch_size = x.size(0)
        hidden_dim = params["gru.weight_hh_l0"].size(1)
        h = torch.zeros(batch_size, hidden_dim, device=x.device)
        
        for t in range(embedded.size(1)):
            inp = embedded[:, t, :]
            weight_ih = params["gru.weight_ih_l0"]
            weight_hh = params["gru.weight_hh_l0"]
            bias_ih = params["gru.bias_ih_l0"]
            bias_hh = params["gru.bias_hh_l0"]
            
            gates = inp @ weight_ih.t() + bias_ih + h @ weight_hh.t() + bias_hh
            
            r, z, n = gates.chunk(3, dim=1)
            r = torch.sigmoid(r)
            z = torch.sigmoid(z)
            n = torch.tanh(n)
            
            h = (1 - z) * n + z * h
        
        last_hidden = h
    
    # Output projection
    logits = last_hidden @ params["output.weight"].t() + params["output.bias"]
    
    return logits

print(f"[CELL 08-07] Helper functions defined")
print(f"  - pad_sequences: Pad variable-length sequences")
print(f"  - functional_forward: Forward pass with external params (for inner loop)")

cell_end("CELL 08-07", t0)

In [None]:
# [CELL 08-08] MAML Training Loop with Warm-Start

t0 = cell_start("CELL 08-08", "MAML Training with Warm-Start initialization")

# Training config
inner_lr = CFG["maml_config"]["inner_lr"]
inner_steps = CFG["maml_config"]["inner_steps"]
meta_batch_size = CFG["maml_config"]["meta_batch_size"]
num_iterations = CFG["maml_config"]["num_meta_iterations"]
val_every = CFG["maml_config"]["val_every"]
checkpoint_every = CFG["maml_config"]["checkpoint_every"]

print(f"[CELL 08-08] Training config:")
print(f"  - Inner LR (alpha): {inner_lr}")
print(f"  - Inner steps: {inner_steps}")
print(f"  - Meta-batch size: {meta_batch_size}")
print(f"  - Iterations: {num_iterations}")
print(f"  - Warm-Start: YES (from GRU4Rec)")

# Training history
history = {
    "train_loss": [],
    "val_acc": [],
    "iteration": [],
}

best_val_acc = 0.0

# Training loop
meta_model.train()

for iteration in range(1, num_iterations + 1):
    meta_optimizer.zero_grad()
    
    # Sample meta-batch of tasks (episodes)
    task_indices = np.random.choice(len(episodes_train), size=meta_batch_size, replace=False)
    
    meta_loss = 0.0
    
    for task_idx in task_indices:
        # Get episode data
        row = episodes_train.iloc[task_idx]
        
        support_x_raw = row["support_prefixes"]
        support_y = torch.tensor(row["support_labels"], dtype=torch.long, device=DEVICE)
        query_x_raw = row["query_prefixes"]
        query_y = torch.tensor(row["query_labels"], dtype=torch.long, device=DEVICE)
        
        # Pad sequences
        support_x, support_lengths = pad_sequences(support_x_raw)
        support_x = support_x.to(DEVICE)
        support_lengths = support_lengths.to(DEVICE)
        
        query_x, query_lengths = pad_sequences(query_x_raw)
        query_x = query_x.to(DEVICE)
        query_lengths = query_lengths.to(DEVICE)
        
        # Clone parameters for inner loop
        params = {name: param.clone() for name, param in meta_model.named_parameters()}
        
        # Inner loop: adapt on support set
        for _ in range(inner_steps):
            support_logits = functional_forward(meta_model, support_x, support_lengths, params)
            support_loss = F.cross_entropy(support_logits, support_y)
            
            # Compute gradients w.r.t. params
            grads = torch.autograd.grad(support_loss, params.values(), create_graph=not CFG["maml_config"]["use_second_order"])
            
            # Update params
            params = {
                name: param - inner_lr * grad
                for (name, param), grad in zip(params.items(), grads)
            }
        
        # Evaluate on query set with adapted params
        query_logits = functional_forward(meta_model, query_x, query_lengths, params)
        query_loss = F.cross_entropy(query_logits, query_y)
        
        meta_loss += query_loss
    
    # Average meta-loss
    meta_loss = meta_loss / meta_batch_size
    
    # Outer loop: meta-update
    meta_loss.backward()
    torch.nn.utils.clip_grad_norm_(meta_model.parameters(), max_norm=10.0)
    meta_optimizer.step()
    
    # Logging
    if iteration % 100 == 0 or iteration == 1:
        print(f"[CELL 08-08] Iteration {iteration}/{num_iterations}, Meta-Loss: {meta_loss.item():.4f}")
        history["train_loss"].append(meta_loss.item())
        history["iteration"].append(iteration)
    
    # Validation
    if iteration % val_every == 0:
        meta_model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for val_idx in range(len(episodes_val)):
                row = episodes_val.iloc[val_idx]
                
                support_x_raw = row["support_prefixes"]
                support_y = torch.tensor(row["support_labels"], dtype=torch.long, device=DEVICE)
                query_x_raw = row["query_prefixes"]
                query_y = torch.tensor(row["query_labels"], dtype=torch.long, device=DEVICE)
                
                support_x, support_lengths = pad_sequences(support_x_raw)
                support_x = support_x.to(DEVICE)
                support_lengths = support_lengths.to(DEVICE)
                
                query_x, query_lengths = pad_sequences(query_x_raw)
                query_x = query_x.to(DEVICE)
                query_lengths = query_lengths.to(DEVICE)
                
                # Clone and adapt
                params = {name: param.clone() for name, param in meta_model.named_parameters()}
                
                for _ in range(inner_steps):
                    support_logits = functional_forward(meta_model, support_x, support_lengths, params)
                    support_loss = F.cross_entropy(support_logits, support_y)
                    grads = torch.autograd.grad(support_loss, params.values())
                    params = {
                        name: param - inner_lr * grad
                        for (name, param), grad in zip(params.items(), grads)
                    }
                
                query_logits = functional_forward(meta_model, query_x, query_lengths, params)
                preds = query_logits.argmax(dim=-1)
                
                val_correct += (preds == query_y).sum().item()
                val_total += len(query_y)
        
        val_acc = val_correct / val_total
        history["val_acc"].append(val_acc)
        print(f"[CELL 08-08] Iteration {iteration}, Val Acc@1: {val_acc:.4f} ({val_acc*100:.2f}%)")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(meta_model.state_dict(), OUT_MODEL)
            print(f"[CELL 08-08] New best model saved! Val Acc: {val_acc*100:.2f}%")
        
        meta_model.train()
    
    # Checkpoint
    if iteration % checkpoint_every == 0:
        checkpoint_path = CHECKPOINT_DIR / f"warmstart_checkpoint_iter{iteration}.pth"
        torch.save({
            "iteration": iteration,
            "model_state_dict": meta_model.state_dict(),
            "optimizer_state_dict": meta_optimizer.state_dict(),
            "best_val_acc": best_val_acc,
        }, checkpoint_path)
        print(f"[CELL 08-08] Checkpoint saved: {checkpoint_path.name}")

print(f"\n[CELL 08-08] Training complete!")
print(f"[CELL 08-08] Best validation accuracy: {best_val_acc*100:.2f}%")

cell_end("CELL 08-08", t0, best_val_acc=best_val_acc)

In [None]:
# [CELL 08-09] Final Evaluation on Test Set

t0 = cell_start("CELL 08-09", "Final evaluation on test set")

# Load best model
meta_model.load_state_dict(torch.load(OUT_MODEL, map_location=DEVICE))
meta_model.eval()

print(f"[CELL 08-09] Loaded best model from: {OUT_MODEL}")

# Evaluate: Zero-shot (no adaptation)
zeroshot_correct = 0
zeroshot_total = 0

# Evaluate: Few-shot (with adaptation)
fewshot_correct = 0
fewshot_total = 0

with torch.no_grad():
    for test_idx in range(len(episodes_test)):
        row = episodes_test.iloc[test_idx]
        
        support_x_raw = row["support_prefixes"]
        support_y = torch.tensor(row["support_labels"], dtype=torch.long, device=DEVICE)
        query_x_raw = row["query_prefixes"]
        query_y = torch.tensor(row["query_labels"], dtype=torch.long, device=DEVICE)
        
        support_x, support_lengths = pad_sequences(support_x_raw)
        support_x = support_x.to(DEVICE)
        support_lengths = support_lengths.to(DEVICE)
        
        query_x, query_lengths = pad_sequences(query_x_raw)
        query_x = query_x.to(DEVICE)
        query_lengths = query_lengths.to(DEVICE)
        
        # Zero-shot: no adaptation
        params_zs = {name: param.clone() for name, param in meta_model.named_parameters()}
        query_logits_zs = functional_forward(meta_model, query_x, query_lengths, params_zs)
        preds_zs = query_logits_zs.argmax(dim=-1)
        zeroshot_correct += (preds_zs == query_y).sum().item()
        zeroshot_total += len(query_y)
        
        # Few-shot: adapt on support set
        params_fs = {name: param.clone() for name, param in meta_model.named_parameters()}
        
        for _ in range(inner_steps):
            support_logits = functional_forward(meta_model, support_x, support_lengths, params_fs)
            support_loss = F.cross_entropy(support_logits, support_y)
            grads = torch.autograd.grad(support_loss, params_fs.values())
            params_fs = {
                name: param - inner_lr * grad
                for (name, param), grad in zip(params_fs.items(), grads)
            }
        
        query_logits_fs = functional_forward(meta_model, query_x, query_lengths, params_fs)
        preds_fs = query_logits_fs.argmax(dim=-1)
        fewshot_correct += (preds_fs == query_y).sum().item()
        fewshot_total += len(query_y)

zeroshot_acc = zeroshot_correct / zeroshot_total
fewshot_acc = fewshot_correct / fewshot_total

print(f"\n[CELL 08-09] ========== RESULTS ==========")
print(f"[CELL 08-09] Test episodes: {len(episodes_test)}")
print(f"\n[CELL 08-09] Warm-Start MAML Zero-shot: {zeroshot_acc:.4f} ({zeroshot_acc*100:.2f}%)")
print(f"[CELL 08-09] Warm-Start MAML Few-shot:  {fewshot_acc:.4f} ({fewshot_acc*100:.2f}%)")
print(f"\n[CELL 08-09] ========== COMPARISON ==========")
print(f"[CELL 08-09] GRU4Rec baseline:          33.55%")
print(f"[CELL 08-09] Vanilla MAML Zero-shot:    25.62%")
print(f"[CELL 08-09] Vanilla MAML Few-shot:     28.66%")
print(f"[CELL 08-09] Warm-Start MAML Zero-shot: {zeroshot_acc*100:.2f}%")
print(f"[CELL 08-09] Warm-Start MAML Few-shot:  {fewshot_acc*100:.2f}%")
print(f"\n[CELL 08-09] Improvement over Vanilla MAML: +{(fewshot_acc - 0.2866)*100:.2f} pp")
print(f"[CELL 08-09] Improvement over GRU4Rec:      {(fewshot_acc - 0.3355)*100:+.2f} pp")

cell_end("CELL 08-09", t0, zeroshot_acc=zeroshot_acc, fewshot_acc=fewshot_acc)

In [None]:
# [CELL 08-10] Save results and report

t0 = cell_start("CELL 08-10", "Save results and report")

# Results
results = {
    "model": "warmstart_maml",
    "contribution": "warm_start_initialization",
    "dataset": "xuetangx",
    "config": CFG,
    "metrics": {
        "zeroshot": {
            "accuracy@1": zeroshot_acc,
        },
        "fewshot": {
            "accuracy@1": fewshot_acc,
        },
    },
    "comparison": {
        "gru4rec_baseline": 0.3355,
        "vanilla_maml_zeroshot": 0.2562,
        "vanilla_maml_fewshot": 0.2866,
        "warmstart_maml_zeroshot": zeroshot_acc,
        "warmstart_maml_fewshot": fewshot_acc,
    },
    "improvement": {
        "over_vanilla_maml": fewshot_acc - 0.2866,
        "over_gru4rec": fewshot_acc - 0.3355,
    },
}

# Save results
with open(OUT_RESULTS, "w") as f:
    json.dump(results, f, indent=2)
print(f"[CELL 08-10] Results saved: {OUT_RESULTS}")

# Report
report = {
    "notebook": CFG["notebook"],
    "run_tag": RUN_TAG,
    "timestamp": datetime.now().isoformat(),
    "config": CFG,
    "results": results,
    "history": history,
}

with open(REPORT_PATH, "w") as f:
    json.dump(report, f, indent=2)
print(f"[CELL 08-10] Report saved: {REPORT_PATH}")

cell_end("CELL 08-10", t0)

## Notebook 08 Complete: Warm-Start MAML Results

**Contribution:** Initialize MAML from pre-trained GRU4Rec weights instead of random initialization.

**Key Results:**

| Model | Initialization | Acc@1 | vs Vanilla MAML | vs GRU4Rec |
|-------|---------------|-------|-----------------|------------|
| GRU4Rec (baseline) | Trained | 33.55% | - | - |
| Vanilla MAML Zero-shot | Random | 25.62% | - | -7.93 pp |
| Vanilla MAML Few-shot | Random | 28.66% | baseline | -4.89 pp |
| **Warm-Start MAML Zero-shot** | GRU4Rec | **??%** | +?? pp | +?? pp |
| **Warm-Start MAML Few-shot** | GRU4Rec | **??%** | +?? pp | +?? pp |

**Key Insight:** By starting from GRU4Rec's learned weights (33.55%), the inner loop only needs to learn user-specific adaptations, not the entire course pattern space.

**Next:** Notebook 09 - Recency-Weighted MAML (Contribution 2)