# Notebook 11: Reliability-Weighted MAML (XuetangX)

## Contribution 3: Session Reliability for Meta-Learning Adaptation

**Core Idea:** Higher-reliability sessions provide higher-confidence evidence for MAML adaptation. We weight the inner loop loss by session reliability scores.

### Key Modification
```python
# Standard MAML (NB07):
support_loss = criterion(logits, labels).mean()

# Reliability-Weighted MAML (this notebook):
per_sample_loss = criterion_none(logits, labels)  # reduction='none'
support_loss = (reliability_weights * per_sample_loss).sum() / reliability_weights.sum()
```

### Reliability Score Definition
```
reliability = (intensity + extent + composition) / 3

where:
  intensity   = min(n_events / 100, 1.0)      # Event count
  extent      = min(duration_sec / 1800, 1.0) # Session duration (capped at 30 min)
  composition = n_action_types / 8            # Behavioral diversity
```

### Results: Comparison with Vanilla MAML (NB07)

| Method | Test HR@10 | Test NDCG@10 | 
|--------|------------|--------------|
| Vanilla MAML (NB07) | 47.35% | 37.41% |
| **Reliability-Weighted MAML (NB11)** | **48.34%** | **37.71%** |
| **Improvement** | **+0.99%** | **+0.30%** |

**Finding:** Reliability weighting improves cold-start recommendation by giving more weight to high-quality sessions during adaptation.

### Dataset (XuetangX)
- Training episodes: 47,357 | Validation: 341 | Test: 313
- Pairs with reliability: 281,979
- Reliability score range: [0.0485, 1.0000]
- Vocabulary: 1,518 courses

### Inputs
- `data/processed/xuetangx/pairs_with_reliability/pairs.parquet` (from NB03b)
- `data/processed/xuetangx/episodes/episodes_*.parquet` (from NB05)

### Outputs
- `models/maml/reliability_weighted_maml.pt`
- `reports/11_reliability_weighted_maml_xuetangx/<run_tag>/report.json`

In [1]:
# [CELL 11-00] Bootstrap

import os
import sys
import json
import time
import uuid
import hashlib
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from tqdm.auto import tqdm

t0 = datetime.now()
print(f"[CELL 11-00] start={t0.isoformat(timespec='seconds')}")
print(f"[CELL 11-00] PyTorch: {torch.__version__}")
print(f"[CELL 11-00] CUDA available: {torch.cuda.is_available()}")

def find_repo_root(start: Path) -> Path:
    start = start.resolve()
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    raise RuntimeError("Could not find PROJECT_STATE.md.")

REPO_ROOT = find_repo_root(Path.cwd())
print(f"[CELL 11-00] REPO_ROOT: {REPO_ROOT}")

PATHS = {
    "DATA_PROCESSED": REPO_ROOT / "data" / "processed",
    "MODELS": REPO_ROOT / "models",
    "REPORTS": REPO_ROOT / "reports",
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[CELL 11-00] Device: {DEVICE}")

def cell_start(cell_id: str, title: str, **kwargs) -> float:
    t = time.time()
    print(f"\n[{cell_id}] {title}")
    print(f"[{cell_id}] start={datetime.now().isoformat(timespec='seconds')}")
    for k, v in kwargs.items():
        print(f"[{cell_id}] {k}={v}")
    return t

def cell_end(cell_id: str, t0: float, **kwargs) -> None:
    for k, v in kwargs.items():
        print(f"[{cell_id}] {k}={v}")
    print(f"[{cell_id}] elapsed={time.time()-t0:.2f}s")
    print(f"[{cell_id}] done")

print("[CELL 11-00] done")

[CELL 11-00] start=2026-02-04T03:16:58
[CELL 11-00] PyTorch: 2.10.0+cu128
[CELL 11-00] CUDA available: True
[CELL 11-00] REPO_ROOT: /workspace/anonymous-users-mooc-session-meta
[CELL 11-00] Device: cuda
[CELL 11-00] done


In [2]:
# [CELL 11-01] Reproducibility + JSON helpers

t0 = cell_start("CELL 11-01", "Seed everything")

GLOBAL_SEED = 20260107

def seed_everything(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything(GLOBAL_SEED)

def write_json_atomic(path: Path, obj: Any, indent: int = 2) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    tmp = path.with_suffix(path.suffix + f".tmp_{uuid.uuid4().hex}")
    with tmp.open("w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=indent)
    tmp.replace(path)

def read_json(path: Path) -> Any:
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)

cell_end("CELL 11-01", t0, seed=GLOBAL_SEED)


[CELL 11-01] Seed everything
[CELL 11-01] start=2026-02-04T03:16:59
[CELL 11-01] seed=20260107
[CELL 11-01] elapsed=0.00s
[CELL 11-01] done


In [3]:
# [CELL 11-02] Configuration

t0 = cell_start("CELL 11-02", "Configuration")

NOTEBOOK_NAME = "11_reliability_weighted_maml_xuetangx"
RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_ID = uuid.uuid4().hex

OUT_DIR = PATHS["REPORTS"] / NOTEBOOK_NAME / RUN_TAG
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Episode configuration
K = 5   # Support set size
Q = 10  # Query set size

# Paths
EPISODES_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "episodes"
PAIRS_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "pairs_with_reliability"  # NEW: Use reliability pairs
VOCAB_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "vocab"
MODELS_DIR = PATHS["MODELS"] / "maml"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

CFG = {
    "notebook": NOTEBOOK_NAME,
    "run_id": RUN_ID,
    "run_tag": RUN_TAG,
    "seed": GLOBAL_SEED,
    "K": K,
    "Q": Q,
    "inputs": {
        "episodes_train": str(EPISODES_DIR / f"episodes_train_K{K}_Q{Q}.parquet"),
        "episodes_val": str(EPISODES_DIR / f"episodes_val_K{K}_Q{Q}.parquet"),
        "episodes_test": str(EPISODES_DIR / f"episodes_test_K{K}_Q{Q}.parquet"),
        "pairs_with_reliability": str(PAIRS_DIR / "pairs.parquet"),
        "session_reliability": str(PAIRS_DIR / "session_reliability.parquet"),
        "course2id": str(VOCAB_DIR / "course2id.json"),
        "pretrained_gru": str(PATHS["MODELS"] / "gru" / "gru4rec_pretrained.pt"),
    },
    "outputs": {
        "model": str(MODELS_DIR / "reliability_weighted_maml.pt"),
        "report": str(OUT_DIR / "report.json"),
    },
    "gru_config": {
        "embedding_dim": 64,
        "hidden_dim": 128,
        "num_layers": 1,
        "dropout": 0.2,
        "max_seq_len": 50,
    },
    "maml_config": {
        "inner_lr": 0.01,
        "outer_lr": 0.001,
        "num_inner_steps": 5,
        "meta_batch_size": 32,
        "num_meta_iterations": 3000,
        "use_second_order": False,  # FOMAML for efficiency
        "use_reliability_weighting": True,  # KEY: Enable reliability weighting
    },
    "eval_config": {
        "eval_every": 100,
        "patience": 10,
    },
}

write_json_atomic(OUT_DIR / "config.json", CFG)

print(f"[CELL 11-02] K={K}, Q={Q}")
print(f"[CELL 11-02] Reliability weighting: {CFG['maml_config']['use_reliability_weighting']}")
print(f"[CELL 11-02] Output dir: {OUT_DIR}")

cell_end("CELL 11-02", t0)


[CELL 11-02] Configuration
[CELL 11-02] start=2026-02-04T03:16:59
[CELL 11-02] K=5, Q=10
[CELL 11-02] Reliability weighting: True
[CELL 11-02] Output dir: /workspace/anonymous-users-mooc-session-meta/reports/11_reliability_weighted_maml_xuetangx/20260204_031659
[CELL 11-02] elapsed=0.00s
[CELL 11-02] done


In [4]:
# [CELL 11-03] Load data with reliability scores

t0 = cell_start("CELL 11-03", "Load data with reliability")

# Load episodes
episodes_train = pd.read_parquet(CFG["inputs"]["episodes_train"])
episodes_val = pd.read_parquet(CFG["inputs"]["episodes_val"])
episodes_test = pd.read_parquet(CFG["inputs"]["episodes_test"])

print(f"[CELL 11-03] Episodes train: {len(episodes_train):,}")
print(f"[CELL 11-03] Episodes val: {len(episodes_val):,}")
print(f"[CELL 11-03] Episodes test: {len(episodes_test):,}")

# Load pairs with reliability (from NB03b)
pairs_path = Path(CFG["inputs"]["pairs_with_reliability"])
if pairs_path.exists():
    pairs_all = pd.read_parquet(pairs_path)
    HAS_RELIABILITY = 'session_reliability' in pairs_all.columns
    print(f"[CELL 11-03] Loaded pairs with reliability: {len(pairs_all):,}")
    if HAS_RELIABILITY:
        print(f"[CELL 11-03] Reliability range: [{pairs_all['session_reliability'].min():.4f}, {pairs_all['session_reliability'].max():.4f}]")
else:
    # Fallback to original pairs without reliability
    print(f"[CELL 11-03] WARNING: Pairs with reliability not found, using original pairs")
    pairs_all = pd.read_parquet(PATHS["DATA_PROCESSED"] / "xuetangx" / "pairs" / "pairs.parquet")
    pairs_all['session_reliability'] = 1.0  # Default uniform weight
    HAS_RELIABILITY = False

# Load vocab
course2id = read_json(Path(CFG["inputs"]["course2id"]))
n_items = len(course2id)
print(f"[CELL 11-03] Vocabulary: {n_items} items")

# Create pair_id → reliability mapping
pair_reliability_map = pairs_all.set_index('pair_id')['session_reliability'].to_dict()

cell_end("CELL 11-03", t0, has_reliability=HAS_RELIABILITY)


[CELL 11-03] Load data with reliability
[CELL 11-03] start=2026-02-04T03:16:59
[CELL 11-03] Episodes train: 47,357
[CELL 11-03] Episodes val: 341
[CELL 11-03] Episodes test: 313
[CELL 11-03] Loaded pairs with reliability: 281,979
[CELL 11-03] Reliability range: [0.0485, 1.0000]
[CELL 11-03] Vocabulary: 1518 items
[CELL 11-03] has_reliability=True
[CELL 11-03] elapsed=0.28s
[CELL 11-03] done


In [5]:
# [CELL 11-04] GRU4Rec Model (same as NB07)

t0 = cell_start("CELL 11-04", "Define GRU4Rec model")

class GRU4Rec(nn.Module):
    """GRU-based sequential recommendation model."""
    
    def __init__(self, n_items: int, embedding_dim: int, hidden_dim: int,
                 num_layers: int = 1, dropout: float = 0.0):
        super().__init__()
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        self.item_embedding = nn.Embedding(n_items, embedding_dim, padding_idx=0)
        self.gru = nn.GRU(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
        self.output_layer = nn.Linear(hidden_dim, n_items)
        
    def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len) item indices
            lengths: (batch_size,) actual sequence lengths
        Returns:
            logits: (batch_size, n_items)
        """
        embedded = self.item_embedding(x)  # (batch, seq, emb)
        
        if lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            _, hidden = self.gru(packed)
        else:
            _, hidden = self.gru(embedded)
        
        # hidden: (num_layers, batch, hidden_dim) -> take last layer
        last_hidden = hidden[-1]  # (batch, hidden_dim)
        logits = self.output_layer(last_hidden)  # (batch, n_items)
        
        return logits
    
    def get_named_parameters(self) -> OrderedDict:
        """Get parameters as OrderedDict for MAML."""
        return OrderedDict(self.named_parameters())

print(f"[CELL 11-04] GRU4Rec model defined")
cell_end("CELL 11-04", t0)


[CELL 11-04] Define GRU4Rec model
[CELL 11-04] start=2026-02-04T03:16:59
[CELL 11-04] GRU4Rec model defined
[CELL 11-04] elapsed=0.00s
[CELL 11-04] done


In [6]:
# [CELL 11-05] Initialize model

t0 = cell_start("CELL 11-05", "Initialize model")

gru_cfg = CFG["gru_config"]

meta_model = GRU4Rec(
    n_items=n_items,
    embedding_dim=gru_cfg["embedding_dim"],
    hidden_dim=gru_cfg["hidden_dim"],
    num_layers=gru_cfg["num_layers"],
    dropout=gru_cfg["dropout"]
).to(DEVICE)

# Load pretrained weights if available
pretrained_path = Path(CFG["inputs"]["pretrained_gru"])
if pretrained_path.exists():
    state_dict = torch.load(pretrained_path, map_location=DEVICE)
    meta_model.load_state_dict(state_dict)
    print(f"[CELL 11-05] Loaded pretrained GRU from {pretrained_path}")
else:
    print(f"[CELL 11-05] No pretrained model found, using random init")

n_params = sum(p.numel() for p in meta_model.parameters())
print(f"[CELL 11-05] Model parameters: {n_params:,}")

cell_end("CELL 11-05", t0)


[CELL 11-05] Initialize model
[CELL 11-05] start=2026-02-04T03:16:59
[CELL 11-05] No pretrained model found, using random init
[CELL 11-05] Model parameters: 367,470
[CELL 11-05] elapsed=0.23s
[CELL 11-05] done


In [7]:
# [CELL 11-06] Episode data extraction helpers

t0 = cell_start("CELL 11-06", "Define episode helpers")

def get_episode_data_with_reliability(episode_row, pairs_df, pair_reliability_map):
    """
    Extract support and query pairs for an episode WITH reliability scores.
    
    Returns:
        support_pairs: DataFrame with support set
        query_pairs: DataFrame with query set
        support_reliability: list of reliability scores for support pairs
    """
    support_pair_ids = episode_row["support_pair_ids"]
    query_pair_ids = episode_row["query_pair_ids"]
    
    support_pairs = pairs_df[pairs_df["pair_id"].isin(support_pair_ids)].sort_values("label_ts_epoch")
    query_pairs = pairs_df[pairs_df["pair_id"].isin(query_pair_ids)].sort_values("label_ts_epoch")
    
    # Get reliability scores for support pairs
    support_reliability = [pair_reliability_map.get(pid, 1.0) for pid in support_pairs['pair_id']]
    
    return support_pairs, query_pairs, support_reliability


def prepare_batch(pairs_df, max_seq_len: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Prepare batch tensors from pairs DataFrame.
    
    Returns:
        sequences: (batch, max_seq_len) padded sequences
        lengths: (batch,) actual lengths
        labels: (batch,) target labels
    """
    batch_size = len(pairs_df)
    sequences = torch.zeros(batch_size, max_seq_len, dtype=torch.long)
    lengths = torch.zeros(batch_size, dtype=torch.long)
    labels = torch.zeros(batch_size, dtype=torch.long)
    
    for i, (_, row) in enumerate(pairs_df.iterrows()):
        prefix = row["prefix"]
        seq_len = min(len(prefix), max_seq_len)
        
        # Take last max_seq_len items (most recent)
        if len(prefix) > max_seq_len:
            prefix = prefix[-max_seq_len:]
        
        sequences[i, :seq_len] = torch.tensor(prefix[:seq_len])
        lengths[i] = seq_len
        labels[i] = row["label"]
    
    return sequences, lengths, labels

print(f"[CELL 11-06] Episode helpers defined")
cell_end("CELL 11-06", t0)


[CELL 11-06] Define episode helpers
[CELL 11-06] start=2026-02-04T03:16:59
[CELL 11-06] Episode helpers defined
[CELL 11-06] elapsed=0.00s
[CELL 11-06] done


In [8]:
# [CELL 11-07] Reliability-Weighted MAML Inner Loop (KEY CONTRIBUTION) - FIXED

t0 = cell_start("CELL 11-07", "Define Reliability-Weighted MAML")

import copy

def reliability_weighted_inner_loop(
    model: nn.Module,
    support_seqs: torch.Tensor,
    support_lengths: torch.Tensor,
    support_labels: torch.Tensor,
    support_reliability: torch.Tensor,
    inner_lr: float,
    num_inner_steps: int,
    use_second_order: bool = False,
) -> None:
    """
    MAML inner loop with reliability-weighted loss.
    Updates model parameters IN-PLACE.
    
    KEY MODIFICATION: Weight each support sample's loss by its reliability score.
    
    Standard MAML:   L = mean(loss_i)
    Reliability:     L = sum(w_i * loss_i) / sum(w_i)
    """
    # Criterion WITHOUT reduction (to get per-sample losses)
    criterion_none = nn.CrossEntropyLoss(reduction='none')
    
    for step in range(num_inner_steps):
        # Forward pass using current model parameters
        logits = model(support_seqs, support_lengths)
        
        # Per-sample cross-entropy loss
        per_sample_loss = criterion_none(logits, support_labels)  # (K,)
        
        # RELIABILITY-WEIGHTED LOSS (KEY CONTRIBUTION)
        # Higher reliability → higher weight → more influence on adaptation
        weighted_loss = (support_reliability * per_sample_loss).sum() / (support_reliability.sum() + 1e-8)
        
        # Compute gradients w.r.t. model parameters directly
        grads = torch.autograd.grad(
            weighted_loss,
            model.parameters(),
            create_graph=use_second_order,
            allow_unused=True  # Some params may not be used
        )
        
        # Update model parameters in-place
        with torch.no_grad():
            for param, grad in zip(model.parameters(), grads):
                if grad is not None:
                    param.sub_(inner_lr * grad)


def standard_inner_loop(
    model: nn.Module,
    support_seqs: torch.Tensor,
    support_lengths: torch.Tensor,
    support_labels: torch.Tensor,
    inner_lr: float,
    num_inner_steps: int,
    use_second_order: bool = False,
) -> None:
    """
    Standard MAML inner loop (for baseline comparison).
    Updates model parameters IN-PLACE.
    """
    criterion = nn.CrossEntropyLoss()
    
    for step in range(num_inner_steps):
        logits = model(support_seqs, support_lengths)
        loss = criterion(logits, support_labels)
        
        grads = torch.autograd.grad(
            loss,
            model.parameters(),
            create_graph=use_second_order,
            allow_unused=True
        )
        
        with torch.no_grad():
            for param, grad in zip(model.parameters(), grads):
                if grad is not None:
                    param.sub_(inner_lr * grad)


print(f"[CELL 11-07] Reliability-Weighted MAML inner loop defined (FIXED)")
print(f"[CELL 11-07] Key modification: weighted_loss = (reliability * per_sample_loss).sum() / reliability.sum()")
print(f"[CELL 11-07] Now computes gradients w.r.t. model.parameters() directly")

cell_end("CELL 11-07", t0)


[CELL 11-07] Define Reliability-Weighted MAML
[CELL 11-07] start=2026-02-04T03:16:59
[CELL 11-07] Reliability-Weighted MAML inner loop defined (FIXED)
[CELL 11-07] Key modification: weighted_loss = (reliability * per_sample_loss).sum() / reliability.sum()
[CELL 11-07] Now computes gradients w.r.t. model.parameters() directly
[CELL 11-07] elapsed=0.00s
[CELL 11-07] done


In [9]:
# [CELL 11-08] Meta-training setup

t0 = cell_start("CELL 11-08", "Meta-training setup")

maml_cfg = CFG["maml_config"]

meta_optimizer = torch.optim.Adam(meta_model.parameters(), lr=maml_cfg["outer_lr"])
criterion = nn.CrossEntropyLoss()

inner_lr = maml_cfg["inner_lr"]
num_inner_steps = maml_cfg["num_inner_steps"]
meta_batch_size = maml_cfg["meta_batch_size"]
num_meta_iterations = maml_cfg["num_meta_iterations"]
max_seq_len = CFG["gru_config"]["max_seq_len"]
use_second_order = maml_cfg["use_second_order"]
use_reliability_weighting = maml_cfg["use_reliability_weighting"]

print(f"[CELL 11-08] Inner LR: {inner_lr}")
print(f"[CELL 11-08] Outer LR: {maml_cfg['outer_lr']}")
print(f"[CELL 11-08] Inner steps: {num_inner_steps}")
print(f"[CELL 11-08] Meta batch size: {meta_batch_size}")
print(f"[CELL 11-08] Meta iterations: {num_meta_iterations}")
print(f"[CELL 11-08] Use reliability weighting: {use_reliability_weighting}")
print(f"[CELL 11-08] Use second order: {use_second_order}")

cell_end("CELL 11-08", t0)


[CELL 11-08] Meta-training setup
[CELL 11-08] start=2026-02-04T03:16:59
[CELL 11-08] Inner LR: 0.01
[CELL 11-08] Outer LR: 0.001
[CELL 11-08] Inner steps: 5
[CELL 11-08] Meta batch size: 32
[CELL 11-08] Meta iterations: 3000
[CELL 11-08] Use reliability weighting: True
[CELL 11-08] Use second order: False
[CELL 11-08] elapsed=0.72s
[CELL 11-08] done


In [10]:
# [CELL 11-09] Evaluation function - FIXED

t0 = cell_start("CELL 11-09", "Define evaluation")

import warnings

def evaluate_maml(
    model: nn.Module,
    episodes_df: pd.DataFrame,
    pairs_df: pd.DataFrame,
    pair_reliability_map: dict,
    inner_lr: float,
    num_inner_steps: int,
    max_seq_len: int,
    use_reliability_weighting: bool,
    max_episodes: int = 100,
) -> Dict[str, float]:
    """
    Evaluate MAML on validation/test episodes.
    
    Returns:
        metrics: Dict with HR@10, NDCG@10, etc.
    """
    model.eval()
    
    all_hr10 = []
    all_ndcg10 = []
    
    episodes_sample = episodes_df.sample(min(max_episodes, len(episodes_df)), random_state=GLOBAL_SEED)
    
    for _, episode in tqdm(episodes_sample.iterrows(), total=len(episodes_sample), desc="Evaluating"):
        # Save original weights (deep copy)
        original_state = copy.deepcopy(model.state_dict())
        
        # Get episode data
        support_pairs, query_pairs, support_reliability = get_episode_data_with_reliability(
            episode, pairs_df, pair_reliability_map
        )
        
        if len(support_pairs) == 0 or len(query_pairs) == 0:
            model.load_state_dict(original_state)
            continue
        
        # Prepare support batch
        support_seqs, support_lengths, support_labels = prepare_batch(support_pairs, max_seq_len)
        support_seqs = support_seqs.to(DEVICE)
        support_lengths = support_lengths.to(DEVICE)
        support_labels = support_labels.to(DEVICE)
        support_reliability_tensor = torch.tensor(support_reliability, dtype=torch.float32).to(DEVICE)
        
        # Inner loop adaptation (updates model in-place)
        model.train()
        if use_reliability_weighting:
            reliability_weighted_inner_loop(
                model, support_seqs, support_lengths, support_labels,
                support_reliability_tensor, inner_lr, num_inner_steps, False
            )
        else:
            standard_inner_loop(
                model, support_seqs, support_lengths, support_labels,
                inner_lr, num_inner_steps, False
            )
        model.eval()
        
        # Evaluate on query set
        query_seqs, query_lengths, query_labels = prepare_batch(query_pairs, max_seq_len)
        query_seqs = query_seqs.to(DEVICE)
        query_lengths = query_lengths.to(DEVICE)
        query_labels = query_labels.to(DEVICE)
        
        with torch.no_grad():
            query_logits = model(query_seqs, query_lengths)
            
            # Compute HR@10, NDCG@10
            _, top10_indices = query_logits.topk(10, dim=1)
            
            for i, label in enumerate(query_labels):
                hit = (top10_indices[i] == label).any().item()
                all_hr10.append(float(hit))
                
                if hit:
                    rank = (top10_indices[i] == label).nonzero(as_tuple=True)[0].item() + 1
                    ndcg = 1.0 / np.log2(rank + 1)
                else:
                    ndcg = 0.0
                all_ndcg10.append(ndcg)
        
        # Restore original weights
        model.load_state_dict(original_state)
    
    model.train()
    
    return {
        "HR@10": np.mean(all_hr10) * 100 if all_hr10 else 0.0,
        "NDCG@10": np.mean(all_ndcg10) * 100 if all_ndcg10 else 0.0,
        "n_queries": len(all_hr10),
    }

print(f"[CELL 11-09] Evaluation function defined (FIXED)")
cell_end("CELL 11-09", t0)


[CELL 11-09] Define evaluation
[CELL 11-09] start=2026-02-04T03:17:00
[CELL 11-09] Evaluation function defined (FIXED)
[CELL 11-09] elapsed=0.00s
[CELL 11-09] done


In [11]:
# [CELL 11-10] Meta-training loop - FIXED for FOMAML

t0 = cell_start("CELL 11-10", "Meta-training loop (FOMAML)")

import warnings

# Filter out the RNN contiguous memory warning (expected with deepcopy)
warnings.filterwarnings('ignore', message='RNN module weights are not part of single contiguous chunk of memory')

train_users = episodes_train["user_id"].unique()
eval_every = CFG["eval_config"]["eval_every"]
patience = CFG["eval_config"]["patience"]

best_val_hr10 = 0.0
best_iteration = 0
patience_counter = 0

training_log = []

meta_model.train()

for meta_iter in tqdm(range(num_meta_iterations), desc="Meta-training"):
    # Sample meta-batch of episodes
    meta_batch_users = np.random.choice(train_users, size=min(meta_batch_size, len(train_users)), replace=False)
    meta_batch_episodes = episodes_train[episodes_train["user_id"].isin(meta_batch_users)]
    
    if len(meta_batch_episodes) == 0:
        continue
    
    meta_batch_episodes = meta_batch_episodes.sample(min(meta_batch_size, len(meta_batch_episodes)))
    
    meta_optimizer.zero_grad()
    
    # Zero gradients on meta_model
    for p in meta_model.parameters():
        if p.grad is not None:
            p.grad.zero_()
    
    total_query_loss = 0.0
    n_valid_episodes = 0
    
    for _, episode in meta_batch_episodes.iterrows():
        # Get episode data with reliability
        support_pairs, query_pairs, support_reliability = get_episode_data_with_reliability(
            episode, pairs_all, pair_reliability_map
        )
        
        if len(support_pairs) == 0 or len(query_pairs) == 0:
            continue
        
        # Prepare batches
        support_seqs, support_lengths, support_labels = prepare_batch(support_pairs, max_seq_len)
        support_seqs = support_seqs.to(DEVICE)
        support_lengths = support_lengths.to(DEVICE)
        support_labels = support_labels.to(DEVICE)
        support_reliability_tensor = torch.tensor(support_reliability, dtype=torch.float32).to(DEVICE)
        
        query_seqs, query_lengths, query_labels = prepare_batch(query_pairs, max_seq_len)
        query_seqs = query_seqs.to(DEVICE)
        query_lengths = query_lengths.to(DEVICE)
        query_labels = query_labels.to(DEVICE)
        
        # FOMAML: Deep copy model for this episode
        adapted_model = copy.deepcopy(meta_model)
        adapted_model.train()
        
        # Flatten GRU parameters to avoid memory warning and improve performance
        adapted_model.gru.flatten_parameters()
        
        # Inner loop with reliability weighting (updates adapted_model in-place)
        if use_reliability_weighting:
            reliability_weighted_inner_loop(
                adapted_model, support_seqs, support_lengths, support_labels,
                support_reliability_tensor, inner_lr, num_inner_steps, use_second_order=False
            )
        else:
            standard_inner_loop(
                adapted_model, support_seqs, support_lengths, support_labels,
                inner_lr, num_inner_steps, use_second_order=False
            )
        
        # Query loss with adapted model
        query_logits = adapted_model(query_seqs, query_lengths)
        query_loss = criterion(query_logits, query_labels)
        
        # Backward on adapted model to get gradients
        query_loss.backward()
        
        # FOMAML: Transfer gradients from adapted_model to meta_model
        with torch.no_grad():
            for meta_p, adapted_p in zip(meta_model.parameters(), adapted_model.parameters()):
                if adapted_p.grad is not None:
                    if meta_p.grad is None:
                        meta_p.grad = adapted_p.grad.clone()
                    else:
                        meta_p.grad.add_(adapted_p.grad)
        
        total_query_loss += query_loss.item()
        n_valid_episodes += 1
        
        # Clean up adapted model
        del adapted_model
    
    if n_valid_episodes > 0:
        # Average gradients across episodes
        for p in meta_model.parameters():
            if p.grad is not None:
                p.grad.div_(n_valid_episodes)
        
        # Gradient step
        meta_optimizer.step()
        
        avg_loss = total_query_loss / n_valid_episodes
    else:
        avg_loss = 0.0
    
    # Evaluation
    if (meta_iter + 1) % eval_every == 0:
        val_metrics = evaluate_maml(
            meta_model, episodes_val, pairs_all, pair_reliability_map,
            inner_lr, num_inner_steps, max_seq_len, use_reliability_weighting,
            max_episodes=50
        )
        
        log_entry = {
            "iteration": meta_iter + 1,
            "train_loss": avg_loss,
            "val_HR@10": val_metrics["HR@10"],
            "val_NDCG@10": val_metrics["NDCG@10"],
        }
        training_log.append(log_entry)
        
        print(f"\n[Iter {meta_iter+1}] Loss: {avg_loss:.4f} | Val HR@10: {val_metrics['HR@10']:.2f}% | Val NDCG@10: {val_metrics['NDCG@10']:.2f}%")
        
        # Early stopping
        if val_metrics["HR@10"] > best_val_hr10:
            best_val_hr10 = val_metrics["HR@10"]
            best_iteration = meta_iter + 1
            patience_counter = 0
            
            # Save best model
            torch.save(meta_model.state_dict(), CFG["outputs"]["model"])
            print(f"  -> New best! Saved model.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nEarly stopping at iteration {meta_iter+1}")
                break

print(f"\n[CELL 11-10] Training complete!")
print(f"[CELL 11-10] Best Val HR@10: {best_val_hr10:.2f}% at iteration {best_iteration}")

cell_end("CELL 11-10", t0, best_val_hr10=f"{best_val_hr10:.2f}%")


[CELL 11-10] Meta-training loop (FOMAML)
[CELL 11-10] start=2026-02-04T03:17:00


Meta-training:   0%|          | 0/3000 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 100] Loss: 5.8223 | Val HR@10: 32.00% | Val NDCG@10: 23.09%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 200] Loss: 5.2474 | Val HR@10: 34.80% | Val NDCG@10: 26.02%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 300] Loss: 5.7127 | Val HR@10: 36.60% | Val NDCG@10: 26.88%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 400] Loss: 5.2427 | Val HR@10: 40.00% | Val NDCG@10: 28.75%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 500] Loss: 5.4818 | Val HR@10: 41.60% | Val NDCG@10: 29.95%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 600] Loss: 4.9888 | Val HR@10: 44.00% | Val NDCG@10: 32.37%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 700] Loss: 4.5757 | Val HR@10: 45.00% | Val NDCG@10: 33.05%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 800] Loss: 4.1929 | Val HR@10: 47.60% | Val NDCG@10: 34.28%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 900] Loss: 4.9273 | Val HR@10: 49.00% | Val NDCG@10: 35.96%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1000] Loss: 4.3856 | Val HR@10: 49.80% | Val NDCG@10: 36.97%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1100] Loss: 4.9536 | Val HR@10: 50.80% | Val NDCG@10: 36.95%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1200] Loss: 4.5574 | Val HR@10: 50.00% | Val NDCG@10: 37.03%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1300] Loss: 4.3743 | Val HR@10: 49.60% | Val NDCG@10: 37.37%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1400] Loss: 4.0337 | Val HR@10: 49.80% | Val NDCG@10: 37.53%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1500] Loss: 3.8936 | Val HR@10: 50.00% | Val NDCG@10: 38.35%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1600] Loss: 3.2242 | Val HR@10: 48.80% | Val NDCG@10: 37.93%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1700] Loss: 3.3212 | Val HR@10: 50.80% | Val NDCG@10: 39.36%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1800] Loss: 3.2605 | Val HR@10: 51.20% | Val NDCG@10: 39.58%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 1900] Loss: 3.7080 | Val HR@10: 51.60% | Val NDCG@10: 40.49%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2000] Loss: 3.2254 | Val HR@10: 50.80% | Val NDCG@10: 39.68%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2100] Loss: 3.2472 | Val HR@10: 50.80% | Val NDCG@10: 40.03%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2200] Loss: 3.5952 | Val HR@10: 51.60% | Val NDCG@10: 40.37%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2300] Loss: 3.3283 | Val HR@10: 51.60% | Val NDCG@10: 41.30%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2400] Loss: 2.7498 | Val HR@10: 51.60% | Val NDCG@10: 41.04%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2500] Loss: 2.9639 | Val HR@10: 52.60% | Val NDCG@10: 41.46%
  -> New best! Saved model.


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2600] Loss: 3.3968 | Val HR@10: 52.40% | Val NDCG@10: 41.79%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2700] Loss: 2.4927 | Val HR@10: 51.60% | Val NDCG@10: 41.32%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2800] Loss: 2.9464 | Val HR@10: 52.20% | Val NDCG@10: 41.56%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 2900] Loss: 2.9473 | Val HR@10: 52.20% | Val NDCG@10: 41.81%


Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]


[Iter 3000] Loss: 2.5238 | Val HR@10: 52.00% | Val NDCG@10: 41.68%

[CELL 11-10] Training complete!
[CELL 11-10] Best Val HR@10: 52.60% at iteration 2500
[CELL 11-10] best_val_hr10=52.60%
[CELL 11-10] elapsed=2348.10s
[CELL 11-10] done


In [12]:
# [CELL 11-11] Final evaluation on test set

t0 = cell_start("CELL 11-11", "Final test evaluation")

# Load best model
best_model_path = Path(CFG["outputs"]["model"])
if best_model_path.exists():
    meta_model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    print(f"[CELL 11-11] Loaded best model from {best_model_path}")

# Evaluate on test set
test_metrics = evaluate_maml(
    meta_model, episodes_test, pairs_all, pair_reliability_map,
    inner_lr, num_inner_steps, max_seq_len, use_reliability_weighting,
    max_episodes=len(episodes_test)  # Evaluate on all test episodes
)

print(f"\n[CELL 11-11] ===== TEST RESULTS =====")
print(f"[CELL 11-11] Test HR@10:   {test_metrics['HR@10']:.2f}%")
print(f"[CELL 11-11] Test NDCG@10: {test_metrics['NDCG@10']:.2f}%")
print(f"[CELL 11-11] N queries:    {test_metrics['n_queries']}")
print(f"[CELL 11-11] ===========================")

cell_end("CELL 11-11", t0, test_HR10=f"{test_metrics['HR@10']:.2f}%")


[CELL 11-11] Final test evaluation
[CELL 11-11] start=2026-02-04T03:56:08
[CELL 11-11] Loaded best model from /workspace/anonymous-users-mooc-session-meta/models/maml/reliability_weighted_maml.pt


Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]


[CELL 11-11] ===== TEST RESULTS =====
[CELL 11-11] Test HR@10:   48.34%
[CELL 11-11] Test NDCG@10: 37.71%
[CELL 11-11] N queries:    3130
[CELL 11-11] test_HR10=48.34%
[CELL 11-11] elapsed=6.99s
[CELL 11-11] done


In [13]:
# [CELL 11-12] Save report

t0 = cell_start("CELL 11-12", "Save report")

report = {
    "run_id": RUN_ID,
    "notebook": NOTEBOOK_NAME,
    "run_tag": RUN_TAG,
    "created_at": datetime.now().isoformat(timespec="seconds"),
    "config": CFG,
    "results": {
        "best_val_HR@10": best_val_hr10,
        "best_iteration": best_iteration,
        "test_HR@10": test_metrics["HR@10"],
        "test_NDCG@10": test_metrics["NDCG@10"],
        "test_n_queries": test_metrics["n_queries"],
    },
    "training_log": training_log,
    "key_findings": [
        f"Reliability-Weighted MAML achieved {test_metrics['HR@10']:.2f}% HR@10 on test set.",
        f"Best validation HR@10: {best_val_hr10:.2f}% at iteration {best_iteration}.",
        f"Reliability weighting {'enabled' if use_reliability_weighting else 'disabled'}.",
    ],
}

write_json_atomic(OUT_DIR / "report.json", report)

print(f"[CELL 11-12] Report saved to {OUT_DIR / 'report.json'}")

cell_end("CELL 11-12", t0)


[CELL 11-12] Save report
[CELL 11-12] start=2026-02-04T03:56:15
[CELL 11-12] Report saved to /workspace/anonymous-users-mooc-session-meta/reports/11_reliability_weighted_maml_xuetangx/20260204_031659/report.json
[CELL 11-12] elapsed=0.00s
[CELL 11-12] done


## Notebook 11 Complete

### Main Result: Reliability-Weighted MAML vs Vanilla MAML

| Method | Test HR@10 | Test NDCG@10 | Source |
|--------|------------|--------------|--------|
| Vanilla MAML | 47.35% | 37.41% | NB07 |
| **Reliability-Weighted MAML** | **48.34%** | **37.71%** | NB11 |
| **Improvement** | **+0.99%** | **+0.30%** | |

**Conclusion:** Weighting the MAML inner loop by session reliability improves cold-start recommendation performance.

---

### Training Details

| Parameter | Value |
|-----------|-------|
| Meta-iterations | 3,000 |
| Inner LR | 0.01 |
| Outer LR | 0.001 |
| Inner steps | 5 |
| Meta batch size | 32 |
| Best Val HR@10 | **52.60%** (iteration 2500) |
| Training time | 2,348s (~39 min) |

### Training Progression
| Iteration | Val HR@10 |
|-----------|-----------|
| 100 | 32.00% |
| 500 | 41.60% |
| 1000 | 49.80% |
| 1500 | 50.00% |
| 2000 | 50.80% |
| **2500** | **52.60%** (best) |
| 3000 | 52.00% |

### Outputs
- Model: `models/maml/reliability_weighted_maml.pt`
- Report: `reports/11_reliability_weighted_maml_xuetangx/20260204_031659/report.json`

### Key Contribution (Contribution 3)

**Reliability-Weighted Inner Loop:**
```python
weighted_loss = (reliability * per_sample_loss).sum() / reliability.sum()
```

**Reliability Score:**
```
reliability = (intensity + extent + composition) / 3
  - intensity:   min(n_events / 100, 1.0)
  - extent:      min(duration_sec / 1800, 1.0)  
  - composition: n_action_types / 8
```

**Intuition:** Sessions with more events, longer duration, and diverse actions provide more reliable signals for user preference learning. These high-reliability sessions should have more influence during MAML adaptation.

### Validation (from NB10)
- Point-biserial correlation between reliability and return: r = 0.217 (p < 0.001)
- Users who return have higher mean reliability: 0.279 vs 0.192
- Reliability predicts user engagement → valid quality signal