Bootstrap: locate repo root (Windows-safe) + env info

In [1]:
# [CELL 06-00] Bootstrap + env info (Windows-safe)

import os, sys, json, time, random
from pathlib import Path

import numpy as np
import pandas as pd
import torch

CWD = Path.cwd().resolve()
print("CWD:", CWD)
print("Python:", sys.version.split()[0])
print("Platform:", sys.platform)
print("torch:", torch.__version__)
print("numpy:", np.__version__)
print("pandas:", pd.__version__)

def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists() or (p / ".git").exists():
            return p
    raise FileNotFoundError("Could not find repo root (PROJECT_STATE.md or .git)")

REPO_ROOT = find_repo_root(CWD)
print("REPO_ROOT:", REPO_ROOT)

DATA_DIR = REPO_ROOT / "data"
PROC_DIR = DATA_DIR / "processed"
print("PROC_DIR:", PROC_DIR)


CWD: C:\mooc-coldstart-session-meta\notebooks
Python: 3.11.14
Platform: win32
torch: 2.9.1+cpu
numpy: 2.4.0
pandas: 2.3.3
REPO_ROOT: C:\mooc-coldstart-session-meta
PROC_DIR: C:\mooc-coldstart-session-meta\data\processed


Config: RUN_TAG + load tensor packs + metadata (PyTorch 2.6+ safe)

In [3]:
# [CELL 06-01] Config (REAL run tags + paths)

# Target tensors (from 05B)
RUN_TAG_TARGET = "20251229_163357"
TENSOR_TGT_DIR = PROC_DIR / "tensor_target"

TGT_TRAIN_PT = TENSOR_TGT_DIR / f"target_tensor_train_{RUN_TAG_TARGET}.pt"
TGT_VAL_PT   = TENSOR_TGT_DIR / f"target_tensor_val_{RUN_TAG_TARGET}.pt"
TGT_TEST_PT  = TENSOR_TGT_DIR / f"target_tensor_test_{RUN_TAG_TARGET}.pt"
TGT_META_JSON= TENSOR_TGT_DIR / f"target_tensor_metadata_{RUN_TAG_TARGET}.json"

# Source session sequences (from 05C)
RUN_TAG_SOURCE = "20251229_232834"
SEQ_SRC_DIR = PROC_DIR / "session_sequences" / f"source_sessions_{RUN_TAG_SOURCE}"

SRC_SEQ_TRAIN_GLOB = SEQ_SRC_DIR / "train" / "sessions_b*.parquet"
SRC_SEQ_VAL_GLOB   = SEQ_SRC_DIR / "val"   / "sessions_b*.parquet"
SRC_SEQ_TEST_GLOB  = SEQ_SRC_DIR / "test"  / "sessions_b*.parquet"

print("RUN_TAG_TARGET:", RUN_TAG_TARGET)
print("RUN_TAG_SOURCE:", RUN_TAG_SOURCE)
print("\nTarget PT exists?",
      TGT_TRAIN_PT.exists(), TGT_VAL_PT.exists(), TGT_TEST_PT.exists(), TGT_META_JSON.exists())
print("Source sequences dir exists?", SEQ_SRC_DIR.exists())
print("Source train glob example:", SRC_SEQ_TRAIN_GLOB)


RUN_TAG_TARGET: 20251229_163357
RUN_TAG_SOURCE: 20251229_232834

Target PT exists? True True True True
Source sequences dir exists? True
Source train glob example: C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\train\sessions_b*.parquet


Protocol config + deterministic seeds

In [4]:
# [CELL 06-02] 

PROTO = {
    "max_prefix_len": 20,          # must match your target tensor max_len
    "source_vocab_mode": "source_local",  # keep source vocab separate from target
    "source_pair_rule": "for t=1..L-1: input=last max_len of items[:t], label=items[t]",
    "source_long_session_policy": {
        "enabled": True,
        "cap_session_len": 200,     # cap very long sessions for speed/stability
        "cap_strategy": "take_last"
    },
    "dataloader": {
        "batch_size_train": 256,
        "batch_size_eval": 512,
        "num_workers": 0,           # Windows-safe; increase later if stable
        "pin_memory": False,
        "shuffle_train": True,
        "drop_last_train": False
    },
    "seeds": {
        "python": 20251229,
        "numpy": 20251229,
        "torch": 20251229
    }
}

def set_all_seeds(seed_py: int, seed_np: int, seed_torch: int):
    random.seed(seed_py)
    np.random.seed(seed_np)
    torch.manual_seed(seed_torch)
    # If later you use CUDA:
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_torch)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_all_seeds(PROTO["seeds"]["python"], PROTO["seeds"]["numpy"], PROTO["seeds"]["torch"])
print("Seeds set:", PROTO["seeds"])
print("PROTO:", json.dumps(PROTO, indent=2)[:500], "...")
print("\nCHECKPOINT 1 ✅ If max_prefix_len != 20 or you want different caps, stop here and adjust.")


Seeds set: {'python': 20251229, 'numpy': 20251229, 'torch': 20251229}
PROTO: {
  "max_prefix_len": 20,
  "source_vocab_mode": "source_local",
  "source_pair_rule": "for t=1..L-1: input=last max_len of items[:t], label=items[t]",
  "source_long_session_policy": {
    "enabled": true,
    "cap_session_len": 200,
    "cap_strategy": "take_last"
  },
  "dataloader": {
    "batch_size_train": 256,
    "batch_size_eval": 512,
    "num_workers": 0,
    "pin_memory": false,
    "shuffle_train": true,
    "drop_last_train": false
  },
  "seeds": {
    "python": 20251229,
    "num ...

CHECKPOINT 1 ✅ If max_prefix_len != 20 or you want different caps, stop here and adjust.


Load target tensors + safe torch.load

In [5]:
# [CELL 06-03] Load target tensors (safe torch.load for torch>=2.6)

from torch.serialization import safe_globals
import numpy as _np

def torch_load_trusted(path: Path):
    # These files are produced by YOU, so safe to load with weights_only=False.
    return torch.load(path, map_location="cpu", weights_only=False)

tgt_train = torch_load_trusted(TGT_TRAIN_PT)
tgt_val   = torch_load_trusted(TGT_VAL_PT)
tgt_test  = torch_load_trusted(TGT_TEST_PT)

print("Target train keys:", list(tgt_train.keys()))
print("train input_ids:", tuple(tgt_train["input_ids"].shape), "labels:", tuple(tgt_train["labels"].shape))
print("val   input_ids:", tuple(tgt_val["input_ids"].shape),   "labels:", tuple(tgt_val["labels"].shape))
print("test  input_ids:", tuple(tgt_test["input_ids"].shape),  "labels:", tuple(tgt_test["labels"].shape))

with open(TGT_META_JSON, "r", encoding="utf-8") as f:
    tgt_meta = json.load(f)

print("\nTarget meta:", {k: tgt_meta.get(k) for k in ["run_tag","max_len","vocab_size","pad_id","unk_id"]})
PAD_ID_TGT = int(tgt_meta["pad_id"])
UNK_ID_TGT = int(tgt_meta["unk_id"])
VOCAB_SIZE_TGT = int(tgt_meta["vocab_size"])

print("\nCHECKPOINT 2 ✅ Paste shapes + meta snippet if anything looks off.")


Target train keys: ['input_ids', 'attn_mask', 'labels', 'session_id', 'user_id', 't', 'split']
train input_ids: (1944, 20) labels: (1944,)
val   input_ids: (189, 20) labels: (189,)
test  input_ids: (200, 20) labels: (200,)

Target meta: {'run_tag': '20251229_163357', 'max_len': 20, 'vocab_size': 747, 'pad_id': 0, 'unk_id': 1}

CHECKPOINT 2 ✅ Paste shapes + meta snippet if anything looks off.


Target TensorDataset + loaders

In [6]:
# [CELL 06-04] Target TensorDataset + loaders

from torch.utils.data import Dataset, DataLoader

class TargetTensorDataset(Dataset):
    def __init__(self, blob: dict):
        self.blob = blob
        self.n = int(blob["input_ids"].shape[0])

    def __len__(self): return self.n

    def __getitem__(self, idx):
        return {
            "input_ids": self.blob["input_ids"][idx],
            "attn_mask": self.blob["attn_mask"][idx],
            "labels": self.blob["labels"][idx],
            "session_id": self.blob["session_id"][idx],
            "user_id": self.blob["user_id"][idx],
            "t": self.blob["t"][idx],
            "split": self.blob["split"][idx],
        }

tgt_train_ds = TargetTensorDataset(tgt_train)
tgt_val_ds   = TargetTensorDataset(tgt_val)
tgt_test_ds  = TargetTensorDataset(tgt_test)

g = torch.Generator()
g.manual_seed(PROTO["seeds"]["torch"])

tgt_train_loader = DataLoader(
    tgt_train_ds,
    batch_size=PROTO["dataloader"]["batch_size_train"],
    shuffle=PROTO["dataloader"]["shuffle_train"],
    num_workers=PROTO["dataloader"]["num_workers"],
    pin_memory=PROTO["dataloader"]["pin_memory"],
    generator=g
)
tgt_val_loader = DataLoader(
    tgt_val_ds,
    batch_size=PROTO["dataloader"]["batch_size_eval"],
    shuffle=False,
    num_workers=0
)
tgt_test_loader = DataLoader(
    tgt_test_ds,
    batch_size=PROTO["dataloader"]["batch_size_eval"],
    shuffle=False,
    num_workers=0
)

b = next(iter(tgt_train_loader))
print("Target batch keys:", list(b.keys()))
print("input_ids:", b["input_ids"].shape, "attn_mask:", b["attn_mask"].shape, "labels:", b["labels"].shape)
print("nonpad lens (first 5):", [int(x) for x in b["attn_mask"].sum(dim=1)[:5].tolist()])
print("\nCHECKPOINT 3 ✅ Paste this batch summary if anything looks wrong.")


Target batch keys: ['input_ids', 'attn_mask', 'labels', 'session_id', 'user_id', 't', 'split']
input_ids: torch.Size([256, 20]) attn_mask: torch.Size([256, 20]) labels: torch.Size([256])
nonpad lens (first 5): [1, 1, 9, 3, 3]

CHECKPOINT 3 ✅ Paste this batch summary if anything looks wrong.


Source: build vocab from TRAIN sequences only (deterministic)

In [7]:
# [CELL 06-05] Build SOURCE vocab from TRAIN sequences (only) to avoid leakage

import pyarrow.parquet as pq
import pyarrow.dataset as ds

MAX_LEN = int(PROTO["max_prefix_len"])
CAP_ENABLED = bool(PROTO["source_long_session_policy"]["enabled"])
CAP_LEN = int(PROTO["source_long_session_policy"]["cap_session_len"])
CAP_STRAT = PROTO["source_long_session_policy"]["cap_strategy"]

src_train_path = str(SRC_SEQ_TRAIN_GLOB)
dataset_train = ds.dataset(str(SEQ_SRC_DIR / "train"), format="parquet")

# We read only the 'items' column to build vocab.
# items is a list column.
item_counts = {}

# Read in fragments (safe)
for frag in dataset_train.get_fragments():
    tbl = frag.to_table(columns=["items"])
    col = tbl["items"].to_pylist()  # list of lists
    for items in col:
        if items is None:
            continue
        if CAP_ENABLED and len(items) > CAP_LEN and CAP_STRAT == "take_last":
            items = items[-CAP_LEN:]
        for it in items:
            item_counts[it] = item_counts.get(it, 0) + 1

# Sort by frequency then item_id for determinism
sorted_items = sorted(item_counts.items(), key=lambda x: (-x[1], str(x[0])))

# IDs: 0 pad, 1 unk, then items
src_item2id = {"<PAD>": 0, "<UNK>": 1}
for it, _ in sorted_items:
    if it not in src_item2id:
        src_item2id[str(it)] = len(src_item2id)

src_id2item = {v:k for k,v in src_item2id.items()}
SRC_PAD_ID = 0
SRC_UNK_ID = 1
SRC_VOCAB_SIZE = len(src_item2id)

print("Source vocab size:", SRC_VOCAB_SIZE)
print("Top items (first 10):", list(sorted_items[:10]))

OUT_SRC_VOCAB = SEQ_SRC_DIR / f"source_vocab_items_{RUN_TAG_SOURCE}.json"
OUT_SRC_VOCAB.write_text(json.dumps({
    "run_tag_source": RUN_TAG_SOURCE,
    "built_from": "train only",
    "vocab_size": SRC_VOCAB_SIZE,
    "pad_id": SRC_PAD_ID,
    "unk_id": SRC_UNK_ID,
    "item2id": src_item2id
}, indent=2), encoding="utf-8")
print("Saved source vocab:", OUT_SRC_VOCAB.resolve())

print("\nCHECKPOINT 4 ✅ Paste vocab size. If this takes too long, we will add a faster path.")


Source vocab size: 1620
Top items (first 10): [('course-v1:TsinghuaX+30640014+2015_T2', 4902340), ('course-v1:TsinghuaX+80512073X+2016_T1', 2435199), ('course-v1:TsinghuaX+80512073X_2015_2+2015_T2', 2119141), ('course-v1:TsinghuaX+30640014X+2016_T1', 2095212), ('course-v1:TsinghuaX+10610224X+2016_T1', 2055980), ('course-v1:TsinghuaX+30640014X+2016_T2', 1774362), ('course-v1:TsinghuaX+10610204X_2015_2+2015_T2', 1740811), ('course-v1:TsinghuaX+10610183_2X+2016_T2', 1687189), ('course-v1:TsinghuaX+10610183X_2015_T2+2015_T2', 1662289), ('course-v1:MITx+6_00_1x+sp', 1332019)]
Saved source vocab: C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\source_vocab_items_20251229_232834.json

CHECKPOINT 4 ✅ Paste vocab size. If this takes too long, we will add a faster path.


Source Dataset: session sequences → supervised pairs ON THE FLY

This dataset enumerates sessions and picks a t position. To avoid building a 146M index, we do a streaming sampler approach for training and a deterministic fixed subset for eval.

We’ll implement:

SourceEvalPairsDataset: deterministic pairs from the first N_SESSIONS_EVAL sessions, all t positions (capped)

SourceTrainIterable: iterable dataset that samples pairs deterministically with seed (no huge index)

In [8]:
# [CELL 06-06] Source datasets (train iterable + eval fixed)

from torch.utils.data import IterableDataset

def encode_prefix(prefix_items, item2id, max_len, pad_id, unk_id):
    ids = [item2id.get(str(x), unk_id) for x in prefix_items]
    if len(ids) >= max_len:
        ids = ids[-max_len:]
    attn = [1] * len(ids)
    # pad left
    pad_n = max_len - len(ids)
    if pad_n > 0:
        ids = [pad_id]*pad_n + ids
        attn = [0]*pad_n + attn
    return torch.tensor(ids, dtype=torch.long), torch.tensor(attn, dtype=torch.long)

class SourceEvalPairsDataset(Dataset):
    """
    Deterministic evaluation dataset:
    - reads a limited number of sessions (N_SESSIONS_EVAL)
    - emits all prefix->next pairs for each session (with optional cap)
    """
    def __init__(self, split_dir: Path, item2id: dict, max_len: int, pad_id: int, unk_id: int,
                 cap_enabled: bool, cap_len: int, cap_strategy: str,
                 n_sessions_eval: int = 20000):
        self.item2id = item2id
        self.max_len = max_len
        self.pad_id = pad_id
        self.unk_id = unk_id
        self.cap_enabled = cap_enabled
        self.cap_len = cap_len
        self.cap_strategy = cap_strategy

        dset = ds.dataset(str(split_dir), format="parquet")
        sessions = []
        for frag in dset.get_fragments():
            tbl = frag.to_table(columns=["session_id","user_id","items"])
            df = tbl.to_pandas()
            for _, r in df.iterrows():
                items = r["items"]
                if items is None or len(items) < 2:
                    continue
                if self.cap_enabled and len(items) > self.cap_len and self.cap_strategy == "take_last":
                    items = items[-self.cap_len:]
                sessions.append((str(r["session_id"]), str(r["user_id"]), [str(x) for x in items]))
                if len(sessions) >= n_sessions_eval:
                    break
            if len(sessions) >= n_sessions_eval:
                break

        # Build flat list of examples (session_idx, t)
        ex = []
        for si, (_, _, items) in enumerate(sessions):
            L = len(items)
            for t in range(1, L):  # t=1..L-1
                ex.append((si, t))
        self.sessions = sessions
        self.examples = ex

        print(f"[SourceEvalPairsDataset] sessions={len(self.sessions):,} examples={len(self.examples):,}")

    def __len__(self): return len(self.examples)

    def __getitem__(self, idx):
        si, t = self.examples[idx]
        session_id, user_id, items = self.sessions[si]
        prefix = items[:t]
        label = items[t]
        x_ids, attn = encode_prefix(prefix, self.item2id, self.max_len, self.pad_id, self.unk_id)
        y = self.item2id.get(str(label), self.unk_id)
        return {
            "input_ids": x_ids,
            "attn_mask": attn,
            "labels": torch.tensor(y, dtype=torch.long),
            "session_id": session_id,
            "user_id": user_id,
            "t": torch.tensor(t, dtype=torch.long),
            "split": "eval"
        }

class SourceTrainIterable(IterableDataset):
    """
    Iterable training stream:
    - reads session rows sequentially
    - yields pairs for each session
    - deterministic order if shuffle_sessions=False
    - optional shuffle using hash(session_id) with seed
    """
    def __init__(self, split_dir: Path, item2id: dict, max_len: int, pad_id: int, unk_id: int,
                 cap_enabled: bool, cap_len: int, cap_strategy: str,
                 seed: int = 20251229, shuffle_sessions: bool = True, max_pairs_per_session: int = 0):
        self.split_dir = split_dir
        self.item2id = item2id
        self.max_len = max_len
        self.pad_id = pad_id
        self.unk_id = unk_id
        self.cap_enabled = cap_enabled
        self.cap_len = cap_len
        self.cap_strategy = cap_strategy
        self.seed = seed
        self.shuffle_sessions = shuffle_sessions
        self.max_pairs_per_session = int(max_pairs_per_session)

    def __iter__(self):
        dset = ds.dataset(str(self.split_dir), format="parquet")
        frags = list(dset.get_fragments())

        # deterministic "shuffle": sort by hash(fragment path + seed)
        if self.shuffle_sessions:
            frags.sort(key=lambda f: hash((str(f.path), self.seed)))

        for frag in frags:
            tbl = frag.to_table(columns=["session_id","user_id","items"])
            df = tbl.to_pandas()

            if self.shuffle_sessions:
                # deterministic per-fragment shuffle
                df["_k"] = df["session_id"].map(lambda s: hash((str(s), self.seed)))
                df = df.sort_values("_k").drop(columns=["_k"])

            for _, r in df.iterrows():
                items = r["items"]
                if items is None or len(items) < 2:
                    continue
                if self.cap_enabled and len(items) > self.cap_len and self.cap_strategy == "take_last":
                    items = items[-self.cap_len:]
                items = [str(x) for x in items]

                L = len(items)
                t_positions = list(range(1, L))

                # Optional cap: sample subset of t positions deterministically
                if self.max_pairs_per_session > 0 and len(t_positions) > self.max_pairs_per_session:
                    # deterministic pick based on session_id hash
                    sid = str(r["session_id"])
                    k = abs(hash((sid, self.seed))) % len(t_positions)
                    # take a window of size max_pairs_per_session
                    t_positions = (t_positions[k:] + t_positions[:k])[:self.max_pairs_per_session]

                for t in t_positions:
                    prefix = items[:t]
                    label = items[t]
                    x_ids, attn = encode_prefix(prefix, self.item2id, self.max_len, self.pad_id, self.unk_id)
                    y = self.item2id.get(str(label), self.unk_id)
                    yield {
                        "input_ids": x_ids,
                        "attn_mask": attn,
                        "labels": torch.tensor(y, dtype=torch.long),
                        "session_id": str(r["session_id"]),
                        "user_id": str(r["user_id"]),
                        "t": torch.tensor(t, dtype=torch.long),
                        "split": "train"
                    }

# Build datasets
src_train_iter = SourceTrainIterable(
    split_dir=SEQ_SRC_DIR / "train",
    item2id=src_item2id,
    max_len=MAX_LEN, pad_id=SRC_PAD_ID, unk_id=SRC_UNK_ID,
    cap_enabled=CAP_ENABLED, cap_len=CAP_LEN, cap_strategy=CAP_STRAT,
    seed=PROTO["seeds"]["torch"], shuffle_sessions=True,
    max_pairs_per_session=0  # set >0 to cap compute
)

src_val_ds = SourceEvalPairsDataset(
    split_dir=SEQ_SRC_DIR / "val",
    item2id=src_item2id,
    max_len=MAX_LEN, pad_id=SRC_PAD_ID, unk_id=SRC_UNK_ID,
    cap_enabled=CAP_ENABLED, cap_len=CAP_LEN, cap_strategy=CAP_STRAT,
    n_sessions_eval=20000
)

src_test_ds = SourceEvalPairsDataset(
    split_dir=SEQ_SRC_DIR / "test",
    item2id=src_item2id,
    max_len=MAX_LEN, pad_id=SRC_PAD_ID, unk_id=SRC_UNK_ID,
    cap_enabled=CAP_ENABLED, cap_len=CAP_LEN, cap_strategy=CAP_STRAT,
    n_sessions_eval=20000
)

print("\nCHECKPOINT 5 ✅ Paste src_val_ds examples count + src_test_ds examples count.")


[SourceEvalPairsDataset] sessions=20,000 examples=331,645
[SourceEvalPairsDataset] sessions=20,000 examples=327,311

CHECKPOINT 5 ✅ Paste src_val_ds examples count + src_test_ds examples count.


Source DataLoaders

In [9]:
# [CELL 06-07] Source loaders

def collate_dict(batch):
    # batch is list of dicts with tensors already padded
    out = {}
    keys = batch[0].keys()
    for k in keys:
        v0 = batch[0][k]
        if torch.is_tensor(v0):
            out[k] = torch.stack([b[k] for b in batch], dim=0)
        else:
            out[k] = [b[k] for b in batch]
    return out

src_train_loader = DataLoader(
    src_train_iter,
    batch_size=PROTO["dataloader"]["batch_size_train"],
    num_workers=0,
    collate_fn=collate_dict
)

src_val_loader = DataLoader(
    src_val_ds,
    batch_size=PROTO["dataloader"]["batch_size_eval"],
    shuffle=False,
    num_workers=0,
    collate_fn=collate_dict
)

src_test_loader = DataLoader(
    src_test_ds,
    batch_size=PROTO["dataloader"]["batch_size_eval"],
    shuffle=False,
    num_workers=0,
    collate_fn=collate_dict
)

b = next(iter(src_train_loader))
print("Source batch keys:", list(b.keys()))
print("input_ids:", b["input_ids"].shape, "attn_mask:", b["attn_mask"].shape, "labels:", b["labels"].shape)
print("nonpad lens (first 5):", [int(x) for x in b["attn_mask"].sum(dim=1)[:5].tolist()])
print("labels (first 5):", b["labels"][:5].tolist())
print("\nCHECKPOINT 6 ✅ Paste this source batch summary if anything looks wrong.")


Source batch keys: ['input_ids', 'attn_mask', 'labels', 'session_id', 'user_id', 't', 'split']
input_ids: torch.Size([256, 20]) attn_mask: torch.Size([256, 20]) labels: torch.Size([256])
nonpad lens (first 5): [1, 2, 3, 4, 5]
labels (first 5): [93, 93, 93, 93, 93]

CHECKPOINT 6 ✅ Paste this source batch summary if anything looks wrong.


Metrics (HR/MRR/NDCG) — shared

In [10]:
# [CELL 06-08] Metrics: HR@K / MRR@K / NDCG@K (batch)

import math

@torch.no_grad()
def batch_metrics_from_scores(scores: torch.Tensor, labels: torch.Tensor, ks=(5,10,20)):
    """
    scores: [B, V] higher is better
    labels: [B] int64
    """
    B, V = scores.shape
    max_k = max(ks)
    topk = torch.topk(scores, k=max_k, dim=1).indices  # [B, max_k]
    out = {}
    for K in ks:
        preds = topk[:, :K]  # [B, K]
        hit = (preds == labels.unsqueeze(1)).any(dim=1).float()
        out[f"HR@{K}"] = float(hit.mean().item())

        # rank (1..K) if hit else inf
        ranks = torch.full((B,), fill_value=0, dtype=torch.long)
        for i in range(B):
            m = (preds[i] == labels[i]).nonzero(as_tuple=False)
            ranks[i] = (m[0,0] + 1) if len(m) else 0

        rr = torch.where(ranks > 0, 1.0 / ranks.float(), torch.zeros_like(ranks, dtype=torch.float))
        out[f"MRR@{K}"] = float(rr.mean().item())

        ndcg = torch.where(ranks > 0, 1.0 / torch.log2(ranks.float() + 1.0), torch.zeros_like(ranks, dtype=torch.float))
        out[f"NDCG@{K}"] = float(ndcg.mean().item())
    return out

@torch.no_grad()
def eval_loader(model_fn, loader, vocab_size: int, ks=(5,10,20), device="cpu", max_batches=None):
    """
    model_fn: function(batch) -> scores [B, V]
    """
    totals = {f"HR@{k}": 0.0 for k in ks}
    totals.update({f"MRR@{k}": 0.0 for k in ks})
    totals.update({f"NDCG@{k}": 0.0 for k in ks})

    n = 0
    bcount = 0
    for batch in loader:
        bcount += 1
        if max_batches and bcount > max_batches:
            break

        input_ids = batch["input_ids"].to(device)
        attn_mask = batch["attn_mask"].to(device)
        labels = batch["labels"].to(device)

        scores = model_fn(input_ids, attn_mask)  # [B,V]
        m = batch_metrics_from_scores(scores, labels, ks=ks)
        bs = int(labels.shape[0])
        for k,v in m.items():
            totals[k] += v * bs
        n += bs

    for k in totals:
        totals[k] = totals[k] / max(n, 1)
    return totals, {"rows": n, "batches": bcount}


Sanity models: Random + MostPop on TARGET and SOURCE eval

In [11]:
# [CELL 06-09] Sanity: Random + MostPop

DEVICE = "cpu"
ks = (5,10,20)

# --- TARGET MostPop: compute popularity from target train tensors ---
tgt_train_labels = tgt_train["labels"].numpy().tolist()
pop_counts = {}
for y in tgt_train_labels:
    pop_counts[int(y)] = pop_counts.get(int(y), 0) + 1
tgt_pop_rank = [k for k,_ in sorted(pop_counts.items(), key=lambda x: (-x[1], x[0]))]
tgt_pop_rank = tgt_pop_rank[:VOCAB_SIZE_TGT]

tgt_pop_tensor = torch.tensor(tgt_pop_rank, dtype=torch.long)

def random_model_fn(vocab_size):
    def fn(input_ids, attn_mask):
        B = input_ids.shape[0]
        return torch.rand((B, vocab_size), device=input_ids.device)
    return fn

def mostpop_model_fn(pop_rank_tensor, vocab_size):
    # returns scores with higher for most popular
    def fn(input_ids, attn_mask):
        B = input_ids.shape[0]
        scores = torch.zeros((B, vocab_size), device=input_ids.device)
        # descending scores for ranked items
        # score = -(rank)
        scores[:, pop_rank_tensor.to(scores.device)] = torch.linspace(1.0, 0.0, steps=len(pop_rank_tensor), device=scores.device)
        return scores
    return fn

# Target sanity on VAL
tgt_random_metrics, tgt_random_aux = eval_loader(random_model_fn(VOCAB_SIZE_TGT), tgt_val_loader, VOCAB_SIZE_TGT, ks=ks, device=DEVICE, max_batches=1)
tgt_mostpop_metrics, tgt_mostpop_aux = eval_loader(mostpop_model_fn(tgt_pop_tensor, VOCAB_SIZE_TGT), tgt_val_loader, VOCAB_SIZE_TGT, ks=ks, device=DEVICE, max_batches=1)

print("[TARGET] Random on VAL:", tgt_random_metrics, "|", tgt_random_aux)
print("[TARGET] MostPop on VAL:", tgt_mostpop_metrics, "|", tgt_mostpop_aux)

# --- SOURCE MostPop: compute popularity from source TRAIN sequences (fast approximate: first N fragments) ---
# To keep this notebook responsive, we do an approximate MostPop from first ~N fragments.
# Later, in 07/08 we can compute full pop offline if needed.
SRC_V = SRC_VOCAB_SIZE

src_pop_counts = {}
dataset_train = ds.dataset(str(SEQ_SRC_DIR / "train"), format="parquet")
frag_limit = 50
frag_seen = 0
for frag in dataset_train.get_fragments():
    tbl = frag.to_table(columns=["items"])
    for items in tbl["items"].to_pylist():
        if items is None:
            continue
        if CAP_ENABLED and len(items) > CAP_LEN and CAP_STRAT == "take_last":
            items = items[-CAP_LEN:]
        for it in items:
            sid = src_item2id.get(str(it), SRC_UNK_ID)
            src_pop_counts[sid] = src_pop_counts.get(sid, 0) + 1
    frag_seen += 1
    if frag_seen >= frag_limit:
        break

src_pop_rank = [k for k,_ in sorted(src_pop_counts.items(), key=lambda x: (-x[1], x[0]))]
# ensure includes at least something
if len(src_pop_rank) < 10:
    src_pop_rank = list(range(2, min(SRC_V, 1000)))

src_pop_tensor = torch.tensor(src_pop_rank[:min(len(src_pop_rank), SRC_V)], dtype=torch.long)

src_random_metrics, src_random_aux = eval_loader(random_model_fn(SRC_V), src_val_loader, SRC_V, ks=ks, device=DEVICE, max_batches=1)
src_mostpop_metrics, src_mostpop_aux = eval_loader(mostpop_model_fn(src_pop_tensor, SRC_V), src_val_loader, SRC_V, ks=ks, device=DEVICE, max_batches=1)

print("\n[SOURCE] Random on VAL:", src_random_metrics, "|", src_random_aux)
print("[SOURCE] MostPop(approx) on VAL:", src_mostpop_metrics, "|", src_mostpop_aux)

print("\nCHECKPOINT 7 ✅ Paste the four metric dicts (target random/mostpop + source random/mostpop).")


[TARGET] Random on VAL: {'HR@5': 0.005291005130857229, 'HR@10': 0.01587301678955555, 'HR@20': 0.026455026119947433, 'MRR@5': 0.0026455025654286146, 'MRR@10': 0.0038947677239775658, 'MRR@20': 0.004629629664123058, 'NDCG@5': 0.0033382526598870754, 'NDCG@10': 0.006600130349397659, 'NDCG@20': 0.009275511838495731} | {'rows': 189, 'batches': 1}
[TARGET] MostPop on VAL: {'HR@5': 0.07407407462596893, 'HR@10': 0.13227513432502747, 'HR@20': 0.1746031790971756, 'MRR@5': 0.050264548510313034, 'MRR@10': 0.05870497226715088, 'MRR@20': 0.06130174547433853, 'NDCG@5': 0.05620747059583664, 'NDCG@10': 0.07570020109415054, 'NDCG@20': 0.08597812801599503} | {'rows': 189, 'batches': 1}

[SOURCE] Random on VAL: {'HR@5': 0.001953125, 'HR@10': 0.00390625, 'HR@20': 0.005859375, 'MRR@5': 0.001953125, 'MRR@10': 0.0021484375465661287, 'MRR@20': 0.0022879464086145163, 'NDCG@5': 0.001953125, 'NDCG@10': 0.002517704851925373, 'NDCG@20': 0.00301762274466455} | {'rows': 512, 'batches': 2}
[SOURCE] MostPop(approx) on VA

Save loader protocol config (recreate later)

In [12]:
# [CELL 06-10] Save dataloader protocol config (reproducibility)

OUT_CFG = PROC_DIR / "supervised" / f"dataloader_config_{RUN_TAG_TARGET}_{RUN_TAG_SOURCE}.json"
OUT_CFG.parent.mkdir(parents=True, exist_ok=True)

cfg = {
    "target": {
        "run_tag": RUN_TAG_TARGET,
        "tensor_dir": str(TENSOR_TGT_DIR.resolve()),
        "train_pt": str(TGT_TRAIN_PT.resolve()),
        "val_pt": str(TGT_VAL_PT.resolve()),
        "test_pt": str(TGT_TEST_PT.resolve()),
        "meta_json": str(TGT_META_JSON.resolve()),
    },
    "source": {
        "run_tag": RUN_TAG_SOURCE,
        "seq_dir": str(SEQ_SRC_DIR.resolve()),
        "train_glob": str(SRC_SEQ_TRAIN_GLOB),
        "val_glob": str(SRC_SEQ_VAL_GLOB),
        "test_glob": str(SRC_SEQ_TEST_GLOB),
        "vocab_json": str(OUT_SRC_VOCAB.resolve()),
    },
    "protocol": PROTO,
}

OUT_CFG.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
print("Wrote:", OUT_CFG.resolve())
print("\nDone ✅ 06_data_loader_and_eval_protocol is ready for baselines.")


Wrote: C:\mooc-coldstart-session-meta\data\processed\supervised\dataloader_config_20251229_163357_20251229_232834.json

Done ✅ 06_data_loader_and_eval_protocol is ready for baselines.


In [13]:
# [CELL 06-11] Stronger sanity (still fast)

ks = (5,10,20)
DEVICE = "cpu"

src_random_metrics_20, _ = eval_loader(random_model_fn(SRC_VOCAB_SIZE), src_val_loader, SRC_VOCAB_SIZE, ks=ks, device=DEVICE, max_batches=20)
src_mostpop_metrics_20, _ = eval_loader(mostpop_model_fn(src_pop_tensor, SRC_VOCAB_SIZE), src_val_loader, SRC_VOCAB_SIZE, ks=ks, device=DEVICE, max_batches=20)

print("[SOURCE] Random (20 batches):", src_random_metrics_20)
print("[SOURCE] MostPop approx (20 batches):", src_mostpop_metrics_20)


[SOURCE] Random (20 batches): {'HR@5': 0.00244140625, 'HR@10': 0.00625, 'HR@20': 0.0115234375, 'MRR@5': 0.001062825536064338, 'MRR@10': 0.0015507192772929556, 'MRR@20': 0.00190973978897091, 'NDCG@5': 0.0014018987130839378, 'NDCG@10': 0.0026125961798243225, 'NDCG@20': 0.003937902301549912}
[SOURCE] MostPop approx (20 batches): {'HR@5': 0.05693359375, 'HR@10': 0.13525390625, 'HR@20': 0.21279296875, 'MRR@5': 0.027268881801865062, 'MRR@10': 0.037007572944276035, 'MRR@20': 0.04246135508292355, 'NDCG@5': 0.03453330923803151, 'NDCG@10': 0.059153526765294374, 'NDCG@20': 0.07883587318938226}
