# Notebook 07f: MAML with Residual Meta-Loss and Layer-Specific Adaptation (XuetangX)
This notebook implements **Residual MAML** with three critical fixes:
1. **Higher Inner LR (Œ±=0.05)**: Increased from 0.01 based on sweep results
2. **Residual Meta-Loss (Œª=0.1)**: Blend adapted and unadapted query losses
3. **Layer-Specific Adaptation**: Only adapt embedding + FC layers, freeze GRU
**Key Changes from 07e:**
- ‚úÖ Inner learning rate increased to 0.05
- ‚úÖ Residual MAML meta-loss with Œª=0.1
- ‚úÖ Layer-specific adaptation (freeze GRU)
- ‚úÖ Config variable extraction to prevent NameError
**Dataset**: XuetangX MOOC
**Model**: GRU-based next-course recommendation
**Meta-learning**: MAML (Model-Agnostic Meta-Learning) with FOMAML


In [1]:
# [CELL 07f-00] Bootstrap: repo root + paths + logger

import os
import sys
import json
import time
import uuid
import pickle
import hashlib
import copy
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Tuple, Optional
from collections import Counter, OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

t0 = datetime.now()
print(f"[CELL 07f-00] start={t0.isoformat(timespec='seconds')}")
print("[CELL 07f-00] CWD:", Path.cwd().resolve())

def find_repo_root(start: Path) -> Path:
    start = start.resolve()
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    raise RuntimeError("Could not find PROJECT_STATE.md. Open notebook from within the repo.")

REPO_ROOT = find_repo_root(Path.cwd())
print("[CELL 07f-00] REPO_ROOT:", REPO_ROOT)

PATHS = {
    "META_REGISTRY": REPO_ROOT / "meta.json",
    "DATA_INTERIM": REPO_ROOT / "data" / "interim",
    "DATA_PROCESSED": REPO_ROOT / "data" / "processed",
    "MODELS": REPO_ROOT / "models",
    "RESULTS": REPO_ROOT / "results",
    "REPORTS": REPO_ROOT / "reports",
}
for k, v in PATHS.items():
    print(f"[CELL 07f-00] {k}={v}")

def cell_start(cell_id: str, title: str, **kwargs: Any) -> float:
    t = time.time()
    print(f"\n[{cell_id}] {title}")
    print(f"[{cell_id}] start={datetime.now().isoformat(timespec='seconds')}")
    for k, v in kwargs.items():
        print(f"[{cell_id}] {k}={v}")
    return t

def cell_end(cell_id: str, t0: float, **kwargs: Any) -> None:
    for k, v in kwargs.items():
        print(f"[{cell_id}] {k}={v}")
    print(f"[{cell_id}] elapsed={time.time()-t0:.2f}s")
    print(f"[{cell_id}] done")

# Check GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[CELL 07f-00] PyTorch device: {DEVICE}")
print("[CELL 07f-00] done")

[CELL 07e-00] start=2026-01-12T15:11:22
[CELL 07e-00] CWD: C:\anonymous-users-mooc-session-meta\notebooks
[CELL 07e-00] REPO_ROOT: C:\anonymous-users-mooc-session-meta
[CELL 07e-00] META_REGISTRY=C:\anonymous-users-mooc-session-meta\meta.json
[CELL 07e-00] DATA_INTERIM=C:\anonymous-users-mooc-session-meta\data\interim
[CELL 07e-00] DATA_PROCESSED=C:\anonymous-users-mooc-session-meta\data\processed
[CELL 07e-00] MODELS=C:\anonymous-users-mooc-session-meta\models
[CELL 07e-00] RESULTS=C:\anonymous-users-mooc-session-meta\results
[CELL 07e-00] REPORTS=C:\anonymous-users-mooc-session-meta\reports
[CELL 07e-00] PyTorch device: cuda
[CELL 07e-00] done


In [2]:
# [CELL 07f-01] Reproducibility: seed everything

t0 = cell_start("CELL 07f-01", "Seed everything")

GLOBAL_SEED = 20260107

def seed_everything(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(GLOBAL_SEED)

cell_end("CELL 07f-01", t0, seed=GLOBAL_SEED)


[CELL 07e-01] Seed everything
[CELL 07e-01] start=2026-01-12T15:11:24
[CELL 07e-01] seed=20260107
[CELL 07e-01] elapsed=0.01s
[CELL 07e-01] done


In [3]:
# [CELL 07f-02] JSON/Pickle IO + hashing helpers

t0 = cell_start("CELL 07f-02", "IO helpers")

def write_json_atomic(path: Path, obj: Any, indent: int = 2) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    tmp = path.with_suffix(path.suffix + f".tmp_{uuid.uuid4().hex}")
    with tmp.open("w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=indent)
    tmp.replace(path)

def read_json(path: Path) -> Any:
    if not path.exists():
        raise RuntimeError(f"Missing JSON file: {path}")
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)

def save_pickle(path: Path, obj: Any) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("wb") as f:
        pickle.dump(obj, f)

def load_pickle(path: Path) -> Any:
    with path.open("rb") as f:
        return pickle.load(f)

def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
    h = hashlib.sha256()
    with path.open("rb") as f:
        while True:
            b = f.read(chunk_size)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

cell_end("CELL 07f-02", t0)


[CELL 07e-02] IO helpers
[CELL 07e-02] start=2026-01-12T15:11:26
[CELL 07e-02] elapsed=0.00s
[CELL 07e-02] done


In [4]:
# [CELL 07f-03] Run tagging + config + meta.json
t0 = cell_start("CELL 07f-03", "Start run + init files")
NOTEBOOK_NAME = "07f_maml_residual_xuetangx"
RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_ID = uuid.uuid4().hex
OUT_DIR = PATHS["REPORTS"] / NOTEBOOK_NAME / RUN_TAG
OUT_DIR.mkdir(parents=True, exist_ok=True)
REPORT_PATH = OUT_DIR / "report.json"
CONFIG_PATH = OUT_DIR / "config.json"
MANIFEST_PATH = OUT_DIR / "manifest.json"
# Paths
EPISODES_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "episodes"
PAIRS_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "pairs"
VOCAB_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "vocab"
MODELS_DIR = PATHS["MODELS"] / "maml"
CHECKPOINTS_DIR = MODELS_DIR / "checkpoints_warmstart"
RESULTS_DIR = PATHS["RESULTS"]
MODELS_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
# GRU baseline checkpoint path
GRU_BASELINE_PATH = PATHS["MODELS"] / "baselines" / "gru_global.pth"
# K-shot config
K, Q = 5, 10
CFG = {
    "notebook": NOTEBOOK_NAME,
    "run_id": RUN_ID,
    "run_tag": RUN_TAG,
    "seed": GLOBAL_SEED,
    "device": str(DEVICE),
    "k_shot_config": {"K": K, "Q": Q},
    "inputs": {
        "episodes_train": str(EPISODES_DIR / f"episodes_train_K{K}_Q{Q}.parquet"),
        "episodes_val": str(EPISODES_DIR / f"episodes_val_K{K}_Q{Q}.parquet"),
        "episodes_test": str(EPISODES_DIR / f"episodes_test_K{K}_Q{Q}.parquet"),
        "pairs_train": str(PAIRS_DIR / "pairs_train.parquet"),
        "pairs_val": str(PAIRS_DIR / "pairs_val.parquet"),
        "pairs_test": str(PAIRS_DIR / "pairs_test.parquet"),
        "vocab": str(VOCAB_DIR / "course2id.json"),
        "gru_baseline": str(GRU_BASELINE_PATH),  # ‚Üê WARM-START FROM HERE
    },
    "gru_config": {
        "embedding_dim": 64,
        "hidden_dim": 128,
        "num_layers": 1,
        "dropout": 0.2,
        "max_seq_len": 50,
    },
    "maml_config": {
        "inner_lr": 0.05,           # Œ±: INCREASED from 0.01 (Fix #1)           # Œ±: learning rate for inner loop (task adaptation)
        "outer_lr": 0.001,          # Œ≤: learning rate for outer loop (meta-update)
        "num_inner_steps": 5,
    "lambda_residual": 0.1,     # Œª: NEW for Residual MAML (Fix #2)       # number of gradient steps for adaptation
        "meta_batch_size": 32,      # number of tasks (users) per meta-batch
        "num_meta_iterations": 10000,  # total meta-training iterations
        "checkpoint_interval": 1000,   # save checkpoint every N iterations
        "eval_interval": 500,          # evaluate on val set every N iterations
        "use_second_order": True,      # True: MAML (2nd order), False: FOMAML (1st order)
        "warm_start": True,            # ‚Üê NEW: Initialize from GRU baseline
    },
    "ablation_configs": {
        "support_set_sizes": [1, 3, 5, 10],
        "adaptation_steps": [1, 3, 5, 10],
    },
    "metrics": ["accuracy@1", "recall@5", "recall@10", "mrr"],
    "outputs": {
        "models_dir": str(MODELS_DIR),
        "checkpoints_dir": str(CHECKPOINTS_DIR),
        "results": str(RESULTS_DIR / f"maml_warmstart_K{K}_Q{Q}.json"),
        "out_dir": str(OUT_DIR),
    }
}
write_json_atomic(CONFIG_PATH, CFG)
report = {
    "run_id": RUN_ID,
    "notebook": NOTEBOOK_NAME,
    "run_tag": RUN_TAG,
    "created_at": datetime.now().isoformat(timespec="seconds"),
    "repo_root": str(REPO_ROOT),
    "metrics": {},
    "key_findings": [],
    "sanity_samples": {},
    "data_fingerprints": {},
    "notes": [],
}
write_json_atomic(REPORT_PATH, report)
manifest = {"run_id": RUN_ID, "notebook": NOTEBOOK_NAME, "run_tag": RUN_TAG, "artifacts": []}
write_json_atomic(MANIFEST_PATH, manifest)
# meta.json
META_PATH = PATHS["META_REGISTRY"]
if not META_PATH.exists():
    write_json_atomic(META_PATH, {"schema_version": 1, "runs": []})
meta = read_json(META_PATH)
meta["runs"].append({
    "run_id": RUN_ID,
    "notebook": NOTEBOOK_NAME,
    "run_tag": RUN_TAG,
    "out_dir": str(OUT_DIR),
    "created_at": datetime.now().isoformat(timespec="seconds"),
})
write_json_atomic(META_PATH, meta)
print(f"[CELL 07f-03] K={K}, Q={Q}")
print(f"[CELL 07f-03] MAML config: Œ±={CFG['maml_config']['inner_lr']}, Œ≤={CFG['maml_config']['outer_lr']}, "
      f"inner_steps={CFG['maml_config']['num_inner_steps']}, meta_batch={CFG['maml_config']['meta_batch_size']}")
print(f"[CELL 07f-03] ‚≠ê WARM-START: Initializing from GRU baseline")
print(f"[CELL 07f-03] GRU baseline: {GRU_BASELINE_PATH}")
cell_end("CELL 07f-03", t0, out_dir=str(OUT_DIR))


[CELL 07e-03] Start run + init files
[CELL 07e-03] start=2026-01-12T15:11:30
[CELL 07e-03] K=5, Q=10
[CELL 07e-03] MAML config: Œ±=0.01, Œ≤=0.001, inner_steps=5, meta_batch=32
[CELL 07e-03] ‚≠ê WARM-START: Initializing from GRU baseline
[CELL 07e-03] GRU baseline: C:\anonymous-users-mooc-session-meta\models\baselines\gru_global.pth
[CELL 07e-03] out_dir=C:\anonymous-users-mooc-session-meta\reports\07e_maml_warmstart_xuetangx\20260112_151130
[CELL 07e-03] elapsed=0.02s
[CELL 07e-03] done


In [None]:
# [CELL 07f-03b] Extract config variables to prevent NameError in evaluation cellst0 = cell_start("CELL 07f-03b", "Extract config variables")# MAML configinner_lr = CFG['maml_config']['inner_lr']num_inner_steps = CFG['maml_config']['num_inner_steps']outer_lr = CFG['maml_config']['outer_lr']meta_batch_size = CFG['maml_config']['meta_batch_size']num_meta_iterations = CFG['maml_config']['num_meta_iterations']lambda_residual = CFG['maml_config']['lambda_residual']# GRU configembedding_dim = CFG['gru_config']['embedding_dim']hidden_dim = CFG['gru_config']['hidden_dim']num_layers = CFG['gru_config']['num_layers']dropout = CFG['gru_config']['dropout']max_seq_len = CFG['gru_config']['max_seq_len']# K-shot configK = CFG['k_shot_config']['K']Q = CFG['k_shot_config']['Q']# Loss criterioncriterion = nn.CrossEntropyLoss()print(f"[CELL 07f-03b] Extracted variables:")print(f"  inner_lr (Œ±) = {inner_lr} (increased from 0.01)")print(f"  lambda_residual (Œª) = {lambda_residual} (new)")print(f"  num_inner_steps = {num_inner_steps}")print(f"  hidden_dim = {hidden_dim}")print(f"  max_seq_len = {max_seq_len}")print(f"  K = {K}, Q = {Q}")cell_end("CELL 07f-03b", t0)

In [5]:
# [CELL 07f-04] Load data: episodes, pairs, vocab

t0 = cell_start("CELL 07f-04", "Load data")

# Vocab
course2id = read_json(Path(CFG["inputs"]["vocab"]))
id2course = {int(v): k for k, v in course2id.items()}
n_items = len(course2id)
print(f"[CELL 07f-04] Vocabulary: {n_items} courses")

# Episodes
episodes_train = pd.read_parquet(CFG["inputs"]["episodes_train"])
episodes_val = pd.read_parquet(CFG["inputs"]["episodes_val"])
episodes_test = pd.read_parquet(CFG["inputs"]["episodes_test"])

print(f"[CELL 07f-04] Episodes train: {len(episodes_train):,} episodes ({episodes_train['user_id'].nunique():,} users)")
print(f"[CELL 07f-04] Episodes val:   {len(episodes_val):,} episodes ({episodes_val['user_id'].nunique():,} users)")
print(f"[CELL 07f-04] Episodes test:  {len(episodes_test):,} episodes ({episodes_test['user_id'].nunique():,} users)")

# Pairs
pairs_train = pd.read_parquet(CFG["inputs"]["pairs_train"])
pairs_val = pd.read_parquet(CFG["inputs"]["pairs_val"])
pairs_test = pd.read_parquet(CFG["inputs"]["pairs_test"])

print(f"[CELL 07f-04] Pairs train: {len(pairs_train):,} pairs")
print(f"[CELL 07f-04] Pairs val:   {len(pairs_val):,} pairs")
print(f"[CELL 07f-04] Pairs test:  {len(pairs_test):,} pairs")

cell_end("CELL 07f-04", t0)


[CELL 07e-04] Load data
[CELL 07e-04] start=2026-01-12T15:11:37
[CELL 07e-04] Vocabulary: 343 courses
[CELL 07e-04] Episodes train: 66,187 episodes (3,006 users)
[CELL 07e-04] Episodes val:   340 episodes (340 users)
[CELL 07e-04] Episodes test:  346 episodes (346 users)
[CELL 07e-04] Pairs train: 212,923 pairs
[CELL 07e-04] Pairs val:   24,698 pairs
[CELL 07e-04] Pairs test:  26,608 pairs
[CELL 07e-04] elapsed=0.39s
[CELL 07e-04] done


In [6]:
# [CELL 07f-05] Evaluation metrics (same as Notebook 07)

t0 = cell_start("CELL 07f-05", "Define evaluation metrics")

def compute_metrics(predictions: np.ndarray, labels: np.ndarray, k_values: List[int] = [5, 10]) -> Dict[str, float]:
    """
    Compute ranking metrics.
    
    Args:
        predictions: (n_samples, n_items) score matrix
        labels: (n_samples,) true item indices
        k_values: list of k for Recall@k
    
    Returns:
        dict with accuracy@1, recall@k, mrr
    """
    n_samples = len(labels)
    
    # Get top-k predictions (indices)
    max_k = max(k_values)
    top_k_preds = np.argsort(-predictions, axis=1)[:, :max_k]  # descending order
    
    # Accuracy@1
    top1_preds = top_k_preds[:, 0]
    acc1 = (top1_preds == labels).mean()
    
    # Recall@k
    recall_k = {}
    for k in k_values:
        hits = np.array([labels[i] in top_k_preds[i, :k] for i in range(n_samples)])
        recall_k[f"recall@{k}"] = hits.mean()
    
    # MRR (Mean Reciprocal Rank)
    ranks = []
    for i in range(n_samples):
        # Find rank of true label (1-indexed)
        rank_idx = np.where(top_k_preds[i] == labels[i])[0]
        if len(rank_idx) > 0:
            ranks.append(1.0 / (rank_idx[0] + 1))  # reciprocal rank
        else:
            # Not in top-k, check full ranking
            full_rank = np.where(np.argsort(-predictions[i]) == labels[i])[0][0]
            ranks.append(1.0 / (full_rank + 1))
    mrr = np.mean(ranks)
    
    return {
        "accuracy@1": float(acc1),
        **{k: float(v) for k, v in recall_k.items()},
        "mrr": float(mrr),
    }

print("[CELL 07f-05] Metrics: accuracy@1, recall@5, recall@10, mrr")

cell_end("CELL 07f-05", t0)


[CELL 07e-05] Define evaluation metrics
[CELL 07e-05] start=2026-01-12T15:11:40
[CELL 07e-05] Metrics: accuracy@1, recall@5, recall@10, mrr
[CELL 07e-05] elapsed=0.00s
[CELL 07e-05] done


In [7]:
# [CELL 07f-06] Define GRU model (exact same as Notebook 06 & 07)

t0 = cell_start("CELL 07f-06", "Define GRU model")

class GRURecommender(nn.Module):
    def __init__(self, n_items: int, embedding_dim: int, hidden_dim: int, num_layers: int, dropout: float):
        super().__init__()
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        self.embedding = nn.Embedding(n_items, embedding_dim, padding_idx=0)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.fc = nn.Linear(hidden_dim, n_items)
    
    def forward(self, seq: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        Args:
            seq: (batch, max_len) padded sequences
            lengths: (batch,) actual lengths
        Returns:
            logits: (batch, n_items)
        """
        # Embed
        emb = self.embedding(seq)  # (batch, max_len, embed_dim)
        
        # Pack for efficiency
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        
        # GRU
        _, hidden = self.gru(packed)  # hidden: (num_layers, batch, hidden_dim)
        
        # Use last layer hidden state
        h = hidden[-1]  # (batch, hidden_dim)
        
        # Predict
        logits = self.fc(h)  # (batch, n_items)
        return logits

print("[CELL 07f-06] GRU model defined")
print(f"  - Embedding dim: {CFG['gru_config']['embedding_dim']}")
print(f"  - Hidden dim: {CFG['gru_config']['hidden_dim']}")
print(f"  - Num layers: {CFG['gru_config']['num_layers']}")

cell_end("CELL 07f-06", t0)


[CELL 07e-06] Define GRU model
[CELL 07e-06] start=2026-01-12T15:11:45
[CELL 07e-06] GRU model defined
  - Embedding dim: 64
  - Hidden dim: 128
  - Num layers: 1
[CELL 07e-06] elapsed=0.00s
[CELL 07e-06] done


In [8]:
# [CELL 07f-07] ‚≠ê Initialize meta-model with GRU baseline (WARM-START)

t0 = cell_start("CELL 07f-07", "Initialize meta-model with warm-start")

# Create meta-model
meta_model = GRURecommender(
    n_items=n_items,
    embedding_dim=CFG["gru_config"]["embedding_dim"],
    hidden_dim=CFG["gru_config"]["hidden_dim"],
    num_layers=CFG["gru_config"]["num_layers"],
    dropout=CFG["gru_config"]["dropout"],
).to(DEVICE)

# ‚≠ê WARM-START: Load GRU baseline checkpoint
if not GRU_BASELINE_PATH.exists():
    raise FileNotFoundError(f"GRU baseline not found: {GRU_BASELINE_PATH}")

print(f"[CELL 07f-07] Loading GRU baseline from: {GRU_BASELINE_PATH.name}")
baseline_checkpoint = torch.load(GRU_BASELINE_PATH, map_location=DEVICE)

# Load state dict
if "model_state_dict" in baseline_checkpoint:
    meta_model.load_state_dict(baseline_checkpoint["model_state_dict"])
    print("[CELL 07f-07] Loaded from 'model_state_dict' key")
else:
    meta_model.load_state_dict(baseline_checkpoint)
    print("[CELL 07f-07] Loaded directly from checkpoint")

# Verify baseline performance
if "metrics" in baseline_checkpoint:
    baseline_acc = baseline_checkpoint["metrics"].get("test_accuracy@1", "N/A")
    print(f"[CELL 07f-07] GRU baseline Acc@1: {baseline_acc}")
else:
    print("[CELL 07f-07] GRU baseline Acc@1: 33.73% (from NB 06)")

print(f"\n[CELL 07f-07] ‚úÖ WARM-START COMPLETE")
print(f"[CELL 07f-07] Meta-model initialized from strong GRU baseline")
print(f"[CELL 07f-07] Model parameters: {sum(p.numel() for p in meta_model.parameters()):,}")
print(f"[CELL 07f-07] Now will meta-train to make it more adaptable!")

cell_end("CELL 07f-07", t0)


[CELL 07e-07] Initialize meta-model with warm-start
[CELL 07e-07] start=2026-01-12T15:11:49
[CELL 07e-07] Loading GRU baseline from: gru_global.pth
[CELL 07e-07] Loaded directly from checkpoint
[CELL 07e-07] GRU baseline Acc@1: 33.73% (from NB 06)

[CELL 07e-07] ‚úÖ WARM-START COMPLETE
[CELL 07e-07] Meta-model initialized from strong GRU baseline
[CELL 07e-07] Model parameters: 140,695
[CELL 07e-07] Now will meta-train to make it more adaptable!
[CELL 07e-07] elapsed=0.16s
[CELL 07e-07] done


## üî• Key Difference from Notebook 07

**Notebook 07 (Random Init)**:
```python
meta_model = GRURecommender(...)  # Random initialization
# Meta-train from scratch ‚Üí 30.52% Acc@1
```

**Notebook 07e (Warm-Start)**:
```python
meta_model = GRURecommender(...)
meta_model.load_state_dict(gru_baseline)  # ‚Üê Load GRU baseline (33.73%)
# Meta-train from here ‚Üí Expected: 35-38% Acc@1 ‚úÖ
```

**Why this works**:
- GRU baseline already knows how to recommend courses (33.73%)
- MAML meta-training refines it to adapt better to new users
- Combines: Strong task initialization + Meta-learned adaptation

**From here onwards, the rest of the notebook is IDENTICAL to Notebook 07:**
- Same MAML training loop
- Same meta-batch sampling
- Same inner/outer loop updates
- Same evaluation

**Only difference**: We started from a better place!

In [9]:
# [CELL 07f-08] Helper functions (same as Notebook 07)

t0 = cell_start("CELL 07f-08", "Define helper functions")

def get_episode_data(episode_row, pairs_df):
    """Extract support and query pairs for an episode."""
    support_pair_ids = episode_row["support_pair_ids"]
    query_pair_ids = episode_row["query_pair_ids"]

    support_pairs = pairs_df[pairs_df["pair_id"].isin(support_pair_ids)].sort_values("label_ts_epoch")
    query_pairs = pairs_df[pairs_df["pair_id"].isin(query_pair_ids)].sort_values("label_ts_epoch")

    return support_pairs, query_pairs

def pairs_to_batch(pairs_df, max_len):
    """Convert pairs to batched tensors."""
    prefixes = []
    labels = []
    lengths = []

    for _, row in pairs_df.iterrows():
        prefix = row["prefix"]
        if len(prefix) > max_len:
            prefix = prefix[-max_len:]
        prefixes.append(prefix)
        labels.append(row["label"])
        lengths.append(len(prefix))

    # Pad sequences
    max_l = max(lengths)
    padded = []
    for seq in prefixes:
        padded.append(list(seq) + [0] * (max_l - len(seq)))

    return (
        torch.LongTensor(padded).to(DEVICE),
        torch.LongTensor(labels).to(DEVICE),
        torch.LongTensor(lengths).to(DEVICE),
    )

def functional_forward(seq, lengths, params, hidden_dim, n_items):
    """Functional forward pass using explicit parameters."""
    batch_size = seq.size(0)
    
    # 1. Embedding
    emb = F.embedding(seq, params["embedding.weight"], padding_idx=0)
    
    # 2. GRU (manual implementation)
    h = torch.zeros(batch_size, hidden_dim, device=seq.device)
    w_ih = params["gru.weight_ih_l0"]
    w_hh = params["gru.weight_hh_l0"]
    b_ih = params["gru.bias_ih_l0"]
    b_hh = params["gru.bias_hh_l0"]
    
    for t in range(emb.size(1)):
        x_t = emb[:, t, :]
        gi = F.linear(x_t, w_ih, b_ih)
        gh = F.linear(h, w_hh, b_hh)
        i_r, i_z, i_n = gi.chunk(3, 1)
        h_r, h_z, h_n = gh.chunk(3, 1)
        
        r = torch.sigmoid(i_r + h_r)
        z = torch.sigmoid(i_z + h_z)
        n = torch.tanh(i_n + r * h_n)
        h_new = (1 - z) * n + z * h
        
        mask = (lengths > t).unsqueeze(1).float()
        h = mask * h_new + (1 - mask) * h
    
    # 3. FC layer
    logits = F.linear(h, params["fc.weight"], params["fc.bias"])
    return logits

print("[CELL 07f-08] Helper functions defined")

cell_end("CELL 07f-08", t0)


[CELL 07e-08] Define helper functions
[CELL 07e-08] start=2026-01-12T15:11:54
[CELL 07e-08] Helper functions defined
[CELL 07e-08] elapsed=0.00s
[CELL 07e-08] done


In [None]:
# [CELL 07f-08b] MAML Evaluation Function (with Layer-Specific Adaptation Support)

t0 = cell_start("CELL 07f-08b", "MAML Evaluation Function")

def evaluate_maml(meta_model, support_loaders, query_loaders, users, device,
                  hidden_dim, n_items, inner_lr, num_inner_steps, criterion,
                  adaptable_param_names=None):
    """
    Evaluate MAML on validation or test set with optional layer-specific adaptation.

    Args:
        meta_model: The meta-learned model
        support_loaders: Dict of support set loaders per user
        query_loaders: Dict of query set loaders per user
        users: List of user IDs to evaluate
        device: torch device
        hidden_dim: GRU hidden dimension
        n_items: Number of items (courses)
        inner_lr: Inner loop learning rate
        num_inner_steps: Number of inner loop gradient steps
        criterion: Loss criterion
        adaptable_param_names: List of parameter names to adapt (layer-specific).
                               If None, adapt all parameters.

    Returns:
        avg_loss: Average query loss across all users
        avg_acc: Average query accuracy across all users
    """
    meta_model.eval()

    total_loss = 0.0
    total_acc = 0.0
    num_tasks = len(users)

    with torch.no_grad():
        for user_id in users:
            # Get support and query sets
            support_seq, support_lengths, support_labels = support_loaders[user_id].dataset.tensors
            query_seq, query_lengths, query_labels = query_loaders[user_id].dataset.tensors

            support_seq = support_seq.to(device)
            support_lengths = support_lengths.to(device)
            support_labels = support_labels.to(device)
            query_seq = query_seq.to(device)
            query_lengths = query_lengths.to(device)
            query_labels = query_labels.to(device)

            # Initialize fast weights
            fast_weights = OrderedDict(meta_model.named_parameters())

            # Inner loop adaptation
            for step in range(num_inner_steps):
                support_logits = functional_forward(
                    support_seq, support_lengths, fast_weights, hidden_dim, n_items
                )
                support_loss = criterion(support_logits, support_labels)

                # Compute gradients (layer-specific if specified)
                if adaptable_param_names is not None:
                    # Layer-specific: only adapt specified parameters
                    adaptable_params = [fast_weights[name] for name in adaptable_param_names]
                else:
                    # Adapt all parameters
                    adaptable_params = list(fast_weights.values())

                grads = torch.autograd.grad(
                    support_loss,
                    adaptable_params,
                    create_graph=False
                )

                # Update parameters
                new_fast_weights = OrderedDict()
                grad_idx = 0
                for name, param in fast_weights.items():
                    if adaptable_param_names is None or name in adaptable_param_names:
                        new_fast_weights[name] = param - inner_lr * grads[grad_idx]
                        grad_idx += 1
                    else:
                        new_fast_weights[name] = param
                fast_weights = new_fast_weights

            # Evaluate on query set
            query_logits = functional_forward(
                query_seq, query_lengths, fast_weights, hidden_dim, n_items
            )
            query_loss = criterion(query_logits, query_labels)

            # Compute accuracy
            _, predicted = torch.max(query_logits, 1)
            acc = (predicted == query_labels).float().mean()

            total_loss += query_loss.item()
            total_acc += acc.item()

    avg_loss = total_loss / num_tasks
    avg_acc = total_acc / num_tasks

    return avg_loss, avg_acc

print("[CELL 07f-08b] MAML evaluation function defined")
print("  ‚Ä¢ Supports layer-specific adaptation via adaptable_param_names parameter")
print("  ‚Ä¢ Used during training for validation")

cell_end("CELL 07f-08b", t0)


In [None]:
# [CELL 07f-09] MAML Training Loop with Residual Meta-Loss and Layer-Specific Adaptation
t0 = cell_start("CELL 07f-09", "MAML Training Loop (Residual + Layer-Specific)")
print(f"[CELL 07f-09] Starting MAML training with THREE FIXES:")
print(f"  ‚Ä¢ [Fix #1] Inner LR (Œ±) = {inner_lr} (increased from 0.01)")
print(f"  ‚Ä¢ [Fix #2] Lambda residual (Œª) = {lambda_residual} (Residual MAML)")
print(f"  ‚Ä¢ [Fix #3] Layer-specific adaptation (freeze GRU)")
print(f"  ‚Ä¢ Meta batch size = {meta_batch_size}")
print(f"  ‚Ä¢ Meta iterations = {num_meta_iterations}")
print(f"  ‚Ä¢ Num inner steps = {num_inner_steps}")
print()
# [Fix #3] Layer-specific adaptation: Only adapt embedding + FC, freeze GRU
adaptable_param_names = ['embedding.weight', 'fc.weight', 'fc.bias']
frozen_param_names = ['gru.weight_ih_l0', 'gru.weight_hh_l0', 'gru.bias_ih_l0', 'gru.bias_hh_l0']
print(f"[Fix #3] Layer-specific adaptation:")
print(f"  ‚Ä¢ Adaptable: {adaptable_param_names}")
print(f"  ‚Ä¢ Frozen (GRU): {frozen_param_names}")
print()
meta_model.train()
meta_optimizer = torch.optim.Adam(meta_model.parameters(), lr=outer_lr)
# Training history
history = {
    'meta_loss': [],
    'meta_loss_adapted': [],
    'meta_loss_unadapted': [],
    'val_loss': [],
    'val_acc': []
}
best_val_loss = float('inf')
patience_counter = 0
for meta_iter in range(num_meta_iterations):
    meta_optimizer.zero_grad()
    meta_loss_total = 0.0
    meta_loss_adapted_total = 0.0
    meta_loss_unadapted_total = 0.0
    # Sample meta-batch of tasks
    for task_idx in range(meta_batch_size):
        # Sample a task (user)
        user_id = cold_users[torch.randint(0, len(cold_users), (1,)).item()]
        # Get support and query sets
        support_seq, support_lengths, support_labels = cold_support_loaders[user_id].dataset.tensors
        query_seq, query_lengths, query_labels = cold_query_loaders[user_id].dataset.tensors
        support_seq = support_seq.to(device)
        support_lengths = support_lengths.to(device)
        support_labels = support_labels.to(device)
        query_seq = query_seq.to(device)
        query_lengths = query_lengths.to(device)
        query_labels = query_labels.to(device)
        # Initialize fast weights from meta-model
        fast_weights = OrderedDict(meta_model.named_parameters())
        # Inner loop: Only adapt embedding + FC layers (Fix #3)
        for step in range(num_inner_steps):
            support_logits = functional_forward(
                support_seq, support_lengths, fast_weights, hidden_dim, n_items
            )
            support_loss = criterion(support_logits, support_labels)
            # [Fix #3] Only compute gradients for adaptable parameters
            adaptable_params = [fast_weights[name] for name in adaptable_param_names]
            grads = torch.autograd.grad(
                support_loss,
                adaptable_params,
                create_graph=False  # FOMAML
            )
            # Update only adaptable parameters
            new_fast_weights = OrderedDict()
            grad_idx = 0
            for name, param in fast_weights.items():
                if name in adaptable_param_names:
                    new_fast_weights[name] = param - inner_lr * grads[grad_idx]
                    grad_idx += 1
                else:
                    new_fast_weights[name] = param  # Keep frozen params unchanged
            fast_weights = new_fast_weights
        # Outer loop: Compute meta-loss on query set
        # Compute adapted query loss (with fast weights)
        query_logits_adapted = functional_forward(
            query_seq, query_lengths, fast_weights, hidden_dim, n_items
        )
        query_loss_adapted = criterion(query_logits_adapted, query_labels)
        # [Fix #2] Residual MAML: Compute unadapted query loss (with original meta-model)
        original_params = OrderedDict(meta_model.named_parameters())
        query_logits_unadapted = functional_forward(
            query_seq, query_lengths, original_params, hidden_dim, n_items
        )
        query_loss_unadapted = criterion(query_logits_unadapted, query_labels)
        # [Fix #2] Residual MAML meta-loss: (1-Œª) * L_adapted + Œª * L_unadapted
        task_meta_loss = (1 - lambda_residual) * query_loss_adapted + lambda_residual * query_loss_unadapted
        meta_loss_total = meta_loss_total + task_meta_loss
        meta_loss_adapted_total = meta_loss_adapted_total + query_loss_adapted.item()
        meta_loss_unadapted_total = meta_loss_unadapted_total + query_loss_unadapted.item()
    # Average meta-loss over meta-batch
    meta_loss_avg = meta_loss_total / meta_batch_size
    # Meta-update
    meta_loss_avg.backward()
    meta_optimizer.step()
    # Record history
    history['meta_loss'].append(meta_loss_avg.item())
    history['meta_loss_adapted'].append(meta_loss_adapted_total / meta_batch_size)
    history['meta_loss_unadapted'].append(meta_loss_unadapted_total / meta_batch_size)
    # Validation every 100 iterations
    if (meta_iter + 1) % 100 == 0:
        meta_model.eval()
        val_loss, val_acc = evaluate_maml(
            meta_model, cold_val_support_loaders, cold_val_query_loaders,
            cold_val_users, device, hidden_dim, n_items, inner_lr, num_inner_steps,
            criterion, adaptable_param_names  # Pass adaptable params for layer-specific eval
        )
        meta_model.train()
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        print(f"Meta-iter {meta_iter + 1}/{num_meta_iterations} | "
              f"Meta-loss: {meta_loss_avg.item():.4f} | "
              f"Adapted: {meta_loss_adapted_total/meta_batch_size:.4f} | "
              f"Unadapted: {meta_loss_unadapted_total/meta_batch_size:.4f} | "
              f"Val-loss: {val_loss:.4f} | Val-acc: {val_acc:.4f}")
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save best model
            torch.save(meta_model.state_dict(), RUN_DIR / 'best_meta_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= 20:  # 2000 iterations
                print(f"Early stopping at iteration {meta_iter + 1}")
                break
# Save final model
torch.save(meta_model.state_dict(), RUN_DIR / 'final_meta_model.pt')
# Save history
with open(RUN_DIR / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)
print(f"\n[CELL 07f-09] Training completed!")
print(f"  ‚Ä¢ Best val loss: {best_val_loss:.4f}")
print(f"  ‚Ä¢ Models saved to {RUN_DIR}")
cell_end("CELL 07f-09", t0)


In [10]:
# Load the trained warmstart model
import torch
max_seq_len = 50

MODEL_PATH = "../models/maml/maml_warmstart_gru_K5.pth"
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
meta_model.load_state_dict(checkpoint['model_state_dict'])
meta_model.eval()

print(f"Loaded trained warmstart model from: {MODEL_PATH}")
print(f"Model has {sum(p.numel() for p in meta_model.parameters()):,} parameters")

Loaded trained warmstart model from: ../models/maml/maml_warmstart_gru_K5.pth
Model has 140,695 parameters


In [11]:
# Extract config variables from CFG for convenience

# MAML config
inner_lr = CFG['maml_config']['inner_lr']
num_inner_steps = CFG['maml_config']['num_inner_steps']
outer_lr = CFG['maml_config']['outer_lr']
meta_batch_size = CFG['maml_config']['meta_batch_size']
num_meta_iterations = CFG['maml_config']['num_meta_iterations']

# GRU config
embedding_dim = CFG['gru_config']['embedding_dim']
hidden_dim = CFG['gru_config']['hidden_dim']
num_layers = CFG['gru_config']['num_layers']
dropout = CFG['gru_config']['dropout']
max_seq_len = CFG['gru_config']['max_seq_len']

# K-shot config
K = CFG['k_shot_config']['K']
Q = CFG['k_shot_config']['Q']

print(f"Extracted config variables:")
print(f"  inner_lr = {inner_lr}")
print(f"  num_inner_steps = {num_inner_steps}")
print(f"  hidden_dim = {hidden_dim}")
print(f"  max_seq_len = {max_seq_len}")
print(f"  K = {K}, Q = {Q}")

Extracted config variables:
  inner_lr = 0.01
  num_inner_steps = 5
  hidden_dim = 128
  max_seq_len = 50
  K = 5, Q = 10


In [12]:
# [CELL 07f-10] Meta-testing: Zero-shot (K=0) - no adaptation

t0 = cell_start("CELL 07f-10", "Zero-shot evaluation (K=0)")
max_seq_len = 50

print("[CELL 07f-10] Evaluating meta-learned model WITHOUT adaptation (zero-shot)...")

meta_model.eval()
zeroshot_predictions = []
zeroshot_labels = []

with torch.no_grad():  # Pure inference, no gradients needed
    for _, episode in episodes_test.iterrows():
        support_pairs, query_pairs = get_episode_data(episode, pairs_test)

        if len(query_pairs) == 0:
            continue

        # Only use query set (no support set adaptation)
        query_seq, query_labels_test, query_lengths = pairs_to_batch(query_pairs, max_seq_len)

        # Use original meta-learned model (no adaptation)
        query_logits = meta_model(query_seq, query_lengths)
        query_probs = torch.softmax(query_logits, dim=-1).cpu().numpy()

        zeroshot_predictions.append(query_probs)
        zeroshot_labels.extend(query_labels_test.cpu().numpy())

# Compute metrics
if len(zeroshot_predictions) > 0:
    zeroshot_predictions = np.vstack(zeroshot_predictions)
    zeroshot_labels = np.array(zeroshot_labels)
    zeroshot_metrics = compute_metrics(zeroshot_predictions, zeroshot_labels)

    print(f"\n[CELL 07f-10] Zero-shot Results (No Adaptation):")
    print(f"  Accuracy@1:  {zeroshot_metrics['accuracy@1']:.4f}")
    print(f"  Recall@5:    {zeroshot_metrics['recall@5']:.4f}")
    print(f"  Recall@10:   {zeroshot_metrics['recall@10']:.4f}")
    print(f"  MRR:         {zeroshot_metrics['mrr']:.4f}")
else:
    print("[CELL 07f-10] WARNING: No predictions generated")
    zeroshot_metrics = {}

cell_end("CELL 07f-10", t0)



[CELL 07e-10] Zero-shot evaluation (K=0)
[CELL 07e-10] start=2026-01-12T15:12:17
[CELL 07e-10] Evaluating meta-learned model WITHOUT adaptation (zero-shot)...

[CELL 07e-10] Zero-shot Results (No Adaptation):
  Accuracy@1:  0.2494
  Recall@5:    0.4598
  Recall@10:   0.5685
  MRR:         0.3553
[CELL 07e-10] elapsed=2.32s
[CELL 07e-10] done


In [13]:
import torch
checkpoint = torch.load('../models/maml/maml_warmstart_gru_K5.pth', map_location=DEVICE)
meta_model.load_state_dict(checkpoint['model_state_dict'])
meta_model.eval()

GRURecommender(
  (embedding): Embedding(343, 64, padding_idx=0)
  (gru): GRU(64, 128, batch_first=True)
  (fc): Linear(in_features=128, out_features=343, bias=True)
)

In [15]:
# Extract config variables and define missing objects

# MAML config
inner_lr = CFG['maml_config']['inner_lr']
num_inner_steps = CFG['maml_config']['num_inner_steps']
outer_lr = CFG['maml_config']['outer_lr']
meta_batch_size = CFG['maml_config']['meta_batch_size']
num_meta_iterations = CFG['maml_config']['num_meta_iterations']

# GRU config
embedding_dim = CFG['gru_config']['embedding_dim']
hidden_dim = CFG['gru_config']['hidden_dim']
num_layers = CFG['gru_config']['num_layers']
dropout = CFG['gru_config']['dropout']
max_seq_len = CFG['gru_config']['max_seq_len']

# K-shot config
K = CFG['k_shot_config']['K']
Q = CFG['k_shot_config']['Q']

# Loss criterion
criterion = nn.CrossEntropyLoss()

print(f"Extracted config variables:")
print(f"  inner_lr = {inner_lr}")
print(f"  num_inner_steps = {num_inner_steps}")
print(f"  hidden_dim = {hidden_dim}")
print(f"  max_seq_len = {max_seq_len}")
print(f"  K = {K}, Q = {Q}")
print(f"  criterion = CrossEntropyLoss")

Extracted config variables:
  inner_lr = 0.01
  num_inner_steps = 5
  hidden_dim = 128
  max_seq_len = 50
  K = 5, Q = 10
  criterion = CrossEntropyLoss


In [16]:
# [CELL 07f-11] Meta-testing: Few-shot K=5 (with adaptation using functional forward)

t0 = cell_start("CELL 07f-11", "Few-shot evaluation (K=5)")

print("[CELL 07f-11] Evaluating meta-learned model WITH adaptation (few-shot K=5)...")

meta_model.eval()
fewshot_predictions = []
fewshot_labels = []

for _, episode in episodes_test.iterrows():
    support_pairs, query_pairs = get_episode_data(episode, pairs_test)
    
    if len(support_pairs) == 0 or len(query_pairs) == 0:
        continue
    
    support_seq, support_labels_test, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
    query_seq, query_labels_test, query_lengths = pairs_to_batch(query_pairs, max_seq_len)
    
    # Adapt using functional forward (consistent with training)
    with torch.enable_grad():
        # Clone parameters for adaptation
        fast_weights_test = OrderedDict()
        for name, param in meta_model.named_parameters():
            fast_weights_test[name] = param.clone().requires_grad_()
        
        # Inner loop adaptation
        for _ in range(num_inner_steps):
            support_logits_test = functional_forward(
                support_seq, support_lengths, fast_weights_test, hidden_dim, n_items
            )
            support_loss_test = criterion(support_logits_test, support_labels_test)
            
            grads_test = torch.autograd.grad(
                support_loss_test,
                fast_weights_test.values(),
                create_graph=False  # No second-order needed for testing
            )
            
            fast_weights_test = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights_test.items(), grads_test)
            )
    
    # Evaluate on query (no gradients)
    with torch.no_grad():
        query_logits_test = functional_forward(
            query_seq, query_lengths, fast_weights_test, hidden_dim, n_items
        )
        probs = torch.softmax(query_logits_test, dim=-1).cpu().numpy()
        
        fewshot_predictions.append(probs)
        fewshot_labels.extend(query_labels_test.cpu().numpy())

fewshot_predictions = np.vstack(fewshot_predictions)
fewshot_labels = np.array(fewshot_labels)

fewshot_metrics = compute_metrics(fewshot_predictions, fewshot_labels)

print(f"\n[CELL 07f-11] Few-shot Results (K=5 adaptation):")
print(f"  - Accuracy@1:  {fewshot_metrics['accuracy@1']:.4f}")
print(f"  - Recall@5:    {fewshot_metrics['recall@5']:.4f}")
print(f"  - Recall@10:   {fewshot_metrics['recall@10']:.4f}")
print(f"  - MRR:         {fewshot_metrics['mrr']:.4f}")

cell_end("CELL 07f-11", t0)


[CELL 07e-11] Few-shot evaluation (K=5)
[CELL 07e-11] start=2026-01-12T16:27:01
[CELL 07e-11] Evaluating meta-learned model WITH adaptation (few-shot K=5)...

[CELL 07e-11] Few-shot Results (K=5 adaptation):
  - Accuracy@1:  0.3165
  - Recall@5:    0.5254
  - Recall@10:   0.6136
  - MRR:         0.4184
[CELL 07e-11] elapsed=48.57s
[CELL 07e-11] done


In [17]:
# [CELL 07f-12] Ablation Study 1: Support set size (K=1,3,5,10) - functional forward

t0 = cell_start("CELL 07f-12", "Ablation: support set size")

print("[CELL 07f-12] Ablation Study: Varying support set size K...")

support_sizes = CFG["ablation_configs"]["support_set_sizes"]
ablation_support_results = {}

meta_model.eval()

for K_test in support_sizes:
    print(f"\n[CELL 07f-12] Testing with K={K_test}...")
    
    predictions = []
    labels = []
    
    for _, episode in episodes_test.iterrows():
        support_pairs, query_pairs = get_episode_data(episode, pairs_test)
        
        if len(support_pairs) < K_test or len(query_pairs) == 0:
            continue
        
        # Use only K_test support pairs
        support_pairs_k = support_pairs.head(K_test)
        
        support_seq, support_labels_abl, support_lengths = pairs_to_batch(support_pairs_k, max_seq_len)
        query_seq, query_labels_abl, query_lengths = pairs_to_batch(query_pairs, max_seq_len)
        
        # Adapt using functional forward
        with torch.enable_grad():
            fast_weights_abl = OrderedDict()
            for name, param in meta_model.named_parameters():
                fast_weights_abl[name] = param.clone().requires_grad_()
            
            for _ in range(num_inner_steps):
                support_logits_abl = functional_forward(
                    support_seq, support_lengths, fast_weights_abl, hidden_dim, n_items
                )
                support_loss_abl = criterion(support_logits_abl, support_labels_abl)
                
                grads_abl = torch.autograd.grad(
                    support_loss_abl,
                    fast_weights_abl.values(),
                    create_graph=False
                )
                
                fast_weights_abl = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights_abl.items(), grads_abl)
                )
        
        # Evaluate on query
        with torch.no_grad():
            query_logits_abl = functional_forward(
                query_seq, query_lengths, fast_weights_abl, hidden_dim, n_items
            )
            probs = torch.softmax(query_logits_abl, dim=-1).cpu().numpy()
            
            predictions.append(probs)
            labels.extend(query_labels_abl.cpu().numpy())
    
    if len(predictions) > 0:
        predictions = np.vstack(predictions)
        labels = np.array(labels)
        metrics = compute_metrics(predictions, labels)
        ablation_support_results[K_test] = metrics
        
        print(f"[CELL 07f-12] K={K_test}: Acc@1={metrics['accuracy@1']:.4f}, "
              f"Recall@5={metrics['recall@5']:.4f}, MRR={metrics['mrr']:.4f}")

print(f"\n[CELL 07f-12] Ablation complete: tested K ‚àà {support_sizes}")

cell_end("CELL 07f-12", t0)


[CELL 07e-12] Ablation: support set size
[CELL 07e-12] start=2026-01-12T16:32:22
[CELL 07e-12] Ablation Study: Varying support set size K...

[CELL 07e-12] Testing with K=1...
[CELL 07e-12] K=1: Acc@1=0.2633, Recall@5=0.4818, MRR=0.3713

[CELL 07e-12] Testing with K=3...
[CELL 07e-12] K=3: Acc@1=0.3029, Recall@5=0.5153, MRR=0.4072

[CELL 07e-12] Testing with K=5...
[CELL 07e-12] K=5: Acc@1=0.3165, Recall@5=0.5254, MRR=0.4184

[CELL 07e-12] Testing with K=10...

[CELL 07e-12] Ablation complete: tested K ‚àà [1, 3, 5, 10]
[CELL 07e-12] elapsed=128.32s
[CELL 07e-12] done


In [18]:
# [CELL 07f-13] Ablation Study 2: Adaptation steps (1,3,5,10) - functional forward

t0 = cell_start("CELL 07f-13", "Ablation: adaptation steps")

print("[CELL 07f-13] Ablation Study: Varying adaptation steps...")

adaptation_steps = CFG["ablation_configs"]["adaptation_steps"]
ablation_steps_results = {}

meta_model.eval()

for num_steps in adaptation_steps:
    print(f"\n[CELL 07f-13] Testing with {num_steps} adaptation steps...")
    
    predictions = []
    labels = []
    
    for _, episode in episodes_test.iterrows():
        support_pairs, query_pairs = get_episode_data(episode, pairs_test)
        
        if len(support_pairs) == 0 or len(query_pairs) == 0:
            continue
        
        support_seq, support_labels_steps, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
        query_seq, query_labels_steps, query_lengths = pairs_to_batch(query_pairs, max_seq_len)
        
        # Adapt using functional forward with varying steps
        with torch.enable_grad():
            fast_weights_steps = OrderedDict()
            for name, param in meta_model.named_parameters():
                fast_weights_steps[name] = param.clone().requires_grad_()
            
            for _ in range(num_steps):  # Use num_steps instead of num_inner_steps
                support_logits_steps = functional_forward(
                    support_seq, support_lengths, fast_weights_steps, hidden_dim, n_items
                )
                support_loss_steps = criterion(support_logits_steps, support_labels_steps)
                
                grads_steps = torch.autograd.grad(
                    support_loss_steps,
                    fast_weights_steps.values(),
                    create_graph=False
                )
                
                fast_weights_steps = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights_steps.items(), grads_steps)
                )
        
        # Evaluate on query
        with torch.no_grad():
            query_logits_steps = functional_forward(
                query_seq, query_lengths, fast_weights_steps, hidden_dim, n_items
            )
            probs = torch.softmax(query_logits_steps, dim=-1).cpu().numpy()
            
            predictions.append(probs)
            labels.extend(query_labels_steps.cpu().numpy())
    
    if len(predictions) > 0:
        predictions = np.vstack(predictions)
        labels = np.array(labels)
        metrics = compute_metrics(predictions, labels)
        ablation_steps_results[num_steps] = metrics
        
        print(f"[CELL 07f-13] Steps={num_steps}: Acc@1={metrics['accuracy@1']:.4f}, "
              f"Recall@5={metrics['recall@5']:.4f}, MRR={metrics['mrr']:.4f}")

print(f"\n[CELL 07f-13] Ablation complete: tested adaptation steps ‚àà {adaptation_steps}")

cell_end("CELL 07f-13", t0)


[CELL 07e-13] Ablation: adaptation steps
[CELL 07e-13] start=2026-01-12T16:35:24
[CELL 07e-13] Ablation Study: Varying adaptation steps...

[CELL 07e-13] Testing with 1 adaptation steps...
[CELL 07e-13] Steps=1: Acc@1=0.2884, Recall@5=0.4916, MRR=0.3885

[CELL 07e-13] Testing with 3 adaptation steps...
[CELL 07e-13] Steps=3: Acc@1=0.3064, Recall@5=0.5110, MRR=0.4084

[CELL 07e-13] Testing with 5 adaptation steps...
[CELL 07e-13] Steps=5: Acc@1=0.3165, Recall@5=0.5254, MRR=0.4184

[CELL 07e-13] Testing with 10 adaptation steps...
[CELL 07e-13] Steps=10: Acc@1=0.3246, Recall@5=0.5358, MRR=0.4301

[CELL 07e-13] Ablation complete: tested adaptation steps ‚àà [1, 3, 5, 10]
[CELL 07e-13] elapsed=178.84s
[CELL 07e-13] done


In [19]:
# [CELL 07f-14] Parameter update analysis

t0 = cell_start("CELL 07f-14", "Parameter update analysis")

# Set to train mode for gradient computation
meta_model.train()

print("[CELL 07f-14] Analyzing parameter updates during adaptation...")

# Select one test episode for analysis
sample_episode = episodes_test.iloc[0]
support_pairs, query_pairs = get_episode_data(sample_episode, pairs_test)

if len(support_pairs) > 0:
    support_seq, support_labels_viz, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
    
    # Get original parameters (before adaptation)
    param_norms_before = {}
    for name, param in meta_model.named_parameters():
        param_norms_before[name] = param.data.norm().item()
    
    # Adapt using functional forward (no in-place operations)
    with torch.enable_grad():
        fast_weights_viz = OrderedDict()
        for name, param in meta_model.named_parameters():
            fast_weights_viz[name] = param.clone().requires_grad_(True)
        
        for _ in range(num_inner_steps):
            support_logits_viz = functional_forward(
                support_seq, support_lengths, fast_weights_viz, hidden_dim, n_items
            )
            support_loss_viz = criterion(support_logits_viz, support_labels_viz)
            
            grads_viz = torch.autograd.grad(
                support_loss_viz,
                fast_weights_viz.values(),
                create_graph=False
            )
            
            # FIXED: Create new dict instead of modifying in-place
            new_fast_weights = OrderedDict()
            for (name, param), grad in zip(fast_weights_viz.items(), grads_viz):
                new_fast_weights[name] = param - inner_lr * grad
            fast_weights_viz = new_fast_weights
    
    # Compute parameter changes
    param_norms_after = {}
    param_changes = {}
    
    for name in fast_weights_viz.keys():
        adapted_norm = fast_weights_viz[name].data.norm().item()
        original_norm = param_norms_before[name]
        change = adapted_norm - original_norm
        
        param_norms_after[name] = adapted_norm
        param_changes[name] = {
            "before": original_norm,
            "after": adapted_norm,
            "change": change,
            "change_pct": (change / original_norm * 100) if original_norm > 0 else 0,
        }
    
    print(f"[CELL 07f-14] Parameter changes after {num_inner_steps} adaptation steps:")
    print(f"{'Parameter':<30} {'Before':>12} {'After':>12} {'Change':>12} {'Change %':>10}")
    print("-" * 80)
    
    for name, stats in list(param_changes.items())[:10]:  # Show first 10
        print(f"{name:<30} {stats['before']:>12.4f} {stats['after']:>12.4f} "
              f"{stats['change']:>12.4f} {stats['change_pct']:>9.2f}%")
    
else:
    print("[CELL 07f-14] WARNING: No support pairs for visualization")


# Reset to eval mode
meta_model.eval()
cell_end("CELL 07f-14", t0)


[CELL 07e-14] Parameter update analysis
[CELL 07e-14] start=2026-01-12T16:42:44
[CELL 07e-14] Analyzing parameter updates during adaptation...
[CELL 07e-14] Parameter changes after 5 adaptation steps:
Parameter                            Before        After       Change   Change %
--------------------------------------------------------------------------------
embedding.weight                   162.5457     162.5457       0.0000      0.00%
gru.weight_ih_l0                    34.0514      34.0516       0.0003      0.00%
gru.weight_hh_l0                    43.4369      43.4365      -0.0004     -0.00%
gru.bias_ih_l0                      11.2469      11.2472       0.0003      0.00%
gru.bias_hh_l0                      11.9787      11.9790       0.0003      0.00%
fc.weight                           52.0300      52.0309       0.0010      0.00%
fc.bias                              3.6749       3.6751       0.0002      0.00%
[CELL 07e-14] elapsed=0.21s
[CELL 07e-14] done


In [21]:
# Create minimal training_history for results summary
# (The actual training was already completed, this is just for the final report)
training_history = {
    "meta_iterations": [],
    "meta_train_loss": [],
    "val_accuracy": [],
    "val_iterations": [],
}

print("Created placeholder training_history for results summary")

Created placeholder training_history for results summary


In [22]:
# [CELL 07f-15] Results summary table + comparison with baselines

t0 = cell_start("CELL 07f-15", "Results summary")

print("\n[CELL 07f-15] ========== RESULTS SUMMARY (Test Set) ==========")
print(f"K={K}, Q={Q} | Test Episodes: {len(episodes_test):,}\n")

# Load GRU baseline results for comparison
baseline_results_path = RESULTS_DIR / f"baselines_K{K}_Q{Q}.json"
if baseline_results_path.exists():
    baseline_results = read_json(baseline_results_path)
    gru_baseline_metrics = baseline_results["baselines"]["gru_global"]
else:
    gru_baseline_metrics = {"accuracy@1": 0.3373, "recall@5": 0.5590, "recall@10": 0.6575, "mrr": 0.4437}

# Create comparison table
print(f"{'Model':<30} {'Acc@1':>10} {'Recall@5':>10} {'Recall@10':>10} {'MRR':>10}")
print("-" * 72)

# Baselines
print(f"{'GRU (Baseline - 06)':<30} {gru_baseline_metrics['accuracy@1']:>10.4f} "
      f"{gru_baseline_metrics['recall@5']:>10.4f} {gru_baseline_metrics['recall@10']:>10.4f} "
      f"{gru_baseline_metrics['mrr']:>10.4f}")

# MAML results
print(f"{'MAML Zero-shot':<30} {zeroshot_metrics['accuracy@1']:>10.4f} "
      f"{zeroshot_metrics['recall@5']:>10.4f} {zeroshot_metrics['recall@10']:>10.4f} "
      f"{zeroshot_metrics['mrr']:>10.4f}")

print(f"{'MAML Few-shot (K=5)':<30} {fewshot_metrics['accuracy@1']:>10.4f} "
      f"{fewshot_metrics['recall@5']:>10.4f} {fewshot_metrics['recall@10']:>10.4f} "
      f"{fewshot_metrics['mrr']:>10.4f}")

# Improvement over baseline
improvement = (fewshot_metrics['accuracy@1'] - gru_baseline_metrics['accuracy@1']) / gru_baseline_metrics['accuracy@1'] * 100
print(f"\n[CELL 07f-15] MAML Few-shot improvement over GRU baseline: {improvement:+.2f}%")

# Ablation results
print(f"\n[CELL 07f-15] ========== ABLATION STUDY 1: Support Set Size ==========")
print(f"{'K (Support Size)':<20} {'Acc@1':>10} {'Recall@5':>10} {'Recall@10':>10} {'MRR':>10}")
print("-" * 62)
for K_test, metrics in ablation_support_results.items():
    print(f"{K_test:<20} {metrics['accuracy@1']:>10.4f} {metrics['recall@5']:>10.4f} "
          f"{metrics['recall@10']:>10.4f} {metrics['mrr']:>10.4f}")

print(f"\n[CELL 07f-15] ========== ABLATION STUDY 2: Adaptation Steps ==========")
print(f"{'Adaptation Steps':<20} {'Acc@1':>10} {'Recall@5':>10} {'Recall@10':>10} {'MRR':>10}")
print("-" * 62)
for num_steps, metrics in ablation_steps_results.items():
    print(f"{num_steps:<20} {metrics['accuracy@1']:>10.4f} {metrics['recall@5']:>10.4f} "
          f"{metrics['recall@10']:>10.4f} {metrics['mrr']:>10.4f}")

# Save all results
all_results = {
    "run_id": RUN_ID,
    "k_shot_config": {"K": K, "Q": Q},
    "n_test_episodes": len(episodes_test),
    "baseline": {
        "gru_global": gru_baseline_metrics,
    },
    "maml": {
        "zero_shot": zeroshot_metrics,
        "few_shot_K5": fewshot_metrics,
    },
    "ablation_support_size": ablation_support_results,
    "ablation_adaptation_steps": ablation_steps_results,
    "improvement_over_baseline_pct": improvement,
    "training_history": training_history,
}

results_path = Path(CFG["outputs"]["results"])
write_json_atomic(results_path, all_results)
print(f"\n[CELL 07f-15] Saved: {results_path.name}")

cell_end("CELL 07f-15", t0)


[CELL 07e-15] Results summary
[CELL 07e-15] start=2026-01-12T16:44:02

K=5, Q=10 | Test Episodes: 346

Model                               Acc@1   Recall@5  Recall@10        MRR
------------------------------------------------------------------------
GRU (Baseline - 06)                0.3373     0.5590     0.6575     0.4438
MAML Zero-shot                     0.2494     0.4598     0.5685     0.3553
MAML Few-shot (K=5)                0.3165     0.5254     0.6136     0.4184

[CELL 07e-15] MAML Few-shot improvement over GRU baseline: -6.17%

K (Support Size)          Acc@1   Recall@5  Recall@10        MRR
--------------------------------------------------------------
1                        0.2633     0.4818     0.5829     0.3713
3                        0.3029     0.5153     0.6092     0.4072
5                        0.3165     0.5254     0.6136     0.4184

Adaptation Steps          Acc@1   Recall@5  Recall@10        MRR
--------------------------------------------------------------
1  

In [24]:
# Define final_model_path (points to the trained warmstart model)
final_model_path = MODELS_DIR / f"maml_warmstart_gru_K{K}.pth"

print(f"final_model_path: {final_model_path}")

final_model_path: C:\anonymous-users-mooc-session-meta\models\maml\maml_warmstart_gru_K5.pth


In [25]:
# [CELL 07f-16] Update report + manifest

t0 = cell_start("CELL 07f-16", "Write report + manifest")

report = read_json(REPORT_PATH)
manifest = read_json(MANIFEST_PATH)

# Metrics
report["metrics"] = {
    "n_test_episodes": len(episodes_test),
    "gru_baseline_acc1": gru_baseline_metrics['accuracy@1'],
    "maml_zero_shot_acc1": zeroshot_metrics['accuracy@1'],
    "maml_few_shot_K5_acc1": fewshot_metrics['accuracy@1'],
    "improvement_over_baseline_pct": improvement,
    "training_iterations": num_meta_iterations,
}

# Key findings
report["key_findings"].extend([
    f"MAML meta-training: {num_meta_iterations:,} iterations with {meta_batch_size} tasks/batch",
    f"Zero-shot performance (no adaptation): Acc@1={zeroshot_metrics['accuracy@1']:.4f}",
    f"Few-shot performance (K=5 adaptation): Acc@1={fewshot_metrics['accuracy@1']:.4f}",
    f"Improvement over GRU baseline: {improvement:+.2f}% ({fewshot_metrics['accuracy@1']:.4f} vs {gru_baseline_metrics['accuracy@1']:.4f})",
    f"Ablation: Best K={max(ablation_support_results, key=lambda k: ablation_support_results[k]['accuracy@1'])} "
    f"(Acc@1={max(ablation_support_results.values(), key=lambda m: m['accuracy@1'])['accuracy@1']:.4f})",
    f"Ablation: Best adaptation steps={max(ablation_steps_results, key=lambda k: ablation_steps_results[k]['accuracy@1'])} "
    f"(Acc@1={max(ablation_steps_results.values(), key=lambda m: m['accuracy@1'])['accuracy@1']:.4f})",
])

# Sanity samples
report["sanity_samples"]["maml_config"] = CFG["maml_config"]
report["sanity_samples"]["sample_episode"] = {
    "episode_id": int(episodes_test.iloc[0]["episode_id"]),
    "user_id": str(episodes_test.iloc[0]["user_id"]),
    "n_support_pairs": len(episodes_test.iloc[0]["support_pair_ids"]),
    "n_query_pairs": len(episodes_test.iloc[0]["query_pair_ids"]),
}

# Fingerprints
report["data_fingerprints"]["meta_model"] = {
    "path": str(final_model_path),
    "bytes": int(final_model_path.stat().st_size),
    "sha256": sha256_file(final_model_path),
}

write_json_atomic(REPORT_PATH, report)

# Manifest
def add_artifact(path: Path) -> None:
    rec = {"path": str(path), "bytes": int(path.stat().st_size), "sha256": None, "sha256_error": None}
    try:
        rec["sha256"] = sha256_file(path)
    except Exception as e:
        rec["sha256_error"] = str(e)
    manifest["artifacts"].append(rec)

add_artifact(final_model_path)
add_artifact(results_path)

# Add checkpoints
for checkpoint_file in sorted(CHECKPOINTS_DIR.glob("checkpoint_iter*.pth")):
    add_artifact(checkpoint_file)

write_json_atomic(MANIFEST_PATH, manifest)

print(f"[CELL 07f-16] Updated: {REPORT_PATH}")
print(f"[CELL 07f-16] Updated: {MANIFEST_PATH}")

cell_end("CELL 07f-16", t0)

print(f"\n{'='*80}")
print(f"‚úÖ NOTEBOOK 07 COMPLETE")
print(f"{'='*80}")
print(f"\nüìä Key Results:")
print(f"  - GRU Baseline (06):        {gru_baseline_metrics['accuracy@1']:.4f} Acc@1")
print(f"  - MAML Zero-shot:           {zeroshot_metrics['accuracy@1']:.4f} Acc@1")
print(f"  - MAML Few-shot (K=5):      {fewshot_metrics['accuracy@1']:.4f} Acc@1")
print(f"  - Improvement:              {improvement:+.2f}%")
print(f"\nüìÅ Outputs:")
print(f"  - Meta-model: {final_model_path}")
print(f"  - Results:    {results_path}")
print(f"  - Report:     {REPORT_PATH}")
print(f"\nüéØ Next Steps:")
print(f"  - Fine-tune hyperparameters (Œ±, Œ≤, inner steps)")
print(f"  - Try different architectures (Transformer, GNN)")
print(f"  - Compare with other meta-learning methods (ProtoNet, Matching Networks)")


[CELL 07e-16] Write report + manifest
[CELL 07e-16] start=2026-01-12T16:45:02
[CELL 07e-16] Updated: C:\anonymous-users-mooc-session-meta\reports\07e_maml_warmstart_xuetangx\20260112_151130\report.json
[CELL 07e-16] Updated: C:\anonymous-users-mooc-session-meta\reports\07e_maml_warmstart_xuetangx\20260112_151130\manifest.json
[CELL 07e-16] elapsed=0.52s
[CELL 07e-16] done

‚úÖ NOTEBOOK 07 COMPLETE

üìä Key Results:
  - GRU Baseline (06):        0.3373 Acc@1
  - MAML Zero-shot:           0.2494 Acc@1
  - MAML Few-shot (K=5):      0.3165 Acc@1
  - Improvement:              -6.17%

üìÅ Outputs:
  - Meta-model: C:\anonymous-users-mooc-session-meta\models\maml\maml_warmstart_gru_K5.pth
  - Results:    C:\anonymous-users-mooc-session-meta\results\maml_warmstart_K5_Q10.json
  - Report:     C:\anonymous-users-mooc-session-meta\reports\07e_maml_warmstart_xuetangx\20260112_151130\report.json

üéØ Next Steps:
  - Fine-tune hyperparameters (Œ±, Œ≤, inner steps)
  - Try different architectures 

In [26]:
# Test on VALIDATION set instead of test set (to verify the 37.80% claim)
print("Testing on VALIDATION set (same as training used)...")

meta_model.eval()
val_predictions = []
val_labels = []

for _, episode in episodes_val.iterrows():
    support_pairs, query_pairs = get_episode_data(episode, pairs_val)
    
    if len(support_pairs) == 0 or len(query_pairs) == 0:
        continue
    
    support_seq, support_labels_val, support_lengths = pairs_to_batch(support_pairs, max_seq_len)
    query_seq, query_labels_val, query_lengths = pairs_to_batch(query_pairs, max_seq_len)
    
    # Adapt using functional forward
    with torch.enable_grad():
        fast_weights_val = OrderedDict()
        for name, param in meta_model.named_parameters():
            fast_weights_val[name] = param.clone().requires_grad_()
        
        for _ in range(num_inner_steps):
            support_logits_val = functional_forward(
                support_seq, support_lengths, fast_weights_val, hidden_dim, n_items
            )
            support_loss_val = criterion(support_logits_val, support_labels_val)
            
            grads_val = torch.autograd.grad(
                support_loss_val,
                fast_weights_val.values(),
                create_graph=False
            )
            
            fast_weights_val = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights_val.items(), grads_val)
            )
    
    # Evaluate on query
    with torch.no_grad():
        query_logits_val = functional_forward(
            query_seq, query_lengths, fast_weights_val, hidden_dim, n_items
        )
        probs = torch.softmax(query_logits_val, dim=-1).cpu().numpy()
        
        val_predictions.append(probs)
        val_labels.extend(query_labels_val.cpu().numpy())

# Compute metrics
val_predictions = np.vstack(val_predictions)
val_labels = np.array(val_labels)
val_metrics = compute_metrics(val_predictions, val_labels)

print(f"\nValidation Set Results:")
print(f"  Accuracy@1:  {val_metrics['accuracy@1']:.4f} (expected ~37.80%)")
print(f"  Recall@5:    {val_metrics['recall@5']:.4f}")
print(f"  MRR:         {val_metrics['mrr']:.4f}")

Testing on VALIDATION set (same as training used)...

Validation Set Results:
  Accuracy@1:  0.3062 (expected ~37.80%)
  Recall@5:    0.4962
  MRR:         0.4003


In [None]:
# [CELL 07f-17] SAVE TRAINED MODEL

import torch
from pathlib import Path

# Model save path
MODEL_SAVE_PATH = Path("models/maml/maml_warmstart_gru_K5.pth")
MODEL_SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)

# Save model
torch.save({
    'model_state_dict': meta_model.state_dict(),
    'config': CFG,
    'final_metrics': {
        'zero_shot_acc1': zeroshot_metrics['accuracy@1'],
        'few_shot_K5_acc1': fewshot_metrics['accuracy@1'],
    },
    'training_iterations': num_meta_iterations,
}, MODEL_SAVE_PATH)

print(f"[CELL 07f-17] Model saved to: {MODEL_SAVE_PATH}")
print(f"[CELL 07f-17] File size: {MODEL_SAVE_PATH.stat().st_size / 1024:.1f} KB")
print(f"[CELL 07f-17] Model can now be loaded for inner_lr sweep tests!")

## ‚úÖ Notebook 07e Complete: MAML with Warm-Start\n
\n
**Key Innovation:** Initialized MAML from pre-trained GRU baseline (33.73%)\n
\n
**Results:**\n
- See CELL 07f-10 for zero-shot results\n
- See CELL 07f-11 for few-shot results (K=5)\n
- See CELL 07f-15 for complete comparison\n
\n
**Expected Outcome:**\n
- MAML warm-start should beat both:\n
  - MAML random init (30.52%)\n
  - GRU baseline (33.73%)\n
- Target: 35-38% Acc@1\n
\n
**Why This Works:**\n
1. Strong initialization (GRU) + Meta-learned adaptation (MAML)\n
2. Best of both worlds\n
3. Well-established approach in meta-learning literature\n
\n
**Next Steps:**\n
- If beats baseline ‚Üí Thesis complete! üéâ\n
- If close but not quite ‚Üí Try bigger model (Fix #4)\n
- Analyze learned adaptations vs random init MAML