# Notebook 07e: MAML with GRU Warm-Start (XuetangX)

**Purpose:** MAML initialized from pre-trained GRU baseline (warm-start meta-learning)

**Key Idea:**
- Standard MAML (NB 07): Random init ‚Üí Meta-train ‚Üí 30.52% Acc@1 ‚ùå (below baseline)
- **Warm-start MAML (NB 07e)**: GRU init ‚Üí Meta-train ‚Üí **Expected: 35-38%** ‚úÖ (beats baseline!)

**Why This Works:**
- GRU baseline (33.73%) is a strong initialization for the task
- MAML meta-training **refines** it to be more adaptable to new users
- Combines: Task performance (GRU) + Adaptation ability (MAML)

**This IS Still Meta-Learning:**
- ‚úÖ Uses MAML algorithm (meta-gradients, inner/outer loop)
- ‚úÖ Learns adaptation strategy via meta-training
- ‚úÖ Only difference: Better initialization (standard practice in meta-learning)
- üìö Similar to "How to train your MAML" (Antoniou et al., 2019)

**Comparison:**
- GRU Baseline (NB 06): 33.73% Acc@1
- MAML Random Init (NB 07): 30.52% Acc@1 (-9.51%)
- MAML Meta-SGD (NB 07c): 3.79% Acc@1 (failed)
- **MAML Warm-Start (NB 07e): Target > 33.73%** ‚úÖ

**Inputs:**
- Pre-trained GRU: `models/baselines/gru_global.pth` (33.73% Acc@1)
- Episodes: Same as Notebook 07
- Config: Same MAML hyperparameters as Notebook 07

**Outputs:**
- Warm-start MAML model: `models/maml/maml_warmstart_gru_K5.pth`
- Results: `results/maml_warmstart_K5_Q10.json`
- Report: `reports/07e_maml_warmstart_xuetangx/<run_tag>/report.json`

**Training Time:** ~24 hours (10,000 meta-iterations, same as NB 07)

**Expected Outcome:** MAML with warm-start beats GRU baseline (35-38% vs 33.73%)

In [1]:
# [CELL 07e-00] Bootstrap: repo root + paths + logger

import os
import sys
import json
import time
import uuid
import pickle
import hashlib
import copy
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Tuple, Optional
from collections import Counter, OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

t0 = datetime.now()
print(f"[CELL 07e-00] start={t0.isoformat(timespec='seconds')}")
print("[CELL 07e-00] CWD:", Path.cwd().resolve())

def find_repo_root(start: Path) -> Path:
    start = start.resolve()
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    raise RuntimeError("Could not find PROJECT_STATE.md. Open notebook from within the repo.")

REPO_ROOT = find_repo_root(Path.cwd())
print("[CELL 07e-00] REPO_ROOT:", REPO_ROOT)

PATHS = {
    "META_REGISTRY": REPO_ROOT / "meta.json",
    "DATA_INTERIM": REPO_ROOT / "data" / "interim",
    "DATA_PROCESSED": REPO_ROOT / "data" / "processed",
    "MODELS": REPO_ROOT / "models",
    "RESULTS": REPO_ROOT / "results",
    "REPORTS": REPO_ROOT / "reports",
}
for k, v in PATHS.items():
    print(f"[CELL 07e-00] {k}={v}")

def cell_start(cell_id: str, title: str, **kwargs: Any) -> float:
    t = time.time()
    print(f"\n[{cell_id}] {title}")
    print(f"[{cell_id}] start={datetime.now().isoformat(timespec='seconds')}")
    for k, v in kwargs.items():
        print(f"[{cell_id}] {k}={v}")
    return t

def cell_end(cell_id: str, t0: float, **kwargs: Any) -> None:
    for k, v in kwargs.items():
        print(f"[{cell_id}] {k}={v}")
    print(f"[{cell_id}] elapsed={time.time()-t0:.2f}s")
    print(f"[{cell_id}] done")

# Check GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[CELL 07e-00] PyTorch device: {DEVICE}")
print("[CELL 07e-00] done")

[CELL 07e-00] start=2026-01-11T17:55:16
[CELL 07e-00] CWD: C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\notebooks
[CELL 07e-00] REPO_ROOT: C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta
[CELL 07e-00] META_REGISTRY=C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\meta.json
[CELL 07e-00] DATA_INTERIM=C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\data\interim
[CELL 07e-00] DATA_PROCESSED=C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\data\processed
[CELL 07e-00] MODELS=C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\models
[CELL 07e-00] RESULTS=C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\results
[CELL 07e-00] REPORTS=C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\reports
[CELL 07e-00] PyTorch device: cpu
[CELL 07e-00] done


In [2]:
# [CELL 07e-01] Reproducibility: seed everything

t0 = cell_start("CELL 07e-01", "Seed everything")

GLOBAL_SEED = 20260107

def seed_everything(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(GLOBAL_SEED)

cell_end("CELL 07e-01", t0, seed=GLOBAL_SEED)


[CELL 07e-01] Seed everything
[CELL 07e-01] start=2026-01-11T17:55:16
[CELL 07e-01] seed=20260107
[CELL 07e-01] elapsed=0.03s
[CELL 07e-01] done


In [3]:
# [CELL 07e-02] JSON/Pickle IO + hashing helpers

t0 = cell_start("CELL 07e-02", "IO helpers")

def write_json_atomic(path: Path, obj: Any, indent: int = 2) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    tmp = path.with_suffix(path.suffix + f".tmp_{uuid.uuid4().hex}")
    with tmp.open("w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=indent)
    tmp.replace(path)

def read_json(path: Path) -> Any:
    if not path.exists():
        raise RuntimeError(f"Missing JSON file: {path}")
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)

def save_pickle(path: Path, obj: Any) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("wb") as f:
        pickle.dump(obj, f)

def load_pickle(path: Path) -> Any:
    with path.open("rb") as f:
        return pickle.load(f)

def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
    h = hashlib.sha256()
    with path.open("rb") as f:
        while True:
            b = f.read(chunk_size)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

cell_end("CELL 07e-02", t0)


[CELL 07e-02] IO helpers
[CELL 07e-02] start=2026-01-11T17:55:16
[CELL 07e-02] elapsed=0.00s
[CELL 07e-02] done


In [4]:
# [CELL 07e-03] Run tagging + config + meta.json

t0 = cell_start("CELL 07e-03", "Start run + init files")

NOTEBOOK_NAME = "07e_maml_warmstart_xuetangx"
RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_ID = uuid.uuid4().hex

OUT_DIR = PATHS["REPORTS"] / NOTEBOOK_NAME / RUN_TAG
OUT_DIR.mkdir(parents=True, exist_ok=True)

REPORT_PATH = OUT_DIR / "report.json"
CONFIG_PATH = OUT_DIR / "config.json"
MANIFEST_PATH = OUT_DIR / "manifest.json"

# Paths
EPISODES_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "episodes"
PAIRS_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "pairs"
VOCAB_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "vocab"
MODELS_DIR = PATHS["MODELS"] / "maml"
CHECKPOINTS_DIR = MODELS_DIR / "checkpoints_warmstart"
RESULTS_DIR = PATHS["RESULTS"]

MODELS_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# GRU baseline checkpoint path
GRU_BASELINE_PATH = PATHS["MODELS"] / "baselines" / "gru_global.pth"

# K-shot config
K, Q = 5, 10

CFG = {
    "notebook": NOTEBOOK_NAME,
    "run_id": RUN_ID,
    "run_tag": RUN_TAG,
    "seed": GLOBAL_SEED,
    "device": str(DEVICE),
    "k_shot_config": {"K": K, "Q": Q},
    "inputs": {
        "episodes_train": str(EPISODES_DIR / f"episodes_train_K{K}_Q{Q}.parquet"),
        "episodes_val": str(EPISODES_DIR / f"episodes_val_K{K}_Q{Q}.parquet"),
        "episodes_test": str(EPISODES_DIR / f"episodes_test_K{K}_Q{Q}.parquet"),
        "pairs_train": str(PAIRS_DIR / "pairs_train.parquet"),
        "pairs_val": str(PAIRS_DIR / "pairs_val.parquet"),
        "pairs_test": str(PAIRS_DIR / "pairs_test.parquet"),
        "vocab": str(VOCAB_DIR / "course2id.json"),
        "gru_baseline": str(GRU_BASELINE_PATH),  # ‚Üê WARM-START FROM HERE
    },
    "gru_config": {
        "embedding_dim": 64,
        "hidden_dim": 128,
        "num_layers": 1,
        "dropout": 0.2,
        "max_seq_len": 50,
    },
    "maml_config": {
        "inner_lr": 0.01,           # Œ±: learning rate for inner loop (task adaptation)
        "outer_lr": 0.001,          # Œ≤: learning rate for outer loop (meta-update)
        "num_inner_steps": 5,       # number of gradient steps for adaptation
        "meta_batch_size": 32,      # number of tasks (users) per meta-batch
        "num_meta_iterations": 10000,  # total meta-training iterations
        "checkpoint_interval": 1000,   # save checkpoint every N iterations
        "eval_interval": 500,          # evaluate on val set every N iterations
        "use_second_order": True,      # True: MAML (2nd order), False: FOMAML (1st order)
        "warm_start": True,            # ‚Üê NEW: Initialize from GRU baseline
    },
    "ablation_configs": {
        "support_set_sizes": [1, 3, 5, 10],
        "adaptation_steps": [1, 3, 5, 10],
    },
    "metrics": ["accuracy@1", "recall@5", "recall@10", "mrr"],
    "outputs": {
        "models_dir": str(MODELS_DIR),
        "checkpoints_dir": str(CHECKPOINTS_DIR),
        "results": str(RESULTS_DIR / f"maml_warmstart_K{K}_Q{Q}.json"),
        "out_dir": str(OUT_DIR),
    }
}

write_json_atomic(CONFIG_PATH, CFG)

report = {
    "run_id": RUN_ID,
    "notebook": NOTEBOOK_NAME,
    "run_tag": RUN_TAG,
    "created_at": datetime.now().isoformat(timespec="seconds"),
    "repo_root": str(REPO_ROOT),
    "metrics": {},
    "key_findings": [],
    "sanity_samples": {},
    "data_fingerprints": {},
    "notes": [],
}
write_json_atomic(REPORT_PATH, report)

manifest = {"run_id": RUN_ID, "notebook": NOTEBOOK_NAME, "run_tag": RUN_TAG, "artifacts": []}
write_json_atomic(MANIFEST_PATH, manifest)

# meta.json
META_PATH = PATHS["META_REGISTRY"]
if not META_PATH.exists():
    write_json_atomic(META_PATH, {"schema_version": 1, "runs": []})
meta = read_json(META_PATH)
meta["runs"].append({
    "run_id": RUN_ID,
    "notebook": NOTEBOOK_NAME,
    "run_tag": RUN_TAG,
    "out_dir": str(OUT_DIR),
    "created_at": datetime.now().isoformat(timespec="seconds"),
})
write_json_atomic(META_PATH, meta)

print(f"[CELL 07e-03] K={K}, Q={Q}")
print(f"[CELL 07e-03] MAML config: Œ±={CFG['maml_config']['inner_lr']}, Œ≤={CFG['maml_config']['outer_lr']}, "
      f"inner_steps={CFG['maml_config']['num_inner_steps']}, meta_batch={CFG['maml_config']['meta_batch_size']}")
print(f"[CELL 07e-03] ‚≠ê WARM-START: Initializing from GRU baseline")
print(f"[CELL 07e-03] GRU baseline: {GRU_BASELINE_PATH}")

cell_end("CELL 07e-03", t0, out_dir=str(OUT_DIR))


[CELL 07e-03] Start run + init files
[CELL 07e-03] start=2026-01-11T17:55:16
[CELL 07e-03] K=5, Q=10
[CELL 07e-03] MAML config: Œ±=0.01, Œ≤=0.001, inner_steps=5, meta_batch=32
[CELL 07e-03] ‚≠ê WARM-START: Initializing from GRU baseline
[CELL 07e-03] GRU baseline: C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\models\baselines\gru_global.pth
[CELL 07e-03] out_dir=C:\Users\User\Documents\ml-workspace\anonymous-users-mooc-session-meta\reports\07e_maml_warmstart_xuetangx\20260111_175516
[CELL 07e-03] elapsed=0.02s
[CELL 07e-03] done


In [5]:
# [CELL 07e-04] Load data: episodes, pairs, vocab

t0 = cell_start("CELL 07e-04", "Load data")

# Vocab
course2id = read_json(Path(CFG["inputs"]["vocab"]))
id2course = {int(v): k for k, v in course2id.items()}
n_items = len(course2id)
print(f"[CELL 07e-04] Vocabulary: {n_items} courses")

# Episodes
episodes_train = pd.read_parquet(CFG["inputs"]["episodes_train"])
episodes_val = pd.read_parquet(CFG["inputs"]["episodes_val"])
episodes_test = pd.read_parquet(CFG["inputs"]["episodes_test"])

print(f"[CELL 07e-04] Episodes train: {len(episodes_train):,} episodes ({episodes_train['user_id'].nunique():,} users)")
print(f"[CELL 07e-04] Episodes val:   {len(episodes_val):,} episodes ({episodes_val['user_id'].nunique():,} users)")
print(f"[CELL 07e-04] Episodes test:  {len(episodes_test):,} episodes ({episodes_test['user_id'].nunique():,} users)")

# Pairs
pairs_train = pd.read_parquet(CFG["inputs"]["pairs_train"])
pairs_val = pd.read_parquet(CFG["inputs"]["pairs_val"])
pairs_test = pd.read_parquet(CFG["inputs"]["pairs_test"])

print(f"[CELL 07e-04] Pairs train: {len(pairs_train):,} pairs")
print(f"[CELL 07e-04] Pairs val:   {len(pairs_val):,} pairs")
print(f"[CELL 07e-04] Pairs test:  {len(pairs_test):,} pairs")

cell_end("CELL 07e-04", t0)


[CELL 07e-04] Load data
[CELL 07e-04] start=2026-01-11T17:55:16
[CELL 07e-04] Vocabulary: 343 courses
[CELL 07e-04] Episodes train: 66,187 episodes (3,006 users)
[CELL 07e-04] Episodes val:   340 episodes (340 users)
[CELL 07e-04] Episodes test:  346 episodes (346 users)
[CELL 07e-04] Pairs train: 212,923 pairs
[CELL 07e-04] Pairs val:   24,698 pairs
[CELL 07e-04] Pairs test:  26,608 pairs
[CELL 07e-04] elapsed=1.09s
[CELL 07e-04] done


In [6]:
# [CELL 07e-05] Evaluation metrics (same as Notebook 07)

t0 = cell_start("CELL 07e-05", "Define evaluation metrics")

def compute_metrics(predictions: np.ndarray, labels: np.ndarray, k_values: List[int] = [5, 10]) -> Dict[str, float]:
    """
    Compute ranking metrics.
    
    Args:
        predictions: (n_samples, n_items) score matrix
        labels: (n_samples,) true item indices
        k_values: list of k for Recall@k
    
    Returns:
        dict with accuracy@1, recall@k, mrr
    """
    n_samples = len(labels)
    
    # Get top-k predictions (indices)
    max_k = max(k_values)
    top_k_preds = np.argsort(-predictions, axis=1)[:, :max_k]  # descending order
    
    # Accuracy@1
    top1_preds = top_k_preds[:, 0]
    acc1 = (top1_preds == labels).mean()
    
    # Recall@k
    recall_k = {}
    for k in k_values:
        hits = np.array([labels[i] in top_k_preds[i, :k] for i in range(n_samples)])
        recall_k[f"recall@{k}"] = hits.mean()
    
    # MRR (Mean Reciprocal Rank)
    ranks = []
    for i in range(n_samples):
        # Find rank of true label (1-indexed)
        rank_idx = np.where(top_k_preds[i] == labels[i])[0]
        if len(rank_idx) > 0:
            ranks.append(1.0 / (rank_idx[0] + 1))  # reciprocal rank
        else:
            # Not in top-k, check full ranking
            full_rank = np.where(np.argsort(-predictions[i]) == labels[i])[0][0]
            ranks.append(1.0 / (full_rank + 1))
    mrr = np.mean(ranks)
    
    return {
        "accuracy@1": float(acc1),
        **{k: float(v) for k, v in recall_k.items()},
        "mrr": float(mrr),
    }

print("[CELL 07e-05] Metrics: accuracy@1, recall@5, recall@10, mrr")

cell_end("CELL 07e-05", t0)


[CELL 07e-05] Define evaluation metrics
[CELL 07e-05] start=2026-01-11T17:55:17
[CELL 07e-05] Metrics: accuracy@1, recall@5, recall@10, mrr
[CELL 07e-05] elapsed=0.00s
[CELL 07e-05] done


In [7]:
# [CELL 07e-06] Define GRU model (exact same as Notebook 06 & 07)

t0 = cell_start("CELL 07e-06", "Define GRU model")

class GRURecommender(nn.Module):
    def __init__(self, n_items: int, embedding_dim: int, hidden_dim: int, num_layers: int, dropout: float):
        super().__init__()
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        self.embedding = nn.Embedding(n_items, embedding_dim, padding_idx=0)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.fc = nn.Linear(hidden_dim, n_items)
    
    def forward(self, seq: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        Args:
            seq: (batch, max_len) padded sequences
            lengths: (batch,) actual lengths
        Returns:
            logits: (batch, n_items)
        """
        # Embed
        emb = self.embedding(seq)  # (batch, max_len, embed_dim)
        
        # Pack for efficiency
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        
        # GRU
        _, hidden = self.gru(packed)  # hidden: (num_layers, batch, hidden_dim)
        
        # Use last layer hidden state
        h = hidden[-1]  # (batch, hidden_dim)
        
        # Predict
        logits = self.fc(h)  # (batch, n_items)
        return logits

print("[CELL 07e-06] GRU model defined")
print(f"  - Embedding dim: {CFG['gru_config']['embedding_dim']}")
print(f"  - Hidden dim: {CFG['gru_config']['hidden_dim']}")
print(f"  - Num layers: {CFG['gru_config']['num_layers']}")

cell_end("CELL 07e-06", t0)


[CELL 07e-06] Define GRU model
[CELL 07e-06] start=2026-01-11T17:55:17
[CELL 07e-06] GRU model defined
  - Embedding dim: 64
  - Hidden dim: 128
  - Num layers: 1
[CELL 07e-06] elapsed=0.00s
[CELL 07e-06] done


In [8]:
# [CELL 07e-07] ‚≠ê Initialize meta-model with GRU baseline (WARM-START)

t0 = cell_start("CELL 07e-07", "Initialize meta-model with warm-start")

# Create meta-model
meta_model = GRURecommender(
    n_items=n_items,
    embedding_dim=CFG["gru_config"]["embedding_dim"],
    hidden_dim=CFG["gru_config"]["hidden_dim"],
    num_layers=CFG["gru_config"]["num_layers"],
    dropout=CFG["gru_config"]["dropout"],
).to(DEVICE)

# ‚≠ê WARM-START: Load GRU baseline checkpoint
if not GRU_BASELINE_PATH.exists():
    raise FileNotFoundError(f"GRU baseline not found: {GRU_BASELINE_PATH}")

print(f"[CELL 07e-07] Loading GRU baseline from: {GRU_BASELINE_PATH.name}")
baseline_checkpoint = torch.load(GRU_BASELINE_PATH, map_location=DEVICE)

# Load state dict
if "model_state_dict" in baseline_checkpoint:
    meta_model.load_state_dict(baseline_checkpoint["model_state_dict"])
    print("[CELL 07e-07] Loaded from 'model_state_dict' key")
else:
    meta_model.load_state_dict(baseline_checkpoint)
    print("[CELL 07e-07] Loaded directly from checkpoint")

# Verify baseline performance
if "metrics" in baseline_checkpoint:
    baseline_acc = baseline_checkpoint["metrics"].get("test_accuracy@1", "N/A")
    print(f"[CELL 07e-07] GRU baseline Acc@1: {baseline_acc}")
else:
    print("[CELL 07e-07] GRU baseline Acc@1: 33.73% (from NB 06)")

print(f"\n[CELL 07e-07] ‚úÖ WARM-START COMPLETE")
print(f"[CELL 07e-07] Meta-model initialized from strong GRU baseline")
print(f"[CELL 07e-07] Model parameters: {sum(p.numel() for p in meta_model.parameters()):,}")
print(f"[CELL 07e-07] Now will meta-train to make it more adaptable!")

cell_end("CELL 07e-07", t0)


[CELL 07e-07] Initialize meta-model with warm-start
[CELL 07e-07] start=2026-01-11T17:55:17
[CELL 07e-07] Loading GRU baseline from: gru_global.pth
[CELL 07e-07] Loaded directly from checkpoint
[CELL 07e-07] GRU baseline Acc@1: 33.73% (from NB 06)

[CELL 07e-07] ‚úÖ WARM-START COMPLETE
[CELL 07e-07] Meta-model initialized from strong GRU baseline
[CELL 07e-07] Model parameters: 140,695
[CELL 07e-07] Now will meta-train to make it more adaptable!
[CELL 07e-07] elapsed=0.05s
[CELL 07e-07] done


## üî• Key Difference from Notebook 07

**Notebook 07 (Random Init)**:
```python
meta_model = GRURecommender(...)  # Random initialization
# Meta-train from scratch ‚Üí 30.52% Acc@1
```

**Notebook 07e (Warm-Start)**:
```python
meta_model = GRURecommender(...)
meta_model.load_state_dict(gru_baseline)  # ‚Üê Load GRU baseline (33.73%)
# Meta-train from here ‚Üí Expected: 35-38% Acc@1 ‚úÖ
```

**Why this works**:
- GRU baseline already knows how to recommend courses (33.73%)
- MAML meta-training refines it to adapt better to new users
- Combines: Strong task initialization + Meta-learned adaptation

**From here onwards, the rest of the notebook is IDENTICAL to Notebook 07:**
- Same MAML training loop
- Same meta-batch sampling
- Same inner/outer loop updates
- Same evaluation

**Only difference**: We started from a better place!

In [9]:
# [CELL 07e-08] Helper functions (same as Notebook 07)

t0 = cell_start("CELL 07e-08", "Define helper functions")

def get_episode_data(episode_row, pairs_df):
    """Extract support and query pairs for an episode."""
    support_pair_ids = episode_row["support_pair_ids"]
    query_pair_ids = episode_row["query_pair_ids"]

    support_pairs = pairs_df[pairs_df["pair_id"].isin(support_pair_ids)].sort_values("label_ts_epoch")
    query_pairs = pairs_df[pairs_df["pair_id"].isin(query_pair_ids)].sort_values("label_ts_epoch")

    return support_pairs, query_pairs

def pairs_to_batch(pairs_df, max_len):
    """Convert pairs to batched tensors."""
    prefixes = []
    labels = []
    lengths = []

    for _, row in pairs_df.iterrows():
        prefix = row["prefix"]
        if len(prefix) > max_len:
            prefix = prefix[-max_len:]
        prefixes.append(prefix)
        labels.append(row["label"])
        lengths.append(len(prefix))

    # Pad sequences
    max_l = max(lengths)
    padded = []
    for seq in prefixes:
        padded.append(list(seq) + [0] * (max_l - len(seq)))

    return (
        torch.LongTensor(padded).to(DEVICE),
        torch.LongTensor(labels).to(DEVICE),
        torch.LongTensor(lengths).to(DEVICE),
    )

def functional_forward(seq, lengths, params, hidden_dim, n_items):
    """Functional forward pass using explicit parameters."""
    batch_size = seq.size(0)
    
    # 1. Embedding
    emb = F.embedding(seq, params["embedding.weight"], padding_idx=0)
    
    # 2. GRU (manual implementation)
    h = torch.zeros(batch_size, hidden_dim, device=seq.device)
    w_ih = params["gru.weight_ih_l0"]
    w_hh = params["gru.weight_hh_l0"]
    b_ih = params["gru.bias_ih_l0"]
    b_hh = params["gru.bias_hh_l0"]
    
    for t in range(emb.size(1)):
        x_t = emb[:, t, :]
        gi = F.linear(x_t, w_ih, b_ih)
        gh = F.linear(h, w_hh, b_hh)
        i_r, i_z, i_n = gi.chunk(3, 1)
        h_r, h_z, h_n = gh.chunk(3, 1)
        
        r = torch.sigmoid(i_r + h_r)
        z = torch.sigmoid(i_z + h_z)
        n = torch.tanh(i_n + r * h_n)
        h_new = (1 - z) * n + z * h
        
        mask = (lengths > t).unsqueeze(1).float()
        h = mask * h_new + (1 - mask) * h
    
    # 3. FC layer
    logits = F.linear(h, params["fc.weight"], params["fc.bias"])
    return logits

print("[CELL 07e-08] Helper functions defined")

cell_end("CELL 07e-08", t0)


[CELL 07e-08] Define helper functions
[CELL 07e-08] start=2026-01-11T17:55:17
[CELL 07e-08] Helper functions defined
[CELL 07e-08] elapsed=0.00s
[CELL 07e-08] done


In [None]:
# [CELL 07e-09] MAML meta-training loop (Functional FOMAML - proper implementation)

t0 = cell_start("CELL 07e-09", "MAML meta-training")

# Initialize meta-model
meta_model = GRURecommender(
    n_items=n_items,
    embedding_dim=CFG["gru_config"]["embedding_dim"],
    hidden_dim=CFG["gru_config"]["hidden_dim"],
    num_layers=CFG["gru_config"]["num_layers"],
    dropout=CFG["gru_config"]["dropout"],
).to(DEVICE)

print(f"[CELL 07e-09] Meta-model parameters: {sum(p.numel() for p in meta_model.parameters()):,}")

# Meta-optimizer (outer loop)
meta_optimizer = torch.optim.Adam(meta_model.parameters(), lr=CFG["maml_config"]["outer_lr"])
criterion = nn.CrossEntropyLoss()

# MAML hyperparameters
inner_lr = CFG["maml_config"]["inner_lr"]
num_inner_steps = CFG["maml_config"]["num_inner_steps"]
meta_batch_size = CFG["maml_config"]["meta_batch_size"]
num_meta_iterations = CFG["maml_config"]["num_meta_iterations"]
max_seq_len = CFG["gru_config"]["max_seq_len"]
use_second_order = CFG["maml_config"].get("use_second_order", False)  # Default to FOMAML

print(f"[CELL 07e-09] Using {'MAML (Second-Order)' if use_second_order else 'First-Order MAML (FOMAML)'}")
print(f"[CELL 07e-09] Meta-training config:")
print(f"  - Inner LR (Œ±): {inner_lr}")
print(f"  - Outer LR (Œ≤): {CFG['maml_config']['outer_lr']}")
print(f"  - Inner steps: {num_inner_steps}")
print(f"  - Meta-batch size: {meta_batch_size}")
print(f"  - Meta-iterations: {num_meta_iterations:,}")

def get_episode_data(episode_row, pairs_df):
    """Extract support and query pairs for an episode."""
    support_pair_ids = episode_row["support_pair_ids"]
    query_pair_ids = episode_row["query_pair_ids"]

    support_pairs = pairs_df[pairs_df["pair_id"].isin(support_pair_ids)].sort_values("label_ts_epoch")
    query_pairs = pairs_df[pairs_df["pair_id"].isin(query_pair_ids)].sort_values("label_ts_epoch")

    return support_pairs, query_pairs

def pairs_to_batch(pairs_df, max_len):
    """Convert pairs to batched tensors."""
    prefixes = []
    labels = []
    lengths = []

    for _, row in pairs_df.iterrows():
        prefix = row["prefix"]
        if len(prefix) > max_len:
            prefix = prefix[-max_len:]
        prefixes.append(prefix)
        labels.append(row["label"])
        lengths.append(len(prefix))

    # Pad sequences
    max_l = max(lengths)
    padded = []
    for seq in prefixes:
        padded.append(list(seq) + [0] * (max_l - len(seq)))

    return (
        torch.LongTensor(padded).to(DEVICE),
        torch.LongTensor(labels).to(DEVICE),
        torch.LongTensor(lengths).to(DEVICE),
    )

# Functional forward pass for GRU (avoids in-place operations)
def functional_forward(seq, lengths, params, hidden_dim, n_items):
    """
    Functional forward pass using explicit parameters.
    Implements: Embedding -> GRU -> FC
    """
    batch_size = seq.size(0)
    
    # 1. Embedding
    emb = F.embedding(seq, params['embedding.weight'], padding_idx=0)
    
    # 2. GRU (manual implementation for num_layers=1, batch_first=True)
    h = torch.zeros(batch_size, hidden_dim, device=seq.device)
    
    # GRU parameters
    w_ih = params['gru.weight_ih_l0']
    w_hh = params['gru.weight_hh_l0']
    b_ih = params['gru.bias_ih_l0']
    b_hh = params['gru.bias_hh_l0']
    
    # Process sequence
    for t in range(emb.size(1)):
        x_t = emb[:, t, :]
        
        # GRU gates
        gi = F.linear(x_t, w_ih, b_ih)
        gh = F.linear(h, w_hh, b_hh)
        i_r, i_z, i_n = gi.chunk(3, 1)
        h_r, h_z, h_n = gh.chunk(3, 1)
        
        r = torch.sigmoid(i_r + h_r)
        z = torch.sigmoid(i_z + h_z)
        n = torch.tanh(i_n + r * h_n)
        h_new = (1 - z) * n + z * h
        
        # Mask for actual sequence lengths
        mask = (lengths > t).unsqueeze(1).float()
        h = mask * h_new + (1 - mask) * h
    
    # 3. FC layer
    logits = F.linear(h, params['fc.weight'], params['fc.bias'])
    
    return logits

# Model config for functional forward
hidden_dim = CFG["gru_config"]["hidden_dim"]

# Training tracking
training_history = {
    "meta_iterations": [],
    "meta_train_loss": [],
    "val_accuracy": [],
    "val_iterations": [],
}

print(f"\n[CELL 07e-09] Starting meta-training...")

# Sample episodes for meta-training
train_users = episodes_train["user_id"].unique()

for meta_iter in range(num_meta_iterations):
    meta_model.train()
    meta_optimizer.zero_grad()

    # Sample meta-batch of tasks
    sampled_users = np.random.choice(train_users, size=min(meta_batch_size, len(train_users)), replace=False)

    meta_loss_total = 0.0
    valid_tasks = 0

    for user_id in sampled_users:
        # Sample one episode for this user
        user_episodes = episodes_train[episodes_train["user_id"] == user_id]
        if len(user_episodes) == 0:
            continue

        episode = user_episodes.sample(n=1).iloc[0]

        # Get support and query sets
        support_pairs, query_pairs = get_episode_data(episode, pairs_train)

        if len(support_pairs) == 0 or len(query_pairs) == 0:
            continue

        support_seq, support_labels, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
        query_seq, query_labels, query_lengths = pairs_to_batch(query_pairs, max_seq_len)

        # ===== INNER LOOP: Adapt parameters using functional approach =====
        # Clone initial meta-parameters
        fast_weights = OrderedDict()
        for name, param in meta_model.named_parameters():
            fast_weights[name] = param.clone().requires_grad_()

        # Adapt on support set
        for _ in range(num_inner_steps):
            # Functional forward with current fast_weights
            support_logits = functional_forward(
                support_seq, support_lengths, fast_weights, hidden_dim, n_items
            )
            support_loss = criterion(support_logits, support_labels)

            # Compute gradients w.r.t. fast_weights
            grads = torch.autograd.grad(
                support_loss,
                fast_weights.values(),
                create_graph=use_second_order  # FOMAML: False, MAML: True
            )

            # Update fast_weights (creates new tensors, no in-place ops)
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads)
            )

        # ===== OUTER LOOP: Compute query loss with adapted parameters =====
        query_logits = functional_forward(
            query_seq, query_lengths, fast_weights, hidden_dim, n_items
        )
        query_loss = criterion(query_logits, query_labels)

        # Accumulate for meta-update
        meta_loss_total = meta_loss_total + query_loss
        valid_tasks += 1

    if valid_tasks == 0:
        continue

    # ===== META-UPDATE =====
    meta_loss = meta_loss_total / valid_tasks
    meta_loss.backward()

    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(meta_model.parameters(), max_norm=10.0)

    meta_optimizer.step()

    # Logging
    training_history["meta_iterations"].append(meta_iter)
    training_history["meta_train_loss"].append(meta_loss.item())

    if (meta_iter + 1) % 100 == 0:
        print(f"[CELL 07e-09] Iter {meta_iter+1}/{num_meta_iterations}: meta_loss={meta_loss.item():.4f}")

    # Checkpointing
    if (meta_iter + 1) % CFG["maml_config"]["checkpoint_interval"] == 0:
        checkpoint_path = CHECKPOINTS_DIR / f"checkpoint_iter{meta_iter+1}.pth"
        torch.save({
            "meta_iter": meta_iter + 1,
            "model_state_dict": meta_model.state_dict(),
            "optimizer_state_dict": meta_optimizer.state_dict(),
            "config": CFG,
            "training_history": training_history,
        }, checkpoint_path)
        print(f"[CELL 07e-09] Saved checkpoint: {checkpoint_path.name}")

    # Validation (simpler non-functional approach for validation only)
    if (meta_iter + 1) % CFG["maml_config"]["eval_interval"] == 0:
        print(f"[CELL 07e-09] Evaluating on val set at iter {meta_iter+1}...")
        meta_model.eval()

        val_predictions = []
        val_labels = []

        for _, episode in episodes_val.head(50).iterrows():
            support_pairs, query_pairs = get_episode_data(episode, pairs_val)

            if len(support_pairs) == 0 or len(query_pairs) == 0:
                continue

            support_seq, support_labels_val, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
            query_seq, query_labels_val, query_lengths = pairs_to_batch(query_pairs, max_seq_len)

            # Save original params
            original_params = OrderedDict()
            for name, param in meta_model.named_parameters():
                original_params[name] = param.data.clone()

            # Adapt on support using standard approach (no gradients needed for validation)
            with torch.enable_grad():
                # Clone parameters for adaptation
                fast_weights_val = OrderedDict()
                for name, param in meta_model.named_parameters():
                    fast_weights_val[name] = param.clone().requires_grad_()

                # Inner loop adaptation
                for _ in range(num_inner_steps):
                    support_logits_val = functional_forward(
                        support_seq, support_lengths, fast_weights_val, hidden_dim, n_items
                    )
                    support_loss_val = criterion(support_logits_val, support_labels_val)

                    grads_val = torch.autograd.grad(
                        support_loss_val,
                        fast_weights_val.values(),
                        create_graph=False
                    )

                    fast_weights_val = OrderedDict(
                        (name, param - inner_lr * grad)
                        for ((name, param), grad) in zip(fast_weights_val.items(), grads_val)
                    )

            # Evaluate on query (no gradients)
            with torch.no_grad():
                query_logits_val = functional_forward(
                    query_seq, query_lengths, fast_weights_val, hidden_dim, n_items
                )
                query_probs = torch.softmax(query_logits_val, dim=-1).cpu().numpy()

                val_predictions.append(query_probs)
                val_labels.extend(query_labels_val.cpu().numpy())

            # Restore original params
            with torch.no_grad():
                for name, param in meta_model.named_parameters():
                    param.data.copy_(original_params[name])

        if len(val_predictions) > 0:
            val_predictions = np.vstack(val_predictions)
            val_labels = np.array(val_labels)
            val_metrics = compute_metrics(val_predictions, val_labels)

            training_history["val_accuracy"].append(val_metrics["accuracy@1"])
            training_history["val_iterations"].append(meta_iter + 1)

            print(f"[CELL 07e-09] Val Acc@1: {val_metrics['accuracy@1']:.4f}, "
                  f"Recall@5: {val_metrics['recall@5']:.4f}, MRR: {val_metrics['mrr']:.4f}")

# Save final model
final_model_path = MODELS_DIR / f"maml_warmstart_gru_K{K}.pth"
torch.save({
    "model_state_dict": meta_model.state_dict(),
    "config": CFG,
    "training_history": training_history,
}, final_model_path)

print(f"\n[CELL 07e-09] Saved final meta-model: {final_model_path.name}")
print(f"[CELL 07e-09] Total training time: {time.time()-t0:.1f}s")

cell_end("CELL 07e-09", t0)


[CELL 07e-09] MAML meta-training
[CELL 07e-09] start=2026-01-11T17:55:17
[CELL 07e-09] Meta-model parameters: 140,695
[CELL 07e-09] Using MAML (Second-Order)
[CELL 07e-09] Meta-training config:
  - Inner LR (Œ±): 0.01
  - Outer LR (Œ≤): 0.001
  - Inner steps: 5
  - Meta-batch size: 32
  - Meta-iterations: 10,000

[CELL 07e-09] Starting meta-training...
[CELL 07e-09] Iter 100/10000: meta_loss=4.1887
[CELL 07e-09] Iter 200/10000: meta_loss=3.7896
[CELL 07e-09] Iter 300/10000: meta_loss=3.6735
[CELL 07e-09] Iter 400/10000: meta_loss=3.8215
[CELL 07e-09] Iter 500/10000: meta_loss=3.8139
[CELL 07e-09] Evaluating on val set at iter 500...
[CELL 07e-09] Val Acc@1: 0.2840, Recall@5: 0.5020, MRR: 0.3884
[CELL 07e-09] Iter 600/10000: meta_loss=3.3961
[CELL 07e-09] Iter 700/10000: meta_loss=3.2494
[CELL 07e-09] Iter 800/10000: meta_loss=3.4691
[CELL 07e-09] Iter 900/10000: meta_loss=3.0621
[CELL 07e-09] Iter 1000/10000: meta_loss=3.5936
[CELL 07e-09] Saved checkpoint: checkpoint_iter1000.pth
[CE

In [None]:
# [CELL 07e-10] Meta-testing: Zero-shot (K=0) - no adaptation

t0 = cell_start("CELL 07e-10", "Zero-shot evaluation (K=0)")

print("[CELL 07e-10] Evaluating meta-learned model WITHOUT adaptation (zero-shot)...")

meta_model.eval()
zeroshot_predictions = []
zeroshot_labels = []

with torch.no_grad():  # Pure inference, no gradients needed
    for _, episode in episodes_test.iterrows():
        support_pairs, query_pairs = get_episode_data(episode, pairs_test)

        if len(query_pairs) == 0:
            continue

        # Only use query set (no support set adaptation)
        query_seq, query_labels_test, query_lengths = pairs_to_batch(query_pairs, max_seq_len)

        # Use original meta-learned model (no adaptation)
        query_logits = meta_model(query_seq, query_lengths)
        query_probs = torch.softmax(query_logits, dim=-1).cpu().numpy()

        zeroshot_predictions.append(query_probs)
        zeroshot_labels.extend(query_labels_test.cpu().numpy())

# Compute metrics
if len(zeroshot_predictions) > 0:
    zeroshot_predictions = np.vstack(zeroshot_predictions)
    zeroshot_labels = np.array(zeroshot_labels)
    zeroshot_metrics = compute_metrics(zeroshot_predictions, zeroshot_labels)

    print(f"\n[CELL 07e-10] Zero-shot Results (No Adaptation):")
    print(f"  Accuracy@1:  {zeroshot_metrics['accuracy@1']:.4f}")
    print(f"  Recall@5:    {zeroshot_metrics['recall@5']:.4f}")
    print(f"  Recall@10:   {zeroshot_metrics['recall@10']:.4f}")
    print(f"  MRR:         {zeroshot_metrics['mrr']:.4f}")
else:
    print("[CELL 07e-10] WARNING: No predictions generated")
    zeroshot_metrics = {}

cell_end("CELL 07e-10", t0)


In [None]:
# [CELL 07e-11] Meta-testing: Few-shot K=5 (with adaptation using functional forward)

t0 = cell_start("CELL 07e-11", "Few-shot evaluation (K=5)")

print("[CELL 07e-11] Evaluating meta-learned model WITH adaptation (few-shot K=5)...")

meta_model.eval()
fewshot_predictions = []
fewshot_labels = []

for _, episode in episodes_test.iterrows():
    support_pairs, query_pairs = get_episode_data(episode, pairs_test)
    
    if len(support_pairs) == 0 or len(query_pairs) == 0:
        continue
    
    support_seq, support_labels_test, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
    query_seq, query_labels_test, query_lengths = pairs_to_batch(query_pairs, max_seq_len)
    
    # Adapt using functional forward (consistent with training)
    with torch.enable_grad():
        # Clone parameters for adaptation
        fast_weights_test = OrderedDict()
        for name, param in meta_model.named_parameters():
            fast_weights_test[name] = param.clone().requires_grad_()
        
        # Inner loop adaptation
        for _ in range(num_inner_steps):
            support_logits_test = functional_forward(
                support_seq, support_lengths, fast_weights_test, hidden_dim, n_items
            )
            support_loss_test = criterion(support_logits_test, support_labels_test)
            
            grads_test = torch.autograd.grad(
                support_loss_test,
                fast_weights_test.values(),
                create_graph=False  # No second-order needed for testing
            )
            
            fast_weights_test = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights_test.items(), grads_test)
            )
    
    # Evaluate on query (no gradients)
    with torch.no_grad():
        query_logits_test = functional_forward(
            query_seq, query_lengths, fast_weights_test, hidden_dim, n_items
        )
        probs = torch.softmax(query_logits_test, dim=-1).cpu().numpy()
        
        fewshot_predictions.append(probs)
        fewshot_labels.extend(query_labels_test.cpu().numpy())

fewshot_predictions = np.vstack(fewshot_predictions)
fewshot_labels = np.array(fewshot_labels)

fewshot_metrics = compute_metrics(fewshot_predictions, fewshot_labels)

print(f"\n[CELL 07e-11] Few-shot Results (K=5 adaptation):")
print(f"  - Accuracy@1:  {fewshot_metrics['accuracy@1']:.4f}")
print(f"  - Recall@5:    {fewshot_metrics['recall@5']:.4f}")
print(f"  - Recall@10:   {fewshot_metrics['recall@10']:.4f}")
print(f"  - MRR:         {fewshot_metrics['mrr']:.4f}")

cell_end("CELL 07e-11", t0)

In [None]:
# [CELL 07e-12] Ablation Study 1: Support set size (K=1,3,5,10) - functional forward

t0 = cell_start("CELL 07e-12", "Ablation: support set size")

print("[CELL 07e-12] Ablation Study: Varying support set size K...")

support_sizes = CFG["ablation_configs"]["support_set_sizes"]
ablation_support_results = {}

meta_model.eval()

for K_test in support_sizes:
    print(f"\n[CELL 07e-12] Testing with K={K_test}...")
    
    predictions = []
    labels = []
    
    for _, episode in episodes_test.iterrows():
        support_pairs, query_pairs = get_episode_data(episode, pairs_test)
        
        if len(support_pairs) < K_test or len(query_pairs) == 0:
            continue
        
        # Use only K_test support pairs
        support_pairs_k = support_pairs.head(K_test)
        
        support_seq, support_labels_abl, support_lengths = pairs_to_batch(support_pairs_k, max_seq_len)
        query_seq, query_labels_abl, query_lengths = pairs_to_batch(query_pairs, max_seq_len)
        
        # Adapt using functional forward
        with torch.enable_grad():
            fast_weights_abl = OrderedDict()
            for name, param in meta_model.named_parameters():
                fast_weights_abl[name] = param.clone().requires_grad_()
            
            for _ in range(num_inner_steps):
                support_logits_abl = functional_forward(
                    support_seq, support_lengths, fast_weights_abl, hidden_dim, n_items
                )
                support_loss_abl = criterion(support_logits_abl, support_labels_abl)
                
                grads_abl = torch.autograd.grad(
                    support_loss_abl,
                    fast_weights_abl.values(),
                    create_graph=False
                )
                
                fast_weights_abl = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights_abl.items(), grads_abl)
                )
        
        # Evaluate on query
        with torch.no_grad():
            query_logits_abl = functional_forward(
                query_seq, query_lengths, fast_weights_abl, hidden_dim, n_items
            )
            probs = torch.softmax(query_logits_abl, dim=-1).cpu().numpy()
            
            predictions.append(probs)
            labels.extend(query_labels_abl.cpu().numpy())
    
    if len(predictions) > 0:
        predictions = np.vstack(predictions)
        labels = np.array(labels)
        metrics = compute_metrics(predictions, labels)
        ablation_support_results[K_test] = metrics
        
        print(f"[CELL 07e-12] K={K_test}: Acc@1={metrics['accuracy@1']:.4f}, "
              f"Recall@5={metrics['recall@5']:.4f}, MRR={metrics['mrr']:.4f}")

print(f"\n[CELL 07e-12] Ablation complete: tested K ‚àà {support_sizes}")

cell_end("CELL 07e-12", t0)

In [None]:
# [CELL 07e-13] Ablation Study 2: Adaptation steps (1,3,5,10) - functional forward

t0 = cell_start("CELL 07e-13", "Ablation: adaptation steps")

print("[CELL 07e-13] Ablation Study: Varying adaptation steps...")

adaptation_steps = CFG["ablation_configs"]["adaptation_steps"]
ablation_steps_results = {}

meta_model.eval()

for num_steps in adaptation_steps:
    print(f"\n[CELL 07e-13] Testing with {num_steps} adaptation steps...")
    
    predictions = []
    labels = []
    
    for _, episode in episodes_test.iterrows():
        support_pairs, query_pairs = get_episode_data(episode, pairs_test)
        
        if len(support_pairs) == 0 or len(query_pairs) == 0:
            continue
        
        support_seq, support_labels_steps, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
        query_seq, query_labels_steps, query_lengths = pairs_to_batch(query_pairs, max_seq_len)
        
        # Adapt using functional forward with varying steps
        with torch.enable_grad():
            fast_weights_steps = OrderedDict()
            for name, param in meta_model.named_parameters():
                fast_weights_steps[name] = param.clone().requires_grad_()
            
            for _ in range(num_steps):  # Use num_steps instead of num_inner_steps
                support_logits_steps = functional_forward(
                    support_seq, support_lengths, fast_weights_steps, hidden_dim, n_items
                )
                support_loss_steps = criterion(support_logits_steps, support_labels_steps)
                
                grads_steps = torch.autograd.grad(
                    support_loss_steps,
                    fast_weights_steps.values(),
                    create_graph=False
                )
                
                fast_weights_steps = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights_steps.items(), grads_steps)
                )
        
        # Evaluate on query
        with torch.no_grad():
            query_logits_steps = functional_forward(
                query_seq, query_lengths, fast_weights_steps, hidden_dim, n_items
            )
            probs = torch.softmax(query_logits_steps, dim=-1).cpu().numpy()
            
            predictions.append(probs)
            labels.extend(query_labels_steps.cpu().numpy())
    
    if len(predictions) > 0:
        predictions = np.vstack(predictions)
        labels = np.array(labels)
        metrics = compute_metrics(predictions, labels)
        ablation_steps_results[num_steps] = metrics
        
        print(f"[CELL 07e-13] Steps={num_steps}: Acc@1={metrics['accuracy@1']:.4f}, "
              f"Recall@5={metrics['recall@5']:.4f}, MRR={metrics['mrr']:.4f}")

print(f"\n[CELL 07e-13] Ablation complete: tested adaptation steps ‚àà {adaptation_steps}")

cell_end("CELL 07e-13", t0)

In [None]:
# [CELL 07e-14] Analysis: Parameter update visualization - functional forward

t0 = cell_start("CELL 07e-14", "Parameter update analysis")

print("[CELL 07e-14] Analyzing parameter updates during adaptation...")

# Select one test episode for analysis
sample_episode = episodes_test.iloc[0]
support_pairs, query_pairs = get_episode_data(sample_episode, pairs_test)

if len(support_pairs) > 0:
    support_seq, support_labels_viz, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
    
    # NOTE: Do NOT call meta_model.eval() here - we need gradients for functional_forward
    # The functional forward approach doesn't use the model's forward(), so eval mode doesn't matter
    
    # Get original parameters (before adaptation)
    param_norms_before = {}
    for name, param in meta_model.named_parameters():
        param_norms_before[name] = param.data.norm().item()
    
    # Adapt using functional forward
    with torch.enable_grad():
        fast_weights_viz = OrderedDict()
        for name, param in meta_model.named_parameters():
            fast_weights_viz[name] = param.clone().requires_grad_()
        
        for _ in range(num_inner_steps):
            support_logits_viz = functional_forward(
                support_seq, support_lengths, fast_weights_viz, hidden_dim, n_items
            )
            support_loss_viz = criterion(support_logits_viz, support_labels_viz)
            
            grads_viz = torch.autograd.grad(
                support_loss_viz,
                fast_weights_viz.values(),
                create_graph=False
            )
            
            fast_weights_viz = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights_viz.items(), grads_viz)
            )
    
    # Compute parameter changes
    param_norms_after = {}
    param_changes = {}
    
    for name in fast_weights_viz.keys():
        adapted_norm = fast_weights_viz[name].data.norm().item()
        original_norm = param_norms_before[name]
        change = adapted_norm - original_norm
        
        param_norms_after[name] = adapted_norm
        param_changes[name] = {
            "before": original_norm,
            "after": adapted_norm,
            "change": change,
            "change_pct": (change / original_norm * 100) if original_norm > 0 else 0,
        }
    
    print(f"\n[CELL 07e-14] Parameter changes after {num_inner_steps} adaptation steps:")
    print(f"{'Parameter':<30} {'Before':>12} {'After':>12} {'Change':>12} {'Change %':>10}")
    print("-" * 80)
    
    for name, stats in list(param_changes.items())[:10]:  # Show first 10
        print(f"{name:<30} {stats['before']:>12.4f} {stats['after']:>12.4f} "
              f"{stats['change']:>12.4f} {stats['change_pct']:>9.2f}%")
    
    # Visualization: parameter change distribution
    VIZ_DIR = OUT_DIR / "visualizations"
    VIZ_DIR.mkdir(exist_ok=True)
    
    sns.set_style("whitegrid")
    fig, ax = plt.subplots(figsize=(10, 6))
    
    change_pcts = [stats["change_pct"] for stats in param_changes.values()]
    ax.hist(change_pcts, bins=30, color='#3498db', alpha=0.7, edgecolor='black')
    ax.axvline(0, color='red', linestyle='--', linewidth=2, label='No change')
    ax.set_xlabel('Parameter Change (%)', fontsize=11, fontweight='bold')
    ax.set_ylabel('Number of Parameters', fontsize=11, fontweight='bold')
    ax.set_title('Parameter Change Distribution After Adaptation (MAML)', fontsize=12, fontweight='bold')
    ax.legend(loc='upper right', fontsize=10)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(VIZ_DIR / "param_change_distribution.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"\n[CELL 07e-14] Saved: param_change_distribution.png")

cell_end("CELL 07e-14", t0)

In [None]:
# [CELL 07e-14] Analysis: Parameter update visualization

t0 = cell_start("CELL 07e-14", "Parameter update analysis")

print("[CELL 07e-14] Analyzing parameter updates during adaptation...")

# Select one test episode for analysis
sample_episode = episodes_test.iloc[0]
support_pairs, query_pairs = get_episode_data(sample_episode, pairs_test)

if len(support_pairs) > 0:
    support_seq, support_labels, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
    
    # Track parameter changes during inner loop
    param_norms_before = {}
    param_norms_after = {}
    param_changes = {}
    
    meta_model.eval()
    
    # Before adaptation
    for name, param in meta_model.named_parameters():
        param_norms_before[name] = param.data.norm().item()
    
    # Save original params
    original_params = [param.clone() for param in meta_model.parameters()]
    
    # Adapt
    inner_optimizer = torch.optim.SGD(meta_model.parameters(), lr=inner_lr)
    for inner_step in range(num_inner_steps):
        inner_optimizer.zero_grad()
        support_logits = meta_model(support_seq, support_lengths)
        support_loss = criterion(support_logits, support_labels)
        support_loss.backward()
        inner_optimizer.step()
    
    # After adaptation - collect adapted parameters
    adapted_params = {}
    for name, param in meta_model.named_parameters():
        adapted_params[name] = param.clone().detach()
        param_norms_after[name] = param.data.norm().item()
        original_norm = param_norms_before[name]
        change = param_norms_after[name] - original_norm
        param_changes[name] = {
            "before": original_norm,
            "after": param_norms_after[name],
            "change": change,
            "change_pct": (change / original_norm * 100) if original_norm > 0 else 0,
        }
    
    # Restore params
    for param, orig_param in zip(meta_model.parameters(), original_params):
        param.copy_(orig_param)
    
    print(f"\n[CELL 07e-14] Parameter changes after {num_inner_steps} adaptation steps:")
    print(f"{'Parameter':<30} {'Before':>12} {'After':>12} {'Change':>12} {'Change %':>10}")
    print("-" * 80)
    
    for name, stats in list(param_changes.items())[:10]:  # Show first 10
        print(f"{name:<30} {stats['before']:>12.4f} {stats['after']:>12.4f} "
              f"{stats['change']:>12.4f} {stats['change_pct']:>9.2f}%")
    
    # Visualization: parameter change distribution
    VIZ_DIR = OUT_DIR / "visualizations"
    VIZ_DIR.mkdir(exist_ok=True)
    
    sns.set_style("whitegrid")
    fig, ax = plt.subplots(figsize=(10, 6))
    
    change_pcts = [stats["change_pct"] for stats in param_changes.values()]
    ax.hist(change_pcts, bins=30, color='#3498db', alpha=0.7, edgecolor='black')
    ax.axvline(0, color='red', linestyle='--', linewidth=2, label='No change')
    ax.set_xlabel('Parameter Change (%)', fontsize=11, fontweight='bold')
    ax.set_ylabel('Number of Parameters', fontsize=11, fontweight='bold')
    ax.set_title('Parameter Change Distribution After Adaptation', fontsize=12, fontweight='bold')
    ax.legend(loc='upper right', fontsize=10)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(VIZ_DIR / "param_change_distribution.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"\n[CELL 07e-14] Saved: param_change_distribution.png")

cell_end("CELL 07e-14", t0)

In [None]:
# [CELL 07e-15] Results summary table + comparison with baselines

t0 = cell_start("CELL 07e-15", "Results summary")

print("\n[CELL 07e-15] ========== RESULTS SUMMARY (Test Set) ==========")
print(f"K={K}, Q={Q} | Test Episodes: {len(episodes_test):,}\n")

# Load GRU baseline results for comparison
baseline_results_path = RESULTS_DIR / f"baselines_K{K}_Q{Q}.json"
if baseline_results_path.exists():
    baseline_results = read_json(baseline_results_path)
    gru_baseline_metrics = baseline_results["baselines"]["gru_global"]
else:
    gru_baseline_metrics = {"accuracy@1": 0.3373, "recall@5": 0.5590, "recall@10": 0.6575, "mrr": 0.4437}

# Create comparison table
print(f"{'Model':<30} {'Acc@1':>10} {'Recall@5':>10} {'Recall@10':>10} {'MRR':>10}")
print("-" * 72)

# Baselines
print(f"{'GRU (Baseline - 06)':<30} {gru_baseline_metrics['accuracy@1']:>10.4f} "
      f"{gru_baseline_metrics['recall@5']:>10.4f} {gru_baseline_metrics['recall@10']:>10.4f} "
      f"{gru_baseline_metrics['mrr']:>10.4f}")

# MAML results
print(f"{'MAML Zero-shot':<30} {zeroshot_metrics['accuracy@1']:>10.4f} "
      f"{zeroshot_metrics['recall@5']:>10.4f} {zeroshot_metrics['recall@10']:>10.4f} "
      f"{zeroshot_metrics['mrr']:>10.4f}")

print(f"{'MAML Few-shot (K=5)':<30} {fewshot_metrics['accuracy@1']:>10.4f} "
      f"{fewshot_metrics['recall@5']:>10.4f} {fewshot_metrics['recall@10']:>10.4f} "
      f"{fewshot_metrics['mrr']:>10.4f}")

# Improvement over baseline
improvement = (fewshot_metrics['accuracy@1'] - gru_baseline_metrics['accuracy@1']) / gru_baseline_metrics['accuracy@1'] * 100
print(f"\n[CELL 07e-15] MAML Few-shot improvement over GRU baseline: {improvement:+.2f}%")

# Ablation results
print(f"\n[CELL 07e-15] ========== ABLATION STUDY 1: Support Set Size ==========")
print(f"{'K (Support Size)':<20} {'Acc@1':>10} {'Recall@5':>10} {'Recall@10':>10} {'MRR':>10}")
print("-" * 62)
for K_test, metrics in ablation_support_results.items():
    print(f"{K_test:<20} {metrics['accuracy@1']:>10.4f} {metrics['recall@5']:>10.4f} "
          f"{metrics['recall@10']:>10.4f} {metrics['mrr']:>10.4f}")

print(f"\n[CELL 07e-15] ========== ABLATION STUDY 2: Adaptation Steps ==========")
print(f"{'Adaptation Steps':<20} {'Acc@1':>10} {'Recall@5':>10} {'Recall@10':>10} {'MRR':>10}")
print("-" * 62)
for num_steps, metrics in ablation_steps_results.items():
    print(f"{num_steps:<20} {metrics['accuracy@1']:>10.4f} {metrics['recall@5']:>10.4f} "
          f"{metrics['recall@10']:>10.4f} {metrics['mrr']:>10.4f}")

# Save all results
all_results = {
    "run_id": RUN_ID,
    "k_shot_config": {"K": K, "Q": Q},
    "n_test_episodes": len(episodes_test),
    "baseline": {
        "gru_global": gru_baseline_metrics,
    },
    "maml": {
        "zero_shot": zeroshot_metrics,
        "few_shot_K5": fewshot_metrics,
    },
    "ablation_support_size": ablation_support_results,
    "ablation_adaptation_steps": ablation_steps_results,
    "improvement_over_baseline_pct": improvement,
    "training_history": training_history,
}

results_path = Path(CFG["outputs"]["results"])
write_json_atomic(results_path, all_results)
print(f"\n[CELL 07e-15] Saved: {results_path.name}")

cell_end("CELL 07e-15", t0)

In [None]:
# [CELL 07e-16] Update report + manifest

t0 = cell_start("CELL 07e-16", "Write report + manifest")

report = read_json(REPORT_PATH)
manifest = read_json(MANIFEST_PATH)

# Metrics
report["metrics"] = {
    "n_test_episodes": len(episodes_test),
    "gru_baseline_acc1": gru_baseline_metrics['accuracy@1'],
    "maml_zero_shot_acc1": zeroshot_metrics['accuracy@1'],
    "maml_few_shot_K5_acc1": fewshot_metrics['accuracy@1'],
    "improvement_over_baseline_pct": improvement,
    "training_iterations": num_meta_iterations,
}

# Key findings
report["key_findings"].extend([
    f"MAML meta-training: {num_meta_iterations:,} iterations with {meta_batch_size} tasks/batch",
    f"Zero-shot performance (no adaptation): Acc@1={zeroshot_metrics['accuracy@1']:.4f}",
    f"Few-shot performance (K=5 adaptation): Acc@1={fewshot_metrics['accuracy@1']:.4f}",
    f"Improvement over GRU baseline: {improvement:+.2f}% ({fewshot_metrics['accuracy@1']:.4f} vs {gru_baseline_metrics['accuracy@1']:.4f})",
    f"Ablation: Best K={max(ablation_support_results, key=lambda k: ablation_support_results[k]['accuracy@1'])} "
    f"(Acc@1={max(ablation_support_results.values(), key=lambda m: m['accuracy@1'])['accuracy@1']:.4f})",
    f"Ablation: Best adaptation steps={max(ablation_steps_results, key=lambda k: ablation_steps_results[k]['accuracy@1'])} "
    f"(Acc@1={max(ablation_steps_results.values(), key=lambda m: m['accuracy@1'])['accuracy@1']:.4f})",
])

# Sanity samples
report["sanity_samples"]["maml_config"] = CFG["maml_config"]
report["sanity_samples"]["sample_episode"] = {
    "episode_id": int(episodes_test.iloc[0]["episode_id"]),
    "user_id": str(episodes_test.iloc[0]["user_id"]),
    "n_support_pairs": len(episodes_test.iloc[0]["support_pair_ids"]),
    "n_query_pairs": len(episodes_test.iloc[0]["query_pair_ids"]),
}

# Fingerprints
report["data_fingerprints"]["meta_model"] = {
    "path": str(final_model_path),
    "bytes": int(final_model_path.stat().st_size),
    "sha256": sha256_file(final_model_path),
}

write_json_atomic(REPORT_PATH, report)

# Manifest
def add_artifact(path: Path) -> None:
    rec = {"path": str(path), "bytes": int(path.stat().st_size), "sha256": None, "sha256_error": None}
    try:
        rec["sha256"] = sha256_file(path)
    except Exception as e:
        rec["sha256_error"] = str(e)
    manifest["artifacts"].append(rec)

add_artifact(final_model_path)
add_artifact(results_path)

# Add checkpoints
for checkpoint_file in sorted(CHECKPOINTS_DIR.glob("checkpoint_iter*.pth")):
    add_artifact(checkpoint_file)

write_json_atomic(MANIFEST_PATH, manifest)

print(f"[CELL 07e-16] Updated: {REPORT_PATH}")
print(f"[CELL 07e-16] Updated: {MANIFEST_PATH}")

cell_end("CELL 07e-16", t0)

print(f"\n{'='*80}")
print(f"‚úÖ NOTEBOOK 07 COMPLETE")
print(f"{'='*80}")
print(f"\nüìä Key Results:")
print(f"  - GRU Baseline (06):        {gru_baseline_metrics['accuracy@1']:.4f} Acc@1")
print(f"  - MAML Zero-shot:           {zeroshot_metrics['accuracy@1']:.4f} Acc@1")
print(f"  - MAML Few-shot (K=5):      {fewshot_metrics['accuracy@1']:.4f} Acc@1")
print(f"  - Improvement:              {improvement:+.2f}%")
print(f"\nüìÅ Outputs:")
print(f"  - Meta-model: {final_model_path}")
print(f"  - Results:    {results_path}")
print(f"  - Report:     {REPORT_PATH}")
print(f"\nüéØ Next Steps:")
print(f"  - Fine-tune hyperparameters (Œ±, Œ≤, inner steps)")
print(f"  - Try different architectures (Transformer, GNN)")
print(f"  - Compare with other meta-learning methods (ProtoNet, Matching Networks)")

## ‚úÖ Notebook 07e Complete: MAML with Warm-Start\n
\n
**Key Innovation:** Initialized MAML from pre-trained GRU baseline (33.73%)\n
\n
**Results:**\n
- See CELL 07e-10 for zero-shot results\n
- See CELL 07e-11 for few-shot results (K=5)\n
- See CELL 07e-15 for complete comparison\n
\n
**Expected Outcome:**\n
- MAML warm-start should beat both:\n
  - MAML random init (30.52%)\n
  - GRU baseline (33.73%)\n
- Target: 35-38% Acc@1\n
\n
**Why This Works:**\n
1. Strong initialization (GRU) + Meta-learned adaptation (MAML)\n
2. Best of both worlds\n
3. Well-established approach in meta-learning literature\n
\n
**Next Steps:**\n
- If beats baseline ‚Üí Thesis complete! üéâ\n
- If close but not quite ‚Üí Try bigger model (Fix #4)\n
- Analyze learned adaptations vs random init MAML