# Notebook 09: Recency-Weighted MAML (XuetangX) - Contribution 2

**Purpose:** Improve MAML inner loop by weighting support pairs by recency - more recent pairs get higher weight.

**Research Motivation:**
- In sequential recommendation, recent interactions better reflect current user preferences
- Standard MAML weights all K support pairs equally
- For cold-start users, the most recent interactions are most informative about current learning goals

**Key Insight:**
```
Standard MAML Inner Loop:
  Loss = (1/K) * sum(loss_i)           # Equal weights

Recency-Weighted MAML Inner Loop:
  Loss = sum(w_i * loss_i)             # w_i proportional to recency
  where w_i = exp(lambda * recency_i)  # Exponential decay
```

**Hypothesis:** Recency-weighted loss will improve adaptation by focusing on recent user behavior.

**Experiments:**
1. Recency-Weighted only (random init)
2. Warm-Start + Recency-Weighted (combined)

**Inputs:**
- `data/processed/xuetangx/episodes/episodes_{train|val|test}_K5_Q10.parquet`
- `models/baselines/gru_global.pth` (for Warm-Start variant)
- `data/processed/xuetangx/vocab/course2id.json`

**Outputs:**
- `models/contributions/recency_maml_K5.pth`
- `models/contributions/warmstart_recency_maml_K5.pth`
- `results/recency_maml_K5_Q10.json`
- `reports/09_recency_weighted_maml_xuetangx/<run_tag>/report.json`

**Comparison Matrix:**
| Model | Init | Inner Loop | Expected Acc@1 |
|-------|------|------------|----------------|
| Vanilla MAML | Random | Equal weights | 28.66% |
| Warm-Start MAML | GRU4Rec | Equal weights | ??% (from NB08) |
| Recency MAML | Random | Recency weights | ??% |
| **Warm-Start + Recency** | GRU4Rec | Recency weights | **??%** |

In [None]:
# [CELL 09-00] Bootstrap: repo root + paths + logger

import os
import sys
import json
import time
import uuid
import copy
import hashlib
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Tuple, Optional
from collections import OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

t0 = datetime.now()
print(f"[CELL 09-00] start={t0.isoformat(timespec='seconds')}")
print("[CELL 09-00] CWD:", Path.cwd().resolve())

def find_repo_root(start: Path) -> Path:
    start = start.resolve()
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    raise RuntimeError("Could not find repo root (no PROJECT_STATE.md)")

REPO_ROOT = find_repo_root(Path.cwd())
print(f"[CELL 09-00] REPO_ROOT: {REPO_ROOT}")

sys.path.insert(0, str(REPO_ROOT / "src"))

META_REGISTRY = REPO_ROOT / "meta.json"
DATA_INTERIM = REPO_ROOT / "data" / "interim"
DATA_PROCESSED = REPO_ROOT / "data" / "processed"
MODELS = REPO_ROOT / "models"
RESULTS = REPO_ROOT / "results"
REPORTS = REPO_ROOT / "reports"

PATHS = {
    "REPO_ROOT": REPO_ROOT,
    "META_REGISTRY": META_REGISTRY,
    "DATA_INTERIM": DATA_INTERIM,
    "DATA_PROCESSED": DATA_PROCESSED,
    "MODELS": MODELS,
    "RESULTS": RESULTS,
    "REPORTS": REPORTS,
}

for name, path in PATHS.items():
    print(f"[CELL 09-00] {name}={path}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[CELL 09-00] PyTorch device: {DEVICE}")

def cell_start(cell_id: str, description: str = "") -> datetime:
    t = datetime.now()
    msg = f"[{cell_id}] start={t.isoformat(timespec='seconds')}"
    if description:
        msg += f" | {description}"
    print(msg)
    return t

def cell_end(cell_id: str, t_start: datetime, **kv) -> None:
    elapsed = (datetime.now() - t_start).total_seconds()
    for k, v in kv.items():
        print(f"[{cell_id}] {k}={v}")
    print(f"[{cell_id}] done in {elapsed:.1f}s")

print("[CELL 09-00] done")

In [None]:
# [CELL 09-01] Configuration

t0 = cell_start("CELL 09-01", "Configuration")

RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8]
print(f"[CELL 09-01] RUN_TAG: {RUN_TAG}")

CFG = {
    "notebook": "09_recency_weighted_maml_xuetangx",
    "run_tag": RUN_TAG,
    "dataset": "xuetangx",
    "contribution": "recency_weighted_maml",
    
    # Episode config
    "episode_config": {
        "K": 5,   # support set size
        "Q": 10,  # query set size
    },
    
    # Model config (same as GRU4Rec baseline)
    "model_config": {
        "embed_dim": 64,
        "hidden_dim": 64,
        "n_layers": 1,
        "dropout": 0.1,
    },
    
    # MAML config
    "maml_config": {
        "inner_lr": 0.01,
        "outer_lr": 0.001,
        "inner_steps": 5,
        "meta_batch_size": 32,
        "num_meta_iterations": 3000,
        "use_second_order": False,  # FOMAML
        "val_every": 100,
        "checkpoint_every": 500,
    },
    
    # Recency-Weighted config (NEW - CONTRIBUTION 2)
    "recency_config": {
        "weighting_scheme": "exponential",  # exponential, linear, softmax
        "lambda": 0.5,                       # decay parameter for exponential
        "temperature": 1.0,                  # temperature for softmax
    },
    
    # Warm-Start config (can be combined)
    "warmstart_config": {
        "use_warmstart": True,  # Set to True for combined experiment
        "pretrained_path": "models/baselines/gru_global.pth",
    },
    
    "seed": 42,
}

# Set seeds
np.random.seed(CFG["seed"])
torch.manual_seed(CFG["seed"])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CFG["seed"])

print(f"[CELL 09-01] Configuration:")
print(f"  - Contribution: {CFG['contribution']}")
print(f"  - Recency weighting: {CFG['recency_config']['weighting_scheme']}")
print(f"  - Lambda: {CFG['recency_config']['lambda']}")
print(f"  - Use Warm-Start: {CFG['warmstart_config']['use_warmstart']}")
print(f"  - Episode: K={CFG['episode_config']['K']}, Q={CFG['episode_config']['Q']}")
print(f"  - Meta iterations: {CFG['maml_config']['num_meta_iterations']}")

cell_end("CELL 09-01", t0)

In [None]:
# [CELL 09-02] Setup paths

t0 = cell_start("CELL 09-02", "Setup paths")

# Input paths
EPISODES_DIR = DATA_PROCESSED / "xuetangx" / "episodes"
VOCAB_DIR = DATA_PROCESSED / "xuetangx" / "vocab"
PRETRAINED_PATH = REPO_ROOT / CFG["warmstart_config"]["pretrained_path"]

K = CFG["episode_config"]["K"]
Q = CFG["episode_config"]["Q"]

EPISODES_TRAIN = EPISODES_DIR / f"episodes_train_K{K}_Q{Q}.parquet"
EPISODES_VAL = EPISODES_DIR / f"episodes_val_K{K}_Q{Q}.parquet"
EPISODES_TEST = EPISODES_DIR / f"episodes_test_K{K}_Q{Q}.parquet"
COURSE2ID_PATH = VOCAB_DIR / "course2id.json"

# Output paths
CONTRIB_MODELS_DIR = MODELS / "contributions"
CONTRIB_MODELS_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_DIR = CONTRIB_MODELS_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

REPORT_DIR = REPORTS / "09_recency_weighted_maml_xuetangx" / RUN_TAG
REPORT_DIR.mkdir(parents=True, exist_ok=True)

# Model name depends on whether we use warm-start
if CFG["warmstart_config"]["use_warmstart"]:
    model_name = "warmstart_recency_maml_K5"
else:
    model_name = "recency_maml_K5"

OUT_MODEL = CONTRIB_MODELS_DIR / f"{model_name}.pth"
OUT_RESULTS = RESULTS / f"{model_name}_Q10.json"
REPORT_PATH = REPORT_DIR / "report.json"

print(f"[CELL 09-02] Input episodes: {EPISODES_TRAIN}")
print(f"[CELL 09-02] Output model: {OUT_MODEL}")
print(f"[CELL 09-02] Use Warm-Start: {CFG['warmstart_config']['use_warmstart']}")

cell_end("CELL 09-02", t0)

In [None]:
# [CELL 09-03] Load data

t0 = cell_start("CELL 09-03", "Load episodes and vocabulary")

# Load vocabulary
with open(COURSE2ID_PATH, "r") as f:
    course2id = json.load(f)
n_items = len(course2id)
print(f"[CELL 09-03] Vocabulary size: {n_items} courses")

# Load episodes
episodes_train = pd.read_parquet(EPISODES_TRAIN)
episodes_val = pd.read_parquet(EPISODES_VAL)
episodes_test = pd.read_parquet(EPISODES_TEST)

print(f"[CELL 09-03] Train episodes: {len(episodes_train):,}")
print(f"[CELL 09-03] Val episodes:   {len(episodes_val):,}")
print(f"[CELL 09-03] Test episodes:  {len(episodes_test):,}")

# Check if timestamps are available
print(f"[CELL 09-03] Episode columns: {list(episodes_train.columns)}")

cell_end("CELL 09-03", t0, n_items=n_items)

In [None]:
# [CELL 09-04] Define Recency Weighting Functions (KEY CONTRIBUTION)

t0 = cell_start("CELL 09-04", "Define recency weighting functions")

def compute_recency_weights_exponential(K: int, lam: float = 0.5) -> torch.Tensor:
    """Exponential recency weights.
    
    w_i = exp(lambda * (i / (K-1)))  for i in [0, K-1]
    Normalized to sum to 1.
    
    Args:
        K: Number of support pairs
        lam: Decay parameter (higher = more weight on recent)
        
    Returns:
        weights: [K] tensor, normalized
        
    Example (K=5, lam=0.5):
        positions: [0, 1, 2, 3, 4] (0=oldest, 4=newest)
        raw: [1.0, 1.13, 1.28, 1.46, 1.65]
        normalized: [0.15, 0.17, 0.19, 0.22, 0.25]
    """
    positions = torch.arange(K, dtype=torch.float32)  # [0, 1, ..., K-1]
    normalized_positions = positions / (K - 1) if K > 1 else positions  # [0, ..., 1]
    
    weights = torch.exp(lam * normalized_positions)
    weights = weights / weights.sum()  # Normalize to sum to 1
    
    return weights

def compute_recency_weights_linear(K: int) -> torch.Tensor:
    """Linear recency weights.
    
    w_i = (i + 1) / sum(1..K)
    
    Example (K=5):
        positions: [1, 2, 3, 4, 5]
        normalized: [0.067, 0.133, 0.2, 0.267, 0.333]
    """
    positions = torch.arange(1, K + 1, dtype=torch.float32)  # [1, 2, ..., K]
    weights = positions / positions.sum()
    return weights

def compute_recency_weights_softmax(K: int, temperature: float = 1.0) -> torch.Tensor:
    """Softmax recency weights.
    
    w = softmax(positions / temperature)
    
    Example (K=5, temp=1.0):
        positions: [0, 1, 2, 3, 4]
        softmax: [0.01, 0.04, 0.09, 0.24, 0.62]
    """
    positions = torch.arange(K, dtype=torch.float32)
    weights = F.softmax(positions / temperature, dim=0)
    return weights

# Demonstrate the weighting schemes
K = CFG["episode_config"]["K"]
lam = CFG["recency_config"]["lambda"]
temp = CFG["recency_config"]["temperature"]

print(f"[CELL 09-04] Recency Weighting Schemes (K={K}):")
print(f"")
print(f"  Position:     [oldest]  →  [newest]")
print(f"  Index:        [0, 1, 2, 3, 4]")
print(f"")

w_equal = torch.ones(K) / K
w_exp = compute_recency_weights_exponential(K, lam)
w_linear = compute_recency_weights_linear(K)
w_softmax = compute_recency_weights_softmax(K, temp)

print(f"  Equal:       {[f'{w:.3f}' for w in w_equal.tolist()]}")
print(f"  Exponential: {[f'{w:.3f}' for w in w_exp.tolist()]} (lambda={lam})")
print(f"  Linear:      {[f'{w:.3f}' for w in w_linear.tolist()]}")
print(f"  Softmax:     {[f'{w:.3f}' for w in w_softmax.tolist()]} (temp={temp})")

print(f"")
print(f"  Selected scheme: {CFG['recency_config']['weighting_scheme']}")

cell_end("CELL 09-04", t0)

In [None]:
# [CELL 09-05] Visualize Recency Weights

t0 = cell_start("CELL 09-05", "Visualize recency weights")

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Exponential with different lambdas
ax = axes[0]
for lam_val in [0.1, 0.3, 0.5, 1.0, 2.0]:
    w = compute_recency_weights_exponential(K, lam_val)
    ax.plot(range(K), w.numpy(), marker='o', label=f'λ={lam_val}')
ax.set_xlabel('Support Pair Index (0=oldest)')
ax.set_ylabel('Weight')
ax.set_title('Exponential Weighting')
ax.legend()
ax.grid(True, alpha=0.3)

# Linear
ax = axes[1]
w_equal = torch.ones(K) / K
w_linear = compute_recency_weights_linear(K)
ax.bar(range(K), w_equal.numpy(), alpha=0.5, label='Equal')
ax.bar(range(K), w_linear.numpy(), alpha=0.5, label='Linear')
ax.set_xlabel('Support Pair Index (0=oldest)')
ax.set_ylabel('Weight')
ax.set_title('Linear vs Equal Weighting')
ax.legend()
ax.grid(True, alpha=0.3)

# Softmax with different temperatures
ax = axes[2]
for temp_val in [0.5, 1.0, 2.0, 5.0]:
    w = compute_recency_weights_softmax(K, temp_val)
    ax.plot(range(K), w.numpy(), marker='o', label=f'τ={temp_val}')
ax.set_xlabel('Support Pair Index (0=oldest)')
ax.set_ylabel('Weight')
ax.set_title('Softmax Weighting')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(REPORT_DIR / "recency_weights_visualization.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"[CELL 09-05] Visualization saved to {REPORT_DIR / 'recency_weights_visualization.png'}")

cell_end("CELL 09-05", t0)

In [None]:
# [CELL 09-06] Define GRU4Rec model

t0 = cell_start("CELL 09-06", "Define GRU4Rec model")

class GRURecommender(nn.Module):
    """GRU4Rec model for sequential recommendation."""
    def __init__(self, n_items: int, embed_dim: int, hidden_dim: int, n_layers: int = 1, dropout: float = 0.1):
        super().__init__()
        self.n_items = n_items
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(n_items, embed_dim, padding_idx=0)
        self.gru = nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0.0,
        )
        self.dropout = nn.Dropout(dropout)
        self.output = nn.Linear(hidden_dim, n_items)
        
    def forward(self, x: torch.Tensor, lengths: torch.Tensor = None) -> torch.Tensor:
        embedded = self.embedding(x)
        
        if lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            output, hidden = self.gru(packed)
        else:
            output, hidden = self.gru(embedded)
        
        last_hidden = hidden[-1]
        out = self.dropout(last_hidden)
        logits = self.output(out)
        
        return logits

print(f"[CELL 09-06] GRURecommender defined")

cell_end("CELL 09-06", t0)

In [None]:
# [CELL 09-07] Initialize model (with optional Warm-Start)

t0 = cell_start("CELL 09-07", "Initialize model")

meta_model = GRURecommender(
    n_items=n_items,
    embed_dim=CFG["model_config"]["embed_dim"],
    hidden_dim=CFG["model_config"]["hidden_dim"],
    n_layers=CFG["model_config"]["n_layers"],
    dropout=CFG["model_config"]["dropout"],
).to(DEVICE)

# Optional Warm-Start
if CFG["warmstart_config"]["use_warmstart"]:
    print(f"[CELL 09-07] Loading pre-trained GRU4Rec for Warm-Start...")
    pretrained_state = torch.load(PRETRAINED_PATH, map_location=DEVICE)
    meta_model.load_state_dict(pretrained_state)
    print(f"[CELL 09-07] Warm-Start: YES (from {PRETRAINED_PATH})")
else:
    print(f"[CELL 09-07] Warm-Start: NO (random initialization)")

n_params = sum(p.numel() for p in meta_model.parameters())
print(f"[CELL 09-07] Model parameters: {n_params:,}")

meta_optimizer = torch.optim.Adam(meta_model.parameters(), lr=CFG["maml_config"]["outer_lr"])

cell_end("CELL 09-07", t0)

In [None]:
# [CELL 09-08] Helper functions

t0 = cell_start("CELL 09-08", "Define helper functions")

def pad_sequences(sequences: List[List[int]], pad_value: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
    """Pad variable-length sequences."""
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long)
    max_len = max(lengths).item()
    
    padded = torch.full((len(sequences), max_len), pad_value, dtype=torch.long)
    for i, seq in enumerate(sequences):
        padded[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
    
    return padded, lengths

def functional_forward(model: nn.Module, x: torch.Tensor, lengths: torch.Tensor, 
                       params: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Forward pass with external parameters (for MAML inner loop)."""
    embedded = F.embedding(x, params["embedding.weight"], padding_idx=0)
    
    batch_size = x.size(0)
    hidden_dim = params["gru.weight_hh_l0"].size(1)
    h = torch.zeros(batch_size, hidden_dim, device=x.device)
    
    weight_ih = params["gru.weight_ih_l0"]
    weight_hh = params["gru.weight_hh_l0"]
    bias_ih = params["gru.bias_ih_l0"]
    bias_hh = params["gru.bias_hh_l0"]
    
    for t in range(embedded.size(1)):
        inp = embedded[:, t, :]
        gates = inp @ weight_ih.t() + bias_ih + h @ weight_hh.t() + bias_hh
        
        r, z, n = gates.chunk(3, dim=1)
        r = torch.sigmoid(r)
        z = torch.sigmoid(z)
        n = torch.tanh(n)
        
        h = (1 - z) * n + z * h
    
    logits = h @ params["output.weight"].t() + params["output.bias"]
    
    return logits

def get_recency_weights(K: int, config: dict) -> torch.Tensor:
    """Get recency weights based on config."""
    scheme = config["weighting_scheme"]
    
    if scheme == "exponential":
        return compute_recency_weights_exponential(K, config["lambda"])
    elif scheme == "linear":
        return compute_recency_weights_linear(K)
    elif scheme == "softmax":
        return compute_recency_weights_softmax(K, config["temperature"])
    else:
        raise ValueError(f"Unknown weighting scheme: {scheme}")

print(f"[CELL 09-08] Helper functions defined")

cell_end("CELL 09-08", t0)

In [None]:
# [CELL 09-09] MAML Training with Recency-Weighted Inner Loop (KEY CONTRIBUTION)

t0 = cell_start("CELL 09-09", "MAML Training with Recency-Weighted Inner Loop")

# Training config
inner_lr = CFG["maml_config"]["inner_lr"]
inner_steps = CFG["maml_config"]["inner_steps"]
meta_batch_size = CFG["maml_config"]["meta_batch_size"]
num_iterations = CFG["maml_config"]["num_meta_iterations"]
val_every = CFG["maml_config"]["val_every"]
checkpoint_every = CFG["maml_config"]["checkpoint_every"]

K = CFG["episode_config"]["K"]

# Pre-compute recency weights
recency_weights = get_recency_weights(K, CFG["recency_config"]).to(DEVICE)
print(f"[CELL 09-09] Recency weights: {recency_weights.tolist()}")

print(f"[CELL 09-09] Training config:")
print(f"  - Recency weighting: {CFG['recency_config']['weighting_scheme']}")
print(f"  - Warm-Start: {CFG['warmstart_config']['use_warmstart']}")
print(f"  - Inner steps: {inner_steps}")
print(f"  - Iterations: {num_iterations}")

history = {"train_loss": [], "val_acc": [], "iteration": []}
best_val_acc = 0.0

meta_model.train()

for iteration in range(1, num_iterations + 1):
    meta_optimizer.zero_grad()
    
    task_indices = np.random.choice(len(episodes_train), size=meta_batch_size, replace=False)
    
    meta_loss = 0.0
    
    for task_idx in task_indices:
        row = episodes_train.iloc[task_idx]
        
        support_x_raw = row["support_prefixes"]
        support_y = torch.tensor(row["support_labels"], dtype=torch.long, device=DEVICE)
        query_x_raw = row["query_prefixes"]
        query_y = torch.tensor(row["query_labels"], dtype=torch.long, device=DEVICE)
        
        support_x, support_lengths = pad_sequences(support_x_raw)
        support_x = support_x.to(DEVICE)
        support_lengths = support_lengths.to(DEVICE)
        
        query_x, query_lengths = pad_sequences(query_x_raw)
        query_x = query_x.to(DEVICE)
        query_lengths = query_lengths.to(DEVICE)
        
        params = {name: param.clone() for name, param in meta_model.named_parameters()}
        
        # Inner loop with RECENCY-WEIGHTED LOSS (KEY CONTRIBUTION)
        for _ in range(inner_steps):
            support_logits = functional_forward(meta_model, support_x, support_lengths, params)
            
            # ============================================================
            # KEY CONTRIBUTION: Recency-weighted loss
            # Instead of: loss = F.cross_entropy(logits, labels)  # equal weights
            # We use:     loss = sum(w_i * loss_i)                 # recency weights
            # ============================================================
            per_sample_loss = F.cross_entropy(support_logits, support_y, reduction='none')
            support_loss = (recency_weights * per_sample_loss).sum()
            # ============================================================
            
            grads = torch.autograd.grad(support_loss, params.values(), create_graph=False)
            
            params = {
                name: param - inner_lr * grad
                for (name, param), grad in zip(params.items(), grads)
            }
        
        query_logits = functional_forward(meta_model, query_x, query_lengths, params)
        query_loss = F.cross_entropy(query_logits, query_y)
        
        meta_loss += query_loss
    
    meta_loss = meta_loss / meta_batch_size
    
    meta_loss.backward()
    torch.nn.utils.clip_grad_norm_(meta_model.parameters(), max_norm=10.0)
    meta_optimizer.step()
    
    if iteration % 100 == 0 or iteration == 1:
        print(f"[CELL 09-09] Iteration {iteration}/{num_iterations}, Meta-Loss: {meta_loss.item():.4f}")
        history["train_loss"].append(meta_loss.item())
        history["iteration"].append(iteration)
    
    if iteration % val_every == 0:
        meta_model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for val_idx in range(len(episodes_val)):
                row = episodes_val.iloc[val_idx]
                
                support_x_raw = row["support_prefixes"]
                support_y = torch.tensor(row["support_labels"], dtype=torch.long, device=DEVICE)
                query_x_raw = row["query_prefixes"]
                query_y = torch.tensor(row["query_labels"], dtype=torch.long, device=DEVICE)
                
                support_x, support_lengths = pad_sequences(support_x_raw)
                support_x = support_x.to(DEVICE)
                support_lengths = support_lengths.to(DEVICE)
                
                query_x, query_lengths = pad_sequences(query_x_raw)
                query_x = query_x.to(DEVICE)
                query_lengths = query_lengths.to(DEVICE)
                
                params = {name: param.clone() for name, param in meta_model.named_parameters()}
                
                for _ in range(inner_steps):
                    support_logits = functional_forward(meta_model, support_x, support_lengths, params)
                    per_sample_loss = F.cross_entropy(support_logits, support_y, reduction='none')
                    support_loss = (recency_weights * per_sample_loss).sum()
                    grads = torch.autograd.grad(support_loss, params.values())
                    params = {
                        name: param - inner_lr * grad
                        for (name, param), grad in zip(params.items(), grads)
                    }
                
                query_logits = functional_forward(meta_model, query_x, query_lengths, params)
                preds = query_logits.argmax(dim=-1)
                
                val_correct += (preds == query_y).sum().item()
                val_total += len(query_y)
        
        val_acc = val_correct / val_total
        history["val_acc"].append(val_acc)
        print(f"[CELL 09-09] Iteration {iteration}, Val Acc@1: {val_acc:.4f} ({val_acc*100:.2f}%)")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(meta_model.state_dict(), OUT_MODEL)
            print(f"[CELL 09-09] New best model saved! Val Acc: {val_acc*100:.2f}%")
        
        meta_model.train()
    
    if iteration % checkpoint_every == 0:
        checkpoint_path = CHECKPOINT_DIR / f"recency_checkpoint_iter{iteration}.pth"
        torch.save({
            "iteration": iteration,
            "model_state_dict": meta_model.state_dict(),
            "optimizer_state_dict": meta_optimizer.state_dict(),
            "best_val_acc": best_val_acc,
            "config": CFG,
        }, checkpoint_path)

print(f"\n[CELL 09-09] Training complete!")
print(f"[CELL 09-09] Best validation accuracy: {best_val_acc*100:.2f}%")

cell_end("CELL 09-09", t0, best_val_acc=best_val_acc)

In [None]:
# [CELL 09-10] Final Evaluation on Test Set

t0 = cell_start("CELL 09-10", "Final evaluation on test set")

meta_model.load_state_dict(torch.load(OUT_MODEL, map_location=DEVICE))
meta_model.eval()

print(f"[CELL 09-10] Loaded best model from: {OUT_MODEL}")

# Zero-shot and Few-shot evaluation
zeroshot_correct = 0
zeroshot_total = 0
fewshot_correct = 0
fewshot_total = 0

with torch.no_grad():
    for test_idx in range(len(episodes_test)):
        row = episodes_test.iloc[test_idx]
        
        support_x_raw = row["support_prefixes"]
        support_y = torch.tensor(row["support_labels"], dtype=torch.long, device=DEVICE)
        query_x_raw = row["query_prefixes"]
        query_y = torch.tensor(row["query_labels"], dtype=torch.long, device=DEVICE)
        
        support_x, support_lengths = pad_sequences(support_x_raw)
        support_x = support_x.to(DEVICE)
        support_lengths = support_lengths.to(DEVICE)
        
        query_x, query_lengths = pad_sequences(query_x_raw)
        query_x = query_x.to(DEVICE)
        query_lengths = query_lengths.to(DEVICE)
        
        # Zero-shot
        params_zs = {name: param.clone() for name, param in meta_model.named_parameters()}
        query_logits_zs = functional_forward(meta_model, query_x, query_lengths, params_zs)
        preds_zs = query_logits_zs.argmax(dim=-1)
        zeroshot_correct += (preds_zs == query_y).sum().item()
        zeroshot_total += len(query_y)
        
        # Few-shot with recency-weighted adaptation
        params_fs = {name: param.clone() for name, param in meta_model.named_parameters()}
        
        for _ in range(inner_steps):
            support_logits = functional_forward(meta_model, support_x, support_lengths, params_fs)
            per_sample_loss = F.cross_entropy(support_logits, support_y, reduction='none')
            support_loss = (recency_weights * per_sample_loss).sum()
            grads = torch.autograd.grad(support_loss, params_fs.values())
            params_fs = {
                name: param - inner_lr * grad
                for (name, param), grad in zip(params_fs.items(), grads)
            }
        
        query_logits_fs = functional_forward(meta_model, query_x, query_lengths, params_fs)
        preds_fs = query_logits_fs.argmax(dim=-1)
        fewshot_correct += (preds_fs == query_y).sum().item()
        fewshot_total += len(query_y)

zeroshot_acc = zeroshot_correct / zeroshot_total
fewshot_acc = fewshot_correct / fewshot_total

model_type = "Warm-Start + Recency" if CFG["warmstart_config"]["use_warmstart"] else "Recency Only"

print(f"\n[CELL 09-10] ========== RESULTS ({model_type}) ==========")
print(f"[CELL 09-10] Test episodes: {len(episodes_test)}")
print(f"[CELL 09-10] Recency weighting: {CFG['recency_config']['weighting_scheme']}")
print(f"\n[CELL 09-10] {model_type} Zero-shot: {zeroshot_acc:.4f} ({zeroshot_acc*100:.2f}%)")
print(f"[CELL 09-10] {model_type} Few-shot:  {fewshot_acc:.4f} ({fewshot_acc*100:.2f}%)")
print(f"\n[CELL 09-10] ========== COMPARISON ==========")
print(f"[CELL 09-10] GRU4Rec baseline:          33.55%")
print(f"[CELL 09-10] Vanilla MAML Few-shot:     28.66%")
print(f"[CELL 09-10] {model_type} Few-shot:  {fewshot_acc*100:.2f}%")
print(f"\n[CELL 09-10] Improvement over Vanilla MAML: +{(fewshot_acc - 0.2866)*100:.2f} pp")
print(f"[CELL 09-10] Improvement over GRU4Rec:      {(fewshot_acc - 0.3355)*100:+.2f} pp")

cell_end("CELL 09-10", t0, zeroshot_acc=zeroshot_acc, fewshot_acc=fewshot_acc)

In [None]:
# [CELL 09-11] Save results and report

t0 = cell_start("CELL 09-11", "Save results and report")

model_type = "warmstart_recency" if CFG["warmstart_config"]["use_warmstart"] else "recency_only"

results = {
    "model": model_type,
    "contribution": "recency_weighted_inner_loop",
    "dataset": "xuetangx",
    "config": CFG,
    "metrics": {
        "zeroshot": {"accuracy@1": zeroshot_acc},
        "fewshot": {"accuracy@1": fewshot_acc},
    },
    "comparison": {
        "gru4rec_baseline": 0.3355,
        "vanilla_maml_fewshot": 0.2866,
        f"{model_type}_fewshot": fewshot_acc,
    },
    "improvement": {
        "over_vanilla_maml": fewshot_acc - 0.2866,
        "over_gru4rec": fewshot_acc - 0.3355,
    },
}

with open(OUT_RESULTS, "w") as f:
    json.dump(results, f, indent=2)
print(f"[CELL 09-11] Results saved: {OUT_RESULTS}")

report = {
    "notebook": CFG["notebook"],
    "run_tag": RUN_TAG,
    "timestamp": datetime.now().isoformat(),
    "config": CFG,
    "results": results,
    "history": history,
}

with open(REPORT_PATH, "w") as f:
    json.dump(report, f, indent=2)
print(f"[CELL 09-11] Report saved: {REPORT_PATH}")

cell_end("CELL 09-11", t0)

## Notebook 09 Complete: Recency-Weighted MAML Results

**Contribution:** Weight support pairs by recency in MAML inner loop - recent pairs get higher weight.

**Key Results:**

| Model | Init | Inner Loop | Acc@1 | vs Vanilla |
|-------|------|------------|-------|------------|
| GRU4Rec baseline | Trained | N/A | 33.55% | - |
| Vanilla MAML | Random | Equal weights | 28.66% | baseline |
| Warm-Start MAML | GRU4Rec | Equal weights | ??% | +?? pp |
| Recency MAML | Random | Recency weights | ??% | +?? pp |
| **Warm-Start + Recency** | GRU4Rec | Recency weights | **??%** | **+?? pp** |

**Ablation: Weighting Schemes**

| Scheme | Lambda/Temp | Acc@1 |
|--------|-------------|-------|
| Equal (baseline) | - | 28.66% |
| Exponential | λ=0.3 | ??% |
| Exponential | λ=0.5 | ??% |
| Exponential | λ=1.0 | ??% |
| Linear | - | ??% |
| Softmax | τ=1.0 | ??% |

**Key Insight:** Recent learning activities better reflect current user preferences in MOOCs.