Imports + env info (SASRec baseline)

In [1]:
# [CELL 10-00] Imports + env info (SASRec baseline)

import os
import json
import math
import time
import random
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

print("[10-00] Imports OK")
print("[10-00] torch:", torch.__version__)
print("[10-00] pandas:", pd.__version__)
print("[10-00] numpy:", np.__version__)


[10-00] Imports OK
[10-00] torch: 2.9.1+cpu
[10-00] pandas: 2.3.3
[10-00] numpy: 2.4.0


Locate repo root + fixed upstream run tags + load protocol/config artifacts

In [2]:
# [CELL 10-01] Locate repo root + fixed upstream run tags + load protocol/config artifacts

def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists() and (p / "meta.json").exists():
            return p
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    raise FileNotFoundError("Could not locate repo root (expected PROJECT_STATE.md).")

REPO_ROOT = find_repo_root(Path.cwd().resolve())
print("[10-01] REPO_ROOT:", REPO_ROOT)

RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
print("[10-01] RUN_TAG:", RUN_TAG)

# Fixed upstream run tags (do NOT change)
TARGET_TAG = "20251229_163357"
SOURCE_TAG = "20251229_232834"

def load_json(path: Path) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

cfg_path_repo = REPO_ROOT / "data/processed/supervised" / f"dataloader_config_{TARGET_TAG}_{SOURCE_TAG}.json"
sanity_path_repo = REPO_ROOT / "data/processed/supervised" / f"sanity_metrics_{TARGET_TAG}_{SOURCE_TAG}.json"
gaps_path_repo = REPO_ROOT / "data/processed/normalized_events" / "session_gap_thresholds.json"

dataloader_cfg = load_json(cfg_path_repo)
sanity_metrics = load_json(sanity_path_repo)
session_gaps = load_json(gaps_path_repo)

print("[10-01] Loaded dataloader_config keys:", list(dataloader_cfg.keys()))
print("[10-01] Loaded sanity_metrics keys:", list(sanity_metrics.keys()))
print("[10-01] Loaded session_gap_thresholds keys:", list(session_gaps.keys()))

# Enforce fixed decisions
assert session_gaps["target"]["primary_threshold_seconds"] == 1800, "Target gap must be 30m (1800s)."
assert session_gaps["source"]["primary_threshold_seconds"] == 600, "Source gap must be 10m (600s)."
print("[10-01] ✅ Session gaps confirmed: target=30m, source=10m")

proto = dataloader_cfg["protocol"]
K_LIST = [5, 10, 20]
MAX_K = max(K_LIST)

MAX_PREFIX_LEN = int(proto["max_prefix_len"])
CAP_ENABLED = bool(proto["source_long_session_policy"]["enabled"])
CAP_SESSION_LEN = int(proto["source_long_session_policy"]["cap_session_len"])
CAP_STRATEGY = str(proto["source_long_session_policy"]["cap_strategy"])

print("[10-01] Protocol from 06:")
print("  K_LIST:", K_LIST)
print("  MAX_PREFIX_LEN:", MAX_PREFIX_LEN)
print("  CAP_ENABLED:", CAP_ENABLED)
print("  CAP_SESSION_LEN:", CAP_SESSION_LEN)
print("  CAP_STRATEGY:", CAP_STRATEGY)

print("\n[10-01] CHECKPOINT A")
print("Confirm JSON loads + gap asserts passed.")


[10-01] REPO_ROOT: C:\mooc-coldstart-session-meta
[10-01] RUN_TAG: 20260102_233834
[10-01] Loaded dataloader_config keys: ['target', 'source', 'protocol']
[10-01] Loaded sanity_metrics keys: ['run_tag_target', 'run_tag_source', 'created_at', 'target', 'source', 'notes']
[10-01] Loaded session_gap_thresholds keys: ['generated_from_run_tag', 'generated_at', 'target', 'source', 'decision_notes']
[10-01] ✅ Session gaps confirmed: target=30m, source=10m
[10-01] Protocol from 06:
  K_LIST: [5, 10, 20]
  MAX_PREFIX_LEN: 20
  CAP_ENABLED: True
  CAP_SESSION_LEN: 200
  CAP_STRATEGY: take_last

[10-01] CHECKPOINT A
Confirm JSON loads + gap asserts passed.


Resolve artifact paths (target tensors + vocabs) + existence checks

In [3]:
# [CELL 10-02] Resolve artifact paths (target tensors + vocabs) + existence checks

def must_exist(p: Path, label: str):
    if not p.exists():
        raise FileNotFoundError(f"{label} not found: {p}")
    return p

TARGET_TENSOR_DIR = REPO_ROOT / "data/processed/tensor_target"
target_train_pt = TARGET_TENSOR_DIR / f"target_tensor_train_{TARGET_TAG}.pt"
target_val_pt   = TARGET_TENSOR_DIR / f"target_tensor_val_{TARGET_TAG}.pt"
target_test_pt  = TARGET_TENSOR_DIR / f"target_tensor_test_{TARGET_TAG}.pt"
target_vocab_json = TARGET_TENSOR_DIR / f"target_vocab_items_{TARGET_TAG}.json"

for p, lbl in [
    (target_train_pt, "target_train_pt"),
    (target_val_pt, "target_val_pt"),
    (target_test_pt, "target_test_pt"),
    (target_vocab_json, "target_vocab_json"),
    (cfg_path_repo, "dataloader_config"),
    (sanity_path_repo, "sanity_metrics"),
    (gaps_path_repo, "session_gap_thresholds"),
]:
    must_exist(p, lbl)

print("[10-02] ✅ All required artifacts exist")

print("\n[10-02] CHECKPOINT B")
print("If any artifact missing, STOP and paste the error.")


[10-02] ✅ All required artifacts exist

[10-02] CHECKPOINT B
If any artifact missing, STOP and paste the error.


Torch loader (PyTorch 2.6+) + vocab sizes + PAD/UNK

In [4]:
# [CELL 10-03] Torch loader (PyTorch 2.6+) + vocab sizes + PAD/UNK

def torch_load_repo_artifact(path, map_location="cpu"):
    path = str(path)
    try:
        obj = torch.load(path, map_location=map_location, weights_only=False)
        print(f"[10-03] torch.load OK (weights_only=False): {path}")
        return obj
    except TypeError:
        obj = torch.load(path, map_location=map_location)
        print(f"[10-03] torch.load OK (no weights_only arg): {path}")
        return obj

target_vocab = load_json(target_vocab_json)

def infer_vocab_size(vocab: dict, name: str) -> int:
    for k in ["vocab_size", "n_items", "num_items", "size"]:
        if k in vocab:
            vs = int(vocab[k])
            print(f"[10-03] {name}: vocab_size from key '{k}' = {vs}")
            return vs
    if "vocab" in vocab and isinstance(vocab["vocab"], dict):
        d = vocab["vocab"]
        if len(d) == 0:
            return 0
        sample_v = next(iter(d.values()))
        if isinstance(sample_v, int):
            ids = list(d.values())
            vs = max(ids) + 1 if len(ids) else 0
            print(f"[10-03] {name}: vocab_size from max(vocab values)+1 (token->id) = {vs}")
            return vs
    raise KeyError(f"[10-03] {name}: Could not infer vocab_size. Keys={list(vocab.keys())}")

vocab_size_target = infer_vocab_size(target_vocab, "TARGET")

def get_special_id(vocab_obj: dict, token_key: str, fallback: int) -> int:
    tok = vocab_obj.get(token_key, None)
    if tok is None:
        return fallback
    mapping = vocab_obj.get("vocab", {})
    if isinstance(mapping, dict) and tok in mapping and isinstance(mapping[tok], int):
        return int(mapping[tok])
    return fallback

PAD_ID_TARGET = get_special_id(target_vocab, "pad_token", 0)
UNK_ID_TARGET = get_special_id(target_vocab, "unk_token", 1)

print("[10-03] vocab_size_target:", vocab_size_target)
print("[10-03] PAD_ID_TARGET:", PAD_ID_TARGET, "| UNK_ID_TARGET:", UNK_ID_TARGET)
assert PAD_ID_TARGET == 0

print("\n[10-03] CHECKPOINT C")
print("Confirm vocab_size_target + PAD/UNK printed as expected.")


[10-03] TARGET: vocab_size from max(vocab values)+1 (token->id) = 747
[10-03] vocab_size_target: 747
[10-03] PAD_ID_TARGET: 0 | UNK_ID_TARGET: 1

[10-03] CHECKPOINT C
Confirm vocab_size_target + PAD/UNK printed as expected.


Metrics (same as 06): HR/MRR/NDCG @ K={5,10,20}

In [5]:
# [CELL 10-04] Metrics (same as 06): HR/MRR/NDCG @ K={5,10,20}

def init_metrics():
    return {f"{m}@{k}": 0.0 for m in ["HR", "MRR", "NDCG"] for k in K_LIST}

def update_metrics_from_rank(metrics: dict, rank0: int | None):
    if rank0 is None:
        return
    r = rank0 + 1
    for k in K_LIST:
        if r <= k:
            metrics[f"HR@{k}"] += 1.0
            metrics[f"MRR@{k}"] += 1.0 / r
            metrics[f"NDCG@{k}"] += 1.0 / math.log2(r + 1.0)

def finalize_metrics(metrics: dict, n: int) -> dict:
    return {k: (float(v / n) if n > 0 else 0.0) for k, v in metrics.items()}

print("[10-04] ✅ Metric functions ready")


[10-04] ✅ Metric functions ready


Load TARGET tensors (train/val/test) (from 05B artifacts)

In [6]:
# [CELL 10-05] Load TARGET tensors (train/val/test) (from 05B artifacts)

train_obj = torch_load_repo_artifact(target_train_pt, map_location="cpu")
val_obj   = torch_load_repo_artifact(target_val_pt, map_location="cpu")
test_obj  = torch_load_repo_artifact(target_test_pt, map_location="cpu")

def as_tensor_dict(obj: dict):
    return {
        "input_ids": torch.as_tensor(obj["input_ids"]).long(),
        "attn_mask": torch.as_tensor(obj["attn_mask"]).long(),
        "labels": torch.as_tensor(obj["labels"]).long(),
    }

target_train = as_tensor_dict(train_obj)
target_val   = as_tensor_dict(val_obj)
target_test  = as_tensor_dict(test_obj)

print("[10-05] TARGET train shapes:",
      tuple(target_train["input_ids"].shape),
      tuple(target_train["attn_mask"].shape),
      tuple(target_train["labels"].shape))
print("[10-05] TARGET val shapes:",
      tuple(target_val["input_ids"].shape),
      tuple(target_val["labels"].shape))
print("[10-05] TARGET test shapes:",
      tuple(target_test["input_ids"].shape),
      tuple(target_test["labels"].shape))

print("\n[10-05] CHECKPOINT D")
print("Confirm shapes match: train=(1944,20), val=(189,20), test=(200,20).")


[10-03] torch.load OK (weights_only=False): C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_train_20251229_163357.pt
[10-03] torch.load OK (weights_only=False): C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_val_20251229_163357.pt
[10-03] torch.load OK (weights_only=False): C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_test_20251229_163357.pt
[10-05] TARGET train shapes: (1944, 20) (1944, 20) (1944,)
[10-05] TARGET val shapes: (189, 20) (189,)
[10-05] TARGET test shapes: (200, 20) (200,)

[10-05] CHECKPOINT D
Confirm shapes match: train=(1944,20), val=(189,20), test=(200,20).


SASRec model (causal self-attention) + helpers

In [7]:
# [CELL 10-06] SASRec model (causal self-attention) + helpers

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def make_lengths(attn_mask: torch.Tensor) -> torch.Tensor:
    return attn_mask.sum(dim=1).long()

def build_causal_mask(T: int, device: torch.device) -> torch.Tensor:
    # float mask: [T,T] with -inf above diagonal
    m = torch.full((T, T), float("-inf"), device=device)
    m = torch.triu(m, diagonal=1)
    return m

class SASRec(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        d_ff: int = 256,
        dropout: float = 0.2,
        pad_id: int = 0,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.pad_id = pad_id

        self.item_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids: torch.Tensor, attn_mask: torch.Tensor):
        # input_ids: [B,T], attn_mask: [B,T] (1 for real, 0 for pad)
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)  # [B,T]
        x = self.item_emb(input_ids) + self.pos_emb(pos)
        x = self.drop(self.norm(x))

        # transformer masks
        causal = build_causal_mask(T, input_ids.device)     # [T,T]
        key_padding = (attn_mask == 0)                      # [B,T] True where pad

        h = self.encoder(x, mask=causal, src_key_padding_mask=key_padding)  # [B,T,D]

        lengths = make_lengths(attn_mask)                   # [B]
        last_idx = torch.clamp(lengths - 1, min=0)          # [B]
        h_last = h[torch.arange(B, device=h.device), last_idx]  # [B,D]
        logits = self.out(h_last)                           # [B,V]
        return logits

print("[10-06] ✅ SASRec defined")


[10-06] ✅ SASRec defined


SASRec (fixed masks): bool causal mask + norm_first=False

In [8]:
# [CELL 10-06B] SASRec (fixed masks): bool causal mask + norm_first=False

def build_causal_mask_bool(T: int, device: torch.device) -> torch.Tensor:
    # True where attention should be blocked (upper triangle)
    return torch.triu(torch.ones((T, T), dtype=torch.bool, device=device), diagonal=1)

class SASRecFixed(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        d_ff: int = 256,
        dropout: float = 0.2,
        pad_id: int = 0,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.pad_id = pad_id

        self.item_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
            norm_first=False,   # fixes nested-tensor warning
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids: torch.Tensor, attn_mask: torch.Tensor):
        # input_ids: [B,T], attn_mask: [B,T] (1=real, 0=pad)
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
        x = self.item_emb(input_ids) + self.pos_emb(pos)
        x = self.drop(self.norm(x))

        causal = build_causal_mask_bool(T, input_ids.device)  # bool [T,T]
        key_padding = (attn_mask == 0)                         # bool [B,T]

        h = self.encoder(x, mask=causal, src_key_padding_mask=key_padding)  # [B,T,D]

        lengths = make_lengths(attn_mask)
        last_idx = torch.clamp(lengths - 1, min=0)
        h_last = h[torch.arange(B, device=h.device), last_idx]  # [B,D]
        logits = self.out(h_last)                                 # [B,V]
        return logits

print("[10-06B] ✅ SASRecFixed defined (bool masks, norm_first=False)")


[10-06B] ✅ SASRecFixed defined (bool masks, norm_first=False)


Train SASRec on TARGET with early stopping (VAL HR@20)

In [9]:
# [CELL 10-07] Train SASRec on TARGET with early stopping (VAL HR@20)

SAS_CFG = {
    "d_model": 64,
    "n_heads": 4,
    "n_layers": 2,
    "d_ff": 256,
    "dropout": 0.2,
    "batch_size": 256,
    "max_epochs": 80,
    "lr": 1e-3,
    "weight_decay": 1e-6,
    "grad_clip": 1.0,
    "seed": 42,
    "early_stop_metric": "HR@20",
    "patience": 10,
    "min_delta": 1e-4,
}

print("[10-07] SAS_CFG:", SAS_CFG)

set_seed(SAS_CFG["seed"])
device = torch.device("cpu")

model_s = SASRec(
    vocab_size=vocab_size_target,
    max_len=MAX_PREFIX_LEN,
    d_model=SAS_CFG["d_model"],
    n_heads=SAS_CFG["n_heads"],
    n_layers=SAS_CFG["n_layers"],
    d_ff=SAS_CFG["d_ff"],
    dropout=SAS_CFG["dropout"],
    pad_id=PAD_ID_TARGET,
).to(device)

opt = torch.optim.Adam(
    model_s.parameters(),
    lr=SAS_CFG["lr"],
    weight_decay=SAS_CFG["weight_decay"],
)

def iter_batches(data: dict, batch_size: int, shuffle: bool = True):
    n = data["input_ids"].shape[0]
    idx = np.arange(n)
    if shuffle:
        np.random.shuffle(idx)
    for s in range(0, n, batch_size):
        b = idx[s:s+batch_size]
        yield (
            data["input_ids"][b].to(device),
            data["attn_mask"][b].to(device),
            data["labels"][b].to(device),
        )

def eval_sas(model: SASRec, data: dict) -> dict:
    model.eval()
    metrics = init_metrics()
    n = 0
    with torch.no_grad():
        for x, am, y in iter_batches(data, batch_size=SAS_CFG["batch_size"], shuffle=False):
            logits = model(x, am)          # [B,V]
            logits[:, PAD_ID_TARGET] = -1e9
            topk = torch.topk(logits, k=MAX_K, dim=1).indices.cpu().numpy()  # [B,K]
            y_np = y.cpu().numpy()
            for i in range(topk.shape[0]):
                if int(y_np[i]) == PAD_ID_TARGET:
                    continue
                pos = np.where(topk[i] == int(y_np[i]))[0]
                rank0 = int(pos[0]) if pos.size > 0 else None
                update_metrics_from_rank(metrics, rank0)
                n += 1
    out = finalize_metrics(metrics, n)
    out["_n_examples"] = int(n)
    return out

best_metric = -1.0
best_epoch = -1
best_state = None
bad_epochs = 0

train_losses = []
val_history = []

metric_name = SAS_CFG["early_stop_metric"]
patience = int(SAS_CFG["patience"])
min_delta = float(SAS_CFG["min_delta"])

print(f"[10-07] Early stopping on {metric_name} | patience={patience} | min_delta={min_delta}")

for epoch in range(1, int(SAS_CFG["max_epochs"]) + 1):
    model_s.train()
    t0 = time.time()
    total_loss = 0.0
    total_n = 0

    for x, am, y in iter_batches(target_train, SAS_CFG["batch_size"], shuffle=True):
        opt.zero_grad()
        logits = model_s(x, am)
        loss = F.cross_entropy(logits, y, ignore_index=PAD_ID_TARGET)
        loss.backward()
        nn.utils.clip_grad_norm_(model_s.parameters(), SAS_CFG["grad_clip"])
        opt.step()

        bs = x.shape[0]
        total_loss += float(loss.item()) * bs
        total_n += bs

    avg_loss = total_loss / max(1, total_n)
    train_losses.append(avg_loss)

    val_metrics = eval_sas(model_s, target_val)
    val_history.append(val_metrics)

    dt = time.time() - t0
    cur = float(val_metrics.get(metric_name, 0.0))

    improved = (cur > best_metric + min_delta)
    if improved:
        best_metric = cur
        best_epoch = epoch
        best_state = {k: v.detach().cpu().clone() for k, v in model_s.state_dict().items()}
        bad_epochs = 0
    else:
        bad_epochs += 1

    print(f"[10-07] epoch={epoch:03d} loss={avg_loss:.4f} time={dt:.1f}s | "
          f"VAL {metric_name}={cur:.6f} | best={best_metric:.6f} (epoch {best_epoch}) | bad_epochs={bad_epochs}")

    if bad_epochs >= patience:
        print(f"[10-07] ✅ Early stop triggered at epoch={epoch} (best epoch={best_epoch}, best {metric_name}={best_metric:.6f})")
        break

assert best_state is not None, "[10-07] best_state is None (no validation computed?)"
model_s.load_state_dict(best_state)
print(f"[10-07] ✅ Restored best model weights from epoch={best_epoch} with best {metric_name}={best_metric:.6f}")

print("\n[10-07] CHECKPOINT E")
print("Paste: best_epoch + best_metric + last 3 epoch log lines.")


[10-07] SAS_CFG: {'d_model': 64, 'n_heads': 4, 'n_layers': 2, 'd_ff': 256, 'dropout': 0.2, 'batch_size': 256, 'max_epochs': 80, 'lr': 0.001, 'weight_decay': 1e-06, 'grad_clip': 1.0, 'seed': 42, 'early_stop_metric': 'HR@20', 'patience': 10, 'min_delta': 0.0001}




[10-07] Early stopping on HR@20 | patience=10 | min_delta=0.0001




[10-07] epoch=001 loss=6.7688 time=1.3s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=0
[10-07] epoch=002 loss=6.5038 time=1.7s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=1
[10-07] epoch=003 loss=6.3187 time=1.6s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=2
[10-07] epoch=004 loss=6.1270 time=1.5s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=3
[10-07] epoch=005 loss=5.9433 time=1.5s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=4
[10-07] epoch=006 loss=5.8104 time=1.3s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=5
[10-07] epoch=007 loss=5.6770 time=1.1s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=6
[10-07] epoch=008 loss=5.5321 time=1.2s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=7
[10-07] epoch=009 loss=5.4174 time=1.1s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=8
[10-07] epoch=010 loss=5.3113 time=1.1s | VAL HR@20=0.010582 | best=0.010582 (epoc

Sanity diagnostics: lengths distribution + hit count explanation

In [10]:
# [CELL 10-07A] Sanity diagnostics: lengths distribution + hit count explanation

lengths_val = make_lengths(target_val["attn_mask"])
print("[10-07A] VAL lengths: min/median/p95/max =",
      int(lengths_val.min()),
      int(lengths_val.median()),
      int(torch.quantile(lengths_val.float(), 0.95).item()),
      int(lengths_val.max()))
print("[10-07A] VAL lengths==0:", int((lengths_val == 0).sum()), "/", int(lengths_val.numel()))

# Convert HR@20 to raw hit counts for interpretability
hr20 = 0.010582
print("[10-07A] HR@20=0.010582 implies hits ≈", hr20 * 189, "out of 189 (should be ~2).")


[10-07A] VAL lengths: min/median/p95/max = 1 3 16 20
[10-07A] VAL lengths==0: 0 / 189
[10-07A] HR@20=0.010582 implies hits ≈ 1.999998 out of 189 (should be ~2).


Final evaluation on TARGET (VAL + TEST) using best weights

Retrain SASRecFixed (new run tag) + early stopping on VAL HR@20

In [11]:
# [CELL 10-07B] Retrain SASRecFixed (new run tag) + early stopping on VAL HR@20

RUN_TAG_FIXED = datetime.now().strftime("%Y%m%d_%H%M%S")
print("[10-07B] RUN_TAG_FIXED:", RUN_TAG_FIXED)

SAS_CFG_FIXED = dict(SAS_CFG)
# keep same cfg for first retry (only masking changed)
print("[10-07B] SAS_CFG_FIXED:", SAS_CFG_FIXED)

set_seed(SAS_CFG_FIXED["seed"])
device = torch.device("cpu")

model_sf = SASRecFixed(
    vocab_size=vocab_size_target,
    max_len=MAX_PREFIX_LEN,
    d_model=SAS_CFG_FIXED["d_model"],
    n_heads=SAS_CFG_FIXED["n_heads"],
    n_layers=SAS_CFG_FIXED["n_layers"],
    d_ff=SAS_CFG_FIXED["d_ff"],
    dropout=SAS_CFG_FIXED["dropout"],
    pad_id=PAD_ID_TARGET,
).to(device)

opt_sf = torch.optim.Adam(
    model_sf.parameters(),
    lr=SAS_CFG_FIXED["lr"],
    weight_decay=SAS_CFG_FIXED["weight_decay"],
)

def eval_sas_fixed(model: nn.Module, data: dict) -> dict:
    model.eval()
    metrics = init_metrics()
    n = 0
    with torch.no_grad():
        for x, am, y in iter_batches(data, batch_size=SAS_CFG_FIXED["batch_size"], shuffle=False):
            logits = model(x, am)
            logits[:, PAD_ID_TARGET] = -1e9
            topk = torch.topk(logits, k=MAX_K, dim=1).indices.cpu().numpy()
            y_np = y.cpu().numpy()
            for i in range(topk.shape[0]):
                yi = int(y_np[i])
                if yi == PAD_ID_TARGET:
                    continue
                pos = np.where(topk[i] == yi)[0]
                rank0 = int(pos[0]) if pos.size > 0 else None
                update_metrics_from_rank(metrics, rank0)
                n += 1
    out = finalize_metrics(metrics, n)
    out["_n_examples"] = int(n)
    return out

best_metric = -1.0
best_epoch = -1
best_state = None
bad_epochs = 0

train_losses_fixed = []
val_history_fixed = []

metric_name = SAS_CFG_FIXED["early_stop_metric"]
patience = int(SAS_CFG_FIXED["patience"])
min_delta = float(SAS_CFG_FIXED["min_delta"])

print(f"[10-07B] Early stopping on {metric_name} | patience={patience} | min_delta={min_delta}")

for epoch in range(1, int(SAS_CFG_FIXED["max_epochs"]) + 1):
    model_sf.train()
    t0 = time.time()
    total_loss = 0.0
    total_n = 0

    for x, am, y in iter_batches(target_train, SAS_CFG_FIXED["batch_size"], shuffle=True):
        opt_sf.zero_grad()
        logits = model_sf(x, am)
        loss = F.cross_entropy(logits, y, ignore_index=PAD_ID_TARGET)
        loss.backward()
        nn.utils.clip_grad_norm_(model_sf.parameters(), SAS_CFG_FIXED["grad_clip"])
        opt_sf.step()

        bs = x.shape[0]
        total_loss += float(loss.item()) * bs
        total_n += bs

    avg_loss = total_loss / max(1, total_n)
    train_losses_fixed.append(avg_loss)

    val_metrics = eval_sas_fixed(model_sf, target_val)
    val_history_fixed.append(val_metrics)

    dt = time.time() - t0
    cur = float(val_metrics.get(metric_name, 0.0))

    improved = (cur > best_metric + min_delta)
    if improved:
        best_metric = cur
        best_epoch = epoch
        best_state = {k: v.detach().cpu().clone() for k, v in model_sf.state_dict().items()}
        bad_epochs = 0
    else:
        bad_epochs += 1

    print(f"[10-07B] epoch={epoch:03d} loss={avg_loss:.4f} time={dt:.1f}s | "
          f"VAL {metric_name}={cur:.6f} | best={best_metric:.6f} (epoch {best_epoch}) | bad_epochs={bad_epochs}")

    if bad_epochs >= patience:
        print(f"[10-07B] ✅ Early stop at epoch={epoch} (best epoch={best_epoch}, best {metric_name}={best_metric:.6f})")
        break

assert best_state is not None
model_sf.load_state_dict(best_state)
print(f"[10-07B] ✅ Restored best weights from epoch={best_epoch} with best {metric_name}={best_metric:.6f}")

print("\n[10-07B] CHECKPOINT E2")
print("Paste: best_epoch + best_metric + last 3 log lines.")


[10-07B] RUN_TAG_FIXED: 20260102_233851
[10-07B] SAS_CFG_FIXED: {'d_model': 64, 'n_heads': 4, 'n_layers': 2, 'd_ff': 256, 'dropout': 0.2, 'batch_size': 256, 'max_epochs': 80, 'lr': 0.001, 'weight_decay': 1e-06, 'grad_clip': 1.0, 'seed': 42, 'early_stop_metric': 'HR@20', 'patience': 10, 'min_delta': 0.0001}
[10-07B] Early stopping on HR@20 | patience=10 | min_delta=0.0001
[10-07B] epoch=001 loss=6.7113 time=1.6s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=0
[10-07B] epoch=002 loss=6.4709 time=1.5s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=1
[10-07B] epoch=003 loss=6.3002 time=1.5s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=2
[10-07B] epoch=004 loss=6.1506 time=1.6s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=3
[10-07B] epoch=005 loss=6.0149 time=1.5s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=4
[10-07B] epoch=006 loss=5.9055 time=1.6s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=5
[10-07B]

In [12]:
# [CELL 10-08] Final evaluation on TARGET (VAL + TEST) using best weights

t_val_sas = eval_sas(model_s, target_val)
t_test_sas = eval_sas(model_s, target_test)

print("[10-08] TARGET VAL (SASRec):", t_val_sas)
print("[10-08] TARGET TEST (SASRec):", t_test_sas)

print("\n[10-08] CHECKPOINT F")
print("Paste TARGET VAL/TEST metrics before writing reports.")


[10-08] TARGET VAL (SASRec): {'HR@5': 0.005291005291005291, 'HR@10': 0.005291005291005291, 'HR@20': 0.010582010582010581, 'MRR@5': 0.005291005291005291, 'MRR@10': 0.005291005291005291, 'MRR@20': 0.005698005698005697, 'NDCG@5': 0.005291005291005291, 'NDCG@10': 0.005291005291005291, 'NDCG@20': 0.006680685370567162, '_n_examples': 189}
[10-08] TARGET TEST (SASRec): {'HR@5': 0.0, 'HR@10': 0.015, 'HR@20': 0.035, 'MRR@5': 0.0, 'MRR@10': 0.0018948412698412697, 'MRR@20': 0.00335454822954823, 'NDCG@5': 0.0, 'NDCG@10': 0.004749141028915217, 'NDCG@20': 0.00990542650333623, '_n_examples': 200}

[10-08] CHECKPOINT F
Paste TARGET VAL/TEST metrics before writing reports.


Final evaluation (VAL + TEST) using SASRecFixed best weights

In [13]:
# [CELL 10-08B] Final evaluation (VAL + TEST) using SASRecFixed best weights

t_val_sas_fixed = eval_sas_fixed(model_sf, target_val)
t_test_sas_fixed = eval_sas_fixed(model_sf, target_test)

print("[10-08B] TARGET VAL (SASRecFixed):", t_val_sas_fixed)
print("[10-08B] TARGET TEST (SASRecFixed):", t_test_sas_fixed)

print("\n[10-08B] CHECKPOINT F2")
print("Paste VAL/TEST metrics. If still flat, we run an overfit-on-256 sanity check next.")


[10-08B] TARGET VAL (SASRecFixed): {'HR@5': 0.005291005291005291, 'HR@10': 0.005291005291005291, 'HR@20': 0.010582010582010581, 'MRR@5': 0.005291005291005291, 'MRR@10': 0.005291005291005291, 'MRR@20': 0.005698005698005697, 'NDCG@5': 0.005291005291005291, 'NDCG@10': 0.005291005291005291, 'NDCG@20': 0.006680685370567162, '_n_examples': 189}
[10-08B] TARGET TEST (SASRecFixed): {'HR@5': 0.0, 'HR@10': 0.005, 'HR@20': 0.035, 'MRR@5': 0.0, 'MRR@10': 0.0007142857142857143, 'MRR@20': 0.002961871461871462, 'NDCG@5': 0.0, 'NDCG@10': 0.0016666666666666666, 'NDCG@20': 0.009467666869343328, '_n_examples': 200}

[10-08B] CHECKPOINT F2
Paste VAL/TEST metrics. If still flat, we run an overfit-on-256 sanity check next.


Write report artifacts to reports/10_sasrec_baseline/<RUN_TAG>/ + update meta.json

In [14]:
# [CELL 10-09] Write report artifacts to reports/10_sasrec_baseline/<RUN_TAG>/ + update meta.json

REPORT_DIR = REPO_ROOT / "reports" / "10_sasrec_baseline" / RUN_TAG
REPORT_DIR.mkdir(parents=True, exist_ok=True)

def save_json(obj: dict, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)

run_meta = {
    "run_tag": RUN_TAG,
    "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "inputs": {
        "target_run_tag": TARGET_TAG,
        "source_run_tag": SOURCE_TAG,
        "target_train_pt": str(target_train_pt),
        "target_val_pt": str(target_val_pt),
        "target_test_pt": str(target_test_pt),
        "target_vocab_json": str(target_vocab_json),
        "dataloader_config": str(cfg_path_repo),
        "sanity_metrics": str(sanity_path_repo),
        "session_gap_thresholds": str(gaps_path_repo),
    },
    "protocol_reused_from_06": {
        "K_LIST": K_LIST,
        "MAX_PREFIX_LEN": MAX_PREFIX_LEN,
        "PAD_ID_TARGET": PAD_ID_TARGET,
        "pad_excluded_from_ranking": True,
        "causal_mask": True,
    },
    "model": {
        "name": "SASRec",
        "vocab_size": int(vocab_size_target),
        "d_model": int(SAS_CFG["d_model"]),
        "n_heads": int(SAS_CFG["n_heads"]),
        "n_layers": int(SAS_CFG["n_layers"]),
        "d_ff": int(SAS_CFG["d_ff"]),
        "dropout": float(SAS_CFG["dropout"]),
    },
    "train_cfg": SAS_CFG,
    "early_stopping": {
        "metric": SAS_CFG["early_stop_metric"],
        "best_epoch": int(best_epoch),
        "best_metric": float(best_metric),
        "patience": int(SAS_CFG["patience"]),
        "min_delta": float(SAS_CFG["min_delta"]),
    },
    "notes": [
        "This run trains/evaluates SASRec on TARGET only (Layer-1 baseline).",
        "Source training / transfer is handled later.",
    ],
}

results = {
    "target": {
        "val": t_val_sas,
        "test": t_test_sas,
        "train_losses": train_losses,
        "val_history": val_history,
    },
    "source": None,
}

save_json(run_meta, REPORT_DIR / "run_meta.json")
save_json(results, REPORT_DIR / "results.json")

ckpt = {
    "state_dict": model_s.state_dict(),
    "sas_cfg": SAS_CFG,
    "vocab_size_target": vocab_size_target,
    "pad_id": PAD_ID_TARGET,
    "best_epoch": best_epoch,
    "best_metric": best_metric,
}
torch.save(ckpt, REPORT_DIR / "model.pt")

# Update meta.json
meta_path = REPO_ROOT / "meta.json"
meta = load_json(meta_path) if meta_path.exists() else {"artifacts": {}}

meta.setdefault("artifacts", {})
meta["artifacts"].setdefault("sasrec_baseline", {})
meta["artifacts"]["sasrec_baseline"][RUN_TAG] = {
    "target_run_tag": TARGET_TAG,
    "source_run_tag": SOURCE_TAG,
    "report_dir": str(REPORT_DIR),
    "results_json": str(REPORT_DIR / "results.json"),
    "run_meta_json": str(REPORT_DIR / "run_meta.json"),
    "model_pt": str(REPORT_DIR / "model.pt"),
}
meta["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
save_json(meta, meta_path)

print("[10-09] ✅ Wrote report files under:", REPORT_DIR)
print("[10-09] ✅ Updated meta.json:", meta_path)

print("\n[10-09] CHECKPOINT G")
print("Paste: report dir + confirm meta.json updated.")


[10-09] ✅ Wrote report files under: C:\mooc-coldstart-session-meta\reports\10_sasrec_baseline\20260102_233834
[10-09] ✅ Updated meta.json: C:\mooc-coldstart-session-meta\meta.json

[10-09] CHECKPOINT G
Paste: report dir + confirm meta.json updated.


Footer summary

In [15]:
# [CELL 10-10] Footer summary

print("========== 10 SASRec Baseline Summary ==========")
print("RUN_TAG:", RUN_TAG)
print("--- TARGET ---")
print("VAL :", t_val_sas)
print("TEST:", t_test_sas)
print("Report dir:", REPORT_DIR)
print("================================================")


RUN_TAG: 20260102_233834
--- TARGET ---
VAL : {'HR@5': 0.005291005291005291, 'HR@10': 0.005291005291005291, 'HR@20': 0.010582010582010581, 'MRR@5': 0.005291005291005291, 'MRR@10': 0.005291005291005291, 'MRR@20': 0.005698005698005697, 'NDCG@5': 0.005291005291005291, 'NDCG@10': 0.005291005291005291, 'NDCG@20': 0.006680685370567162, '_n_examples': 189}
TEST: {'HR@5': 0.0, 'HR@10': 0.015, 'HR@20': 0.035, 'MRR@5': 0.0, 'MRR@10': 0.0018948412698412697, 'MRR@20': 0.00335454822954823, 'NDCG@5': 0.0, 'NDCG@10': 0.004749141028915217, 'NDCG@20': 0.00990542650333623, '_n_examples': 200}
Report dir: C:\mooc-coldstart-session-meta\reports\10_sasrec_baseline\20260102_233834


Diagnose padding side (CRITICAL): right-pad vs left-pad

In [16]:
# [CELL 10-10] Diagnose padding side (CRITICAL): right-pad vs left-pad

def detect_pad_side(input_ids: torch.Tensor, attn_mask: torch.Tensor, pad_id: int = 0, n_probe: int = 50):
    # Find rows with some padding
    lengths = attn_mask.sum(dim=1)
    idx = torch.where(lengths < attn_mask.shape[1])[0]
    if idx.numel() == 0:
        return "no_pad_detected"

    idx = idx[:n_probe]
    T = attn_mask.shape[1]
    right_votes = 0
    left_votes = 0

    for i in idx.tolist():
        am = attn_mask[i].cpu().numpy()
        # right-pad => zeros clustered at end
        if am[-1] == 0 and am[0] == 1:
            right_votes += 1
        # left-pad => zeros clustered at start
        if am[0] == 0 and am[-1] == 1:
            left_votes += 1

    if right_votes > left_votes:
        return "right"
    if left_votes > right_votes:
        return "left"
    return f"ambiguous(right={right_votes}, left={left_votes})"

pad_side = detect_pad_side(target_train["input_ids"], target_train["attn_mask"], pad_id=PAD_ID_TARGET)
print("[10-10] pad_side detected:", pad_side)

# Show one short example row for visual confirmation
lengths = make_lengths(target_train["attn_mask"])
short_idx = int(torch.where(lengths < MAX_PREFIX_LEN)[0][0].item())
print("[10-10] Example short row idx:", short_idx, "| len:", int(lengths[short_idx]))
print("[10-10] input_ids:", target_train["input_ids"][short_idx].tolist())
print("[10-10] attn_mask:", target_train["attn_mask"][short_idx].tolist())

print("\n[10-10] CHECKPOINT H")
print("Paste pad_side + the example input_ids/attn_mask row.")


[10-10] pad_side detected: left
[10-10] Example short row idx: 0 | len: 1
[10-10] input_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 416]
[10-10] attn_mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]

[10-10] CHECKPOINT H
Paste pad_side + the example input_ids/attn_mask row.


Prediction-collapse check: what IDs does SASRec actually predict?

In [17]:
# [CELL 10-11] Prediction-collapse check: what IDs does SASRec actually predict?

def topk_id_histogram(model, data, n_batches=2, k=20):
    model.eval()
    hist = {}
    with torch.no_grad():
        nb = 0
        for x, am, y in iter_batches(data, batch_size=256, shuffle=False):
            logits = model(x, am)
            logits[:, PAD_ID_TARGET] = -1e9
            topk = torch.topk(logits, k=k, dim=1).indices.cpu().numpy()
            for row in topk:
                for iid in row:
                    hist[int(iid)] = hist.get(int(iid), 0) + 1
            nb += 1
            if nb >= n_batches:
                break
    items = sorted(hist.items(), key=lambda kv: kv[1], reverse=True)[:30]
    return items

print("[10-11] Top predicted IDs in VAL (first 2 batches):")
print(topk_id_histogram(model_s, target_val, n_batches=2, k=20))

print("\n[10-11] CHECKPOINT I")
print("Paste the top predicted IDs list. If it's dominated by 1-3 IDs, it's collapsing.")


[10-11] Top predicted IDs in VAL (first 2 batches):
[(93, 185), (1, 185), (2, 185), (3, 185), (4, 185), (5, 185), (6, 185), (7, 185), (8, 185), (9, 185), (10, 185), (11, 185), (12, 185), (13, 185), (14, 185), (15, 185), (16, 185), (17, 185), (18, 185), (19, 185), (303, 2), (679, 2), (58, 2), (159, 2), (718, 2), (571, 2), (331, 2), (24, 2), (36, 2), (269, 2)]

[10-11] CHECKPOINT I
Paste the top predicted IDs list. If it's dominated by 1-3 IDs, it's collapsing.


SASRecAdaptive: correct last-position extraction based on pad_side

In [20]:
# [CELL 10-12] SASRecAdaptive: correct last-position extraction based on pad_side

class SASRecAdaptive(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        d_ff: int = 256,
        dropout: float = 0.2,
        pad_id: int = 0,
        pad_side: str = "right",   # "right" or "left"
    ):
        super().__init__()
        assert pad_side in ["right", "left"], f"pad_side must be right/left, got {pad_side}"
        self.pad_side = pad_side
        self.pad_id = pad_id
        self.max_len = max_len

        self.item_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
            norm_first=False,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids: torch.Tensor, attn_mask: torch.Tensor):
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
        x = self.item_emb(input_ids) + self.pos_emb(pos)
        x = self.drop(self.norm(x))

        causal = torch.triu(torch.ones((T, T), dtype=torch.bool, device=input_ids.device), diagonal=1)
        key_padding = (attn_mask == 0)  # bool

        h = self.encoder(x, mask=causal, src_key_padding_mask=key_padding)  # [B,T,D]

        if self.pad_side == "right":
            lengths = make_lengths(attn_mask)
            last_idx = torch.clamp(lengths - 1, min=0)
        else:
            # left padding => last token is always real (attn_mask[-1]==1)
            last_idx = torch.full((B,), T - 1, device=h.device, dtype=torch.long)

        h_last = h[torch.arange(B, device=h.device), last_idx]
        logits = self.out(h_last)
        return logits

print("[10-12] ✅ SASRecAdaptive defined")


[10-12] ✅ SASRecAdaptive defined


Retrain SASRecAdaptive (new RUN_TAG) + early stopping on VAL HR@20

In [21]:
# [CELL 10-13] Retrain SASRecAdaptive (new RUN_TAG) + early stopping on VAL HR@20

RUN_TAG_ADAPT = datetime.now().strftime("%Y%m%d_%H%M%S")
print("[10-13] RUN_TAG_ADAPT:", RUN_TAG_ADAPT)
print("[10-13] Using pad_side:", pad_side)

model_sa = SASRecAdaptive(
    vocab_size=vocab_size_target,
    max_len=MAX_PREFIX_LEN,
    d_model=SAS_CFG["d_model"],
    n_heads=SAS_CFG["n_heads"],
    n_layers=SAS_CFG["n_layers"],
    d_ff=SAS_CFG["d_ff"],
    dropout=SAS_CFG["dropout"],
    pad_id=PAD_ID_TARGET,
    pad_side=("left" if "left" in str(pad_side) else "right"),
).to(device)

opt_sa = torch.optim.Adam(model_sa.parameters(), lr=SAS_CFG["lr"], weight_decay=SAS_CFG["weight_decay"])

def eval_sas_adapt(model: nn.Module, data: dict) -> dict:
    model.eval()
    metrics = init_metrics()
    n = 0
    with torch.no_grad():
        for x, am, y in iter_batches(data, batch_size=SAS_CFG["batch_size"], shuffle=False):
            logits = model(x, am)
            logits[:, PAD_ID_TARGET] = -1e9
            topk = torch.topk(logits, k=MAX_K, dim=1).indices.cpu().numpy()
            y_np = y.cpu().numpy()
            for i in range(topk.shape[0]):
                yi = int(y_np[i])
                if yi == PAD_ID_TARGET:
                    continue
                pos = np.where(topk[i] == yi)[0]
                rank0 = int(pos[0]) if pos.size > 0 else None
                update_metrics_from_rank(metrics, rank0)
                n += 1
    out = finalize_metrics(metrics, n)
    out["_n_examples"] = int(n)
    return out

best_metric = -1.0
best_epoch = -1
best_state = None
bad_epochs = 0
train_losses_adapt = []
val_history_adapt = []

metric_name = SAS_CFG["early_stop_metric"]
patience = int(SAS_CFG["patience"])
min_delta = float(SAS_CFG["min_delta"])

for epoch in range(1, int(SAS_CFG["max_epochs"]) + 1):
    model_sa.train()
    t0 = time.time()
    total_loss = 0.0
    total_n = 0

    for x, am, y in iter_batches(target_train, SAS_CFG["batch_size"], shuffle=True):
        opt_sa.zero_grad()
        logits = model_sa(x, am)
        loss = F.cross_entropy(logits, y, ignore_index=PAD_ID_TARGET)
        loss.backward()
        nn.utils.clip_grad_norm_(model_sa.parameters(), SAS_CFG["grad_clip"])
        opt_sa.step()

        bs = x.shape[0]
        total_loss += float(loss.item()) * bs
        total_n += bs

    avg_loss = total_loss / max(1, total_n)
    train_losses_adapt.append(avg_loss)

    val_metrics = eval_sas_adapt(model_sa, target_val)
    val_history_adapt.append(val_metrics)

    cur = float(val_metrics.get(metric_name, 0.0))
    improved = (cur > best_metric + min_delta)
    if improved:
        best_metric = cur
        best_epoch = epoch
        best_state = {k: v.detach().cpu().clone() for k, v in model_sa.state_dict().items()}
        bad_epochs = 0
    else:
        bad_epochs += 1

    print(f"[10-13] epoch={epoch:03d} loss={avg_loss:.4f} time={time.time()-t0:.1f}s | "
          f"VAL {metric_name}={cur:.6f} | best={best_metric:.6f} (epoch {best_epoch}) | bad_epochs={bad_epochs}")

    if bad_epochs >= patience:
        print(f"[10-13] ✅ Early stop at epoch={epoch} (best epoch={best_epoch}, best {metric_name}={best_metric:.6f})")
        break

assert best_state is not None
model_sa.load_state_dict(best_state)
print(f"[10-13] ✅ Restored best weights from epoch={best_epoch} with best {metric_name}={best_metric:.6f}")

print("\n[10-13] CHECKPOINT J")
print("Paste: best_epoch + best_metric + last 3 epoch log lines.")


[10-13] RUN_TAG_ADAPT: 20260102_234522
[10-13] Using pad_side: left
[10-13] epoch=001 loss=6.7259 time=1.2s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=0
[10-13] epoch=002 loss=6.3992 time=1.3s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=1
[10-13] epoch=003 loss=6.1698 time=1.4s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=2
[10-13] epoch=004 loss=5.9537 time=1.4s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=3
[10-13] epoch=005 loss=5.7447 time=1.5s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=4
[10-13] epoch=006 loss=5.5472 time=1.5s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=5
[10-13] epoch=007 loss=5.3671 time=1.5s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=6
[10-13] epoch=008 loss=5.1847 time=1.4s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=7
[10-13] epoch=009 loss=5.0143 time=1.2s | VAL HR@20=0.010582 | best=0.010582 (epoch 1) | bad_epochs=8
[10-13] epoch=

Evaluate SASRecAdaptive (VAL + TEST)

In [22]:
# [CELL 10-14] Evaluate SASRecAdaptive (VAL + TEST)

t_val_adapt = eval_sas_adapt(model_sa, target_val)
t_test_adapt = eval_sas_adapt(model_sa, target_test)

print("[10-14] TARGET VAL (SASRecAdaptive):", t_val_adapt)
print("[10-14] TARGET TEST (SASRecAdaptive):", t_test_adapt)

print("\n[10-14] CHECKPOINT K")
print("Paste VAL/TEST metrics. We expect to beat MostPop HR@20 (~0.175) if the bug was padding-side.")


[10-14] TARGET VAL (SASRecAdaptive): {'HR@5': 0.005291005291005291, 'HR@10': 0.005291005291005291, 'HR@20': 0.010582010582010581, 'MRR@5': 0.005291005291005291, 'MRR@10': 0.005291005291005291, 'MRR@20': 0.005698005698005697, 'NDCG@5': 0.005291005291005291, 'NDCG@10': 0.005291005291005291, 'NDCG@20': 0.006680685370567162, '_n_examples': 189}
[10-14] TARGET TEST (SASRecAdaptive): {'HR@5': 0.0, 'HR@10': 0.015, 'HR@20': 0.04, 'MRR@5': 0.0, 'MRR@10': 0.0021031746031746033, 'MRR@20': 0.004017427017427018, 'NDCG@5': 0.0, 'NDCG@10': 0.004952852580526683, 'NDCG@20': 0.011503852783203346, '_n_examples': 200}

[10-14] CHECKPOINT K
Paste VAL/TEST metrics. We expect to beat MostPop HR@20 (~0.175) if the bug was padding-side.


On-the-fly LEFT->RIGHT padding transform (batch)

In [23]:
# [CELL 10-15] On-the-fly LEFT->RIGHT padding transform (batch)

def right_pad_from_left_padded(input_ids: torch.Tensor, attn_mask: torch.Tensor, pad_id: int = 0):
    """
    Convert left-padded [B,T] to right-padded [B,T] while preserving token order.
    Assumes attn_mask is 0/1 and PAD tokens live where attn_mask==0.
    """
    B, T = input_ids.shape
    lengths = attn_mask.sum(dim=1).long()  # [B]
    out_ids = torch.full_like(input_ids, pad_id)
    out_am = torch.zeros_like(attn_mask)

    for i in range(B):
        L = int(lengths[i].item())
        if L <= 0:
            continue
        # For left padding, real tokens are at the end; take the last L tokens
        seq = input_ids[i, T - L :].clone()
        out_ids[i, :L] = seq
        out_am[i, :L] = 1

    return out_ids, out_am, lengths

# quick sanity on a tiny batch
_x = target_train["input_ids"][:4].clone()
_am = target_train["attn_mask"][:4].clone()
_x2, _am2, _len2 = right_pad_from_left_padded(_x, _am, pad_id=PAD_ID_TARGET)

print("[10-15] Example transform check (row0):")
print("  left input_ids :", _x[0].tolist())
print("  left attn_mask :", _am[0].tolist())
print("  right input_ids:", _x2[0].tolist())
print("  right attn_mask:", _am2[0].tolist())
print("  len:", int(_len2[0].item()))

print("\n[10-15] CHECKPOINT L")
print("Paste row0 before/after to confirm left->right pad transform is correct.")


[10-15] Example transform check (row0):
  left input_ids : [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 416]
  left attn_mask : [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
  right input_ids: [416, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  right attn_mask: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  len: 1

[10-15] CHECKPOINT L
Paste row0 before/after to confirm left->right pad transform is correct.


SASRecRightPad: always right-pad inside forward

In [24]:
# [CELL 10-16] SASRecRightPad: always right-pad inside forward

class SASRecRightPad(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        d_ff: int = 256,
        dropout: float = 0.2,
        pad_id: int = 0,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.pad_id = pad_id

        self.item_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
            norm_first=False,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids: torch.Tensor, attn_mask: torch.Tensor):
        # Convert to RIGHT-padded so position ids align with token order
        input_ids, attn_mask, lengths = right_pad_from_left_padded(
            input_ids, attn_mask, pad_id=self.pad_id
        )

        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
        x = self.item_emb(input_ids) + self.pos_emb(pos)
        x = self.drop(self.norm(x))

        # bool causal + key padding
        causal = torch.triu(torch.ones((T, T), dtype=torch.bool, device=input_ids.device), diagonal=1)
        key_padding = (attn_mask == 0)

        h = self.encoder(x, mask=causal, src_key_padding_mask=key_padding)  # [B,T,D]

        last_idx = torch.clamp(lengths - 1, min=0)  # RIGHT pad => last real token index
        h_last = h[torch.arange(B, device=h.device), last_idx]
        logits = self.out(h_last)
        return logits

print("[10-16] ✅ SASRecRightPad defined")


[10-16] ✅ SASRecRightPad defined


Train SASRecRightPad (NEW RUN TAG) + early stopping on VAL HR@20

In [25]:
# [CELL 10-17] Train SASRecRightPad (NEW RUN TAG) + early stopping on VAL HR@20

RUN_TAG_RPAD = datetime.now().strftime("%Y%m%d_%H%M%S")
print("[10-17] RUN_TAG_RPAD:", RUN_TAG_RPAD)

SAS_CFG_RPAD = dict(SAS_CFG)
# keep same cfg first; we only changed padding semantics
print("[10-17] SAS_CFG_RPAD:", SAS_CFG_RPAD)

set_seed(SAS_CFG_RPAD["seed"])
device = torch.device("cpu")

model_rp = SASRecRightPad(
    vocab_size=vocab_size_target,
    max_len=MAX_PREFIX_LEN,
    d_model=SAS_CFG_RPAD["d_model"],
    n_heads=SAS_CFG_RPAD["n_heads"],
    n_layers=SAS_CFG_RPAD["n_layers"],
    d_ff=SAS_CFG_RPAD["d_ff"],
    dropout=SAS_CFG_RPAD["dropout"],
    pad_id=PAD_ID_TARGET,
).to(device)

opt_rp = torch.optim.Adam(model_rp.parameters(), lr=SAS_CFG_RPAD["lr"], weight_decay=SAS_CFG_RPAD["weight_decay"])

def eval_sas_rp(model: nn.Module, data: dict) -> dict:
    model.eval()
    metrics = init_metrics()
    n = 0
    with torch.no_grad():
        for x, am, y in iter_batches(data, batch_size=SAS_CFG_RPAD["batch_size"], shuffle=False):
            logits = model(x, am)
            logits[:, PAD_ID_TARGET] = -1e9
            topk = torch.topk(logits, k=MAX_K, dim=1).indices.cpu().numpy()
            y_np = y.cpu().numpy()
            for i in range(topk.shape[0]):
                yi = int(y_np[i])
                if yi == PAD_ID_TARGET:
                    continue
                pos = np.where(topk[i] == yi)[0]
                rank0 = int(pos[0]) if pos.size > 0 else None
                update_metrics_from_rank(metrics, rank0)
                n += 1
    out = finalize_metrics(metrics, n)
    out["_n_examples"] = int(n)
    return out

best_metric = -1.0
best_epoch = -1
best_state = None
bad_epochs = 0

train_losses_rp = []
val_history_rp = []

metric_name = SAS_CFG_RPAD["early_stop_metric"]
patience = int(SAS_CFG_RPAD["patience"])
min_delta = float(SAS_CFG_RPAD["min_delta"])

print(f"[10-17] Early stopping on {metric_name} | patience={patience} | min_delta={min_delta}")

for epoch in range(1, int(SAS_CFG_RPAD["max_epochs"]) + 1):
    model_rp.train()
    t0 = time.time()
    total_loss = 0.0
    total_n = 0

    for x, am, y in iter_batches(target_train, SAS_CFG_RPAD["batch_size"], shuffle=True):
        opt_rp.zero_grad()
        logits = model_rp(x, am)
        loss = F.cross_entropy(logits, y, ignore_index=PAD_ID_TARGET)
        loss.backward()
        nn.utils.clip_grad_norm_(model_rp.parameters(), SAS_CFG_RPAD["grad_clip"])
        opt_rp.step()

        bs = x.shape[0]
        total_loss += float(loss.item()) * bs
        total_n += bs

    avg_loss = total_loss / max(1, total_n)
    train_losses_rp.append(avg_loss)

    val_metrics = eval_sas_rp(model_rp, target_val)
    val_history_rp.append(val_metrics)

    cur = float(val_metrics.get(metric_name, 0.0))
    improved = (cur > best_metric + min_delta)
    if improved:
        best_metric = cur
        best_epoch = epoch
        best_state = {k: v.detach().cpu().clone() for k, v in model_rp.state_dict().items()}
        bad_epochs = 0
    else:
        bad_epochs += 1

    print(f"[10-17] epoch={epoch:03d} loss={avg_loss:.4f} time={time.time()-t0:.1f}s | "
          f"VAL {metric_name}={cur:.6f} | best={best_metric:.6f} (epoch {best_epoch}) | bad_epochs={bad_epochs}")

    if bad_epochs >= patience:
        print(f"[10-17] ✅ Early stop at epoch={epoch} (best epoch={best_epoch}, best {metric_name}={best_metric:.6f})")
        break

assert best_state is not None
model_rp.load_state_dict(best_state)
print(f"[10-17] ✅ Restored best weights from epoch={best_epoch} with best {metric_name}={best_metric:.6f}")

print("\n[10-17] CHECKPOINT M")
print("Paste: best_epoch + best_metric + last 3 epoch log lines.")


[10-17] RUN_TAG_RPAD: 20260102_234927
[10-17] SAS_CFG_RPAD: {'d_model': 64, 'n_heads': 4, 'n_layers': 2, 'd_ff': 256, 'dropout': 0.2, 'batch_size': 256, 'max_epochs': 80, 'lr': 0.001, 'weight_decay': 1e-06, 'grad_clip': 1.0, 'seed': 42, 'early_stop_metric': 'HR@20', 'patience': 10, 'min_delta': 0.0001}
[10-17] Early stopping on HR@20 | patience=10 | min_delta=0.0001
[10-17] epoch=001 loss=6.6921 time=1.4s | VAL HR@20=0.111111 | best=0.111111 (epoch 1) | bad_epochs=0
[10-17] epoch=002 loss=6.3934 time=1.4s | VAL HR@20=0.142857 | best=0.142857 (epoch 2) | bad_epochs=0
[10-17] epoch=003 loss=6.1864 time=1.4s | VAL HR@20=0.174603 | best=0.174603 (epoch 3) | bad_epochs=0
[10-17] epoch=004 loss=5.9666 time=1.3s | VAL HR@20=0.201058 | best=0.201058 (epoch 4) | bad_epochs=0
[10-17] epoch=005 loss=5.7746 time=1.5s | VAL HR@20=0.206349 | best=0.206349 (epoch 5) | bad_epochs=0
[10-17] epoch=006 loss=5.6028 time=1.6s | VAL HR@20=0.222222 | best=0.222222 (epoch 6) | bad_epochs=0
[10-17] epoch=007 l

Evaluate SASRecRightPad (VAL + TEST)

In [26]:
# [CELL 10-18] Evaluate SASRecRightPad (VAL + TEST)

t_val_rp = eval_sas_rp(model_rp, target_val)
t_test_rp = eval_sas_rp(model_rp, target_test)

print("[10-18] TARGET VAL (SASRecRightPad):", t_val_rp)
print("[10-18] TARGET TEST (SASRecRightPad):", t_test_rp)

print("\n[10-18] CHECKPOINT N")
print("Paste VAL/TEST metrics. We should now be >= MostPop on HR@20 (val ~0.175).")


[10-18] TARGET VAL (SASRecRightPad): {'HR@5': 0.4973544973544973, 'HR@10': 0.5132275132275133, 'HR@20': 0.5661375661375662, 'MRR@5': 0.4210758377425044, 'MRR@10': 0.42359536407155457, 'MRR@20': 0.42723215364745665, 'NDCG@5': 0.44003450264380367, 'NDCG@10': 0.4455675592975041, 'NDCG@20': 0.45890533525978944, '_n_examples': 189}
[10-18] TARGET TEST (SASRecRightPad): {'HR@5': 0.44, 'HR@10': 0.455, 'HR@20': 0.485, 'MRR@5': 0.39049999999999996, 'MRR@10': 0.3928809523809524, 'MRR@20': 0.3950561773752563, 'NDCG@5': 0.4028467478321102, 'NDCG@10': 0.40807548636985713, 'NDCG@20': 0.41577484972680745, '_n_examples': 200}

[10-18] CHECKPOINT N
Paste VAL/TEST metrics. We should now be >= MostPop on HR@20 (val ~0.175).


Write FIXED report artifacts under reports/10_sasrec_baseline/<RUN_TAG_RPAD>/ + meta.json update

In [27]:
# [CELL 10-19] Write FIXED report artifacts under reports/10_sasrec_baseline/<RUN_TAG_RPAD>/ + meta.json update

REPORT_DIR_RPAD = REPO_ROOT / "reports" / "10_sasrec_baseline" / RUN_TAG_RPAD
REPORT_DIR_RPAD.mkdir(parents=True, exist_ok=True)

def save_json(obj: dict, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)

run_meta_rp = {
    "run_tag": RUN_TAG_RPAD,
    "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "inputs": {
        "target_run_tag": TARGET_TAG,
        "source_run_tag": SOURCE_TAG,
        "target_train_pt": str(target_train_pt),
        "target_val_pt": str(target_val_pt),
        "target_test_pt": str(target_test_pt),
        "target_vocab_json": str(target_vocab_json),
        "dataloader_config": str(cfg_path_repo),
        "sanity_metrics": str(sanity_path_repo),
        "session_gap_thresholds": str(gaps_path_repo),
    },
    "protocol_reused_from_06": {
        "K_LIST": K_LIST,
        "MAX_PREFIX_LEN": MAX_PREFIX_LEN,
        "PAD_ID_TARGET": PAD_ID_TARGET,
        "pad_excluded_from_ranking": True,
        "causal_mask": True,
    },
    "model": {
        "name": "SASRecRightPad",
        "vocab_size": int(vocab_size_target),
        "d_model": int(SAS_CFG_RPAD["d_model"]),
        "n_heads": int(SAS_CFG_RPAD["n_heads"]),
        "n_layers": int(SAS_CFG_RPAD["n_layers"]),
        "d_ff": int(SAS_CFG_RPAD["d_ff"]),
        "dropout": float(SAS_CFG_RPAD["dropout"]),
    },
    "train_cfg": SAS_CFG_RPAD,
    "early_stopping": {
        "metric": SAS_CFG_RPAD["early_stop_metric"],
        "best_epoch": int(best_epoch),
        "best_metric": float(best_metric),
        "patience": int(SAS_CFG_RPAD["patience"]),
        "min_delta": float(SAS_CFG_RPAD["min_delta"]),
    },
    "notes": [
        "CRITICAL FIX: target tensors are left-padded; SASRec is evaluated/trained with on-the-fly right-padding to align absolute position embeddings.",
        "This run supersedes earlier SASRec attempts that were effectively broken under left padding.",
    ],
}

results_rp = {
    "target": {
        "val": t_val_rp,
        "test": t_test_rp,
        "train_losses": train_losses_rp,
        "val_history": val_history_rp,
    },
    "source": None,
}

save_json(run_meta_rp, REPORT_DIR_RPAD / "run_meta.json")
save_json(results_rp, REPORT_DIR_RPAD / "results.json")

torch.save(
    {
        "state_dict": model_rp.state_dict(),
        "sas_cfg": SAS_CFG_RPAD,
        "vocab_size_target": vocab_size_target,
        "pad_id": PAD_ID_TARGET,
        "best_epoch": best_epoch,
        "best_metric": best_metric,
        "note": "SASRecRightPad (left->right pad inside forward)",
    },
    REPORT_DIR_RPAD / "model.pt",
)

meta_path = REPO_ROOT / "meta.json"
meta = load_json(meta_path) if meta_path.exists() else {"artifacts": {}}
meta.setdefault("artifacts", {})
meta["artifacts"].setdefault("sasrec_baseline", {})
meta["artifacts"]["sasrec_baseline"][RUN_TAG_RPAD] = {
    "target_run_tag": TARGET_TAG,
    "source_run_tag": SOURCE_TAG,
    "report_dir": str(REPORT_DIR_RPAD),
    "results_json": str(REPORT_DIR_RPAD / "results.json"),
    "run_meta_json": str(REPORT_DIR_RPAD / "run_meta.json"),
    "model_pt": str(REPORT_DIR_RPAD / "model.pt"),
    "note": "Fixed: right-pad-on-the-fly to correct left-padding artifact.",
}
meta["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
save_json(meta, meta_path)

print("[10-19] ✅ Wrote FIXED report files under:", REPORT_DIR_RPAD)
print("[10-19] ✅ Updated meta.json:", meta_path)

print("\n[10-19] CHECKPOINT O")
print("Paste: REPORT_DIR_RPAD + confirm meta.json updated.")


[10-19] ✅ Wrote FIXED report files under: C:\mooc-coldstart-session-meta\reports\10_sasrec_baseline\20260102_234927
[10-19] ✅ Updated meta.json: C:\mooc-coldstart-session-meta\meta.json

[10-19] CHECKPOINT O
Paste: REPORT_DIR_RPAD + confirm meta.json updated.
