Imports + versions

In [2]:
#[CELL 11A-00] Imports + versions

import os, json, time, math, random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

print("[11A-00] torch:", torch.__version__)
print("[11A-00] pandas:", pd.__version__)
print("[11A-00] numpy:", np.__version__)


[11A-00] torch: 2.9.1+cpu
[11A-00] pandas: 2.3.3
[11A-00] numpy: 2.4.0


Repo root + run tag + load protocol JSONs

In [5]:
# [CELL 11A-01] Repo root + run tag + load protocol JSONs
from datetime import datetime

REPO_ROOT = Path(r"C:\mooc-coldstart-session-meta").resolve()
RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")

cfg_path = REPO_ROOT / "data/processed/supervised/dataloader_config_20251229_163357_20251229_232834.json"
sanity_path = REPO_ROOT / "data/processed/supervised/sanity_metrics_20251229_163357_20251229_232834.json"
gaps_path = REPO_ROOT / "data/processed/normalized_events/session_gap_thresholds.json"

print("[11A-01] REPO_ROOT:", REPO_ROOT)
print("[11A-01] RUN_TAG:", RUN_TAG)
print("[11A-01] Expect config:", cfg_path)
print("[11A-01] Expect sanity:", sanity_path)
print("[11A-01] Expect gaps:", gaps_path)

with open(cfg_path, "r", encoding="utf-8") as f:
    DL_CFG = json.load(f)
with open(sanity_path, "r", encoding="utf-8") as f:
    SANITY = json.load(f)
with open(gaps_path, "r", encoding="utf-8") as f:
    GAPS = json.load(f)

print("[11A-01] Loaded dataloader_config keys:", list(DL_CFG.keys()))
print("[11A-01] Loaded sanity_metrics keys:", list(SANITY.keys()))
print("[11A-01] Loaded session_gap_thresholds keys:", list(GAPS.keys()))

# --- [PATCH v2] robust gap minutes inference (now supports primary_threshold_seconds) ---

def infer_gap_minutes(d: dict, name: str) -> int:
    """
    Accept multiple schema variants.
    Returns int minutes.
    """
    # explicit primary key in your file
    if "primary_threshold_seconds" in d:
        v = int(d["primary_threshold_seconds"])
        m = int(round(v / 60))
        lbl = d.get("primary_threshold_label", None)
        print(f"[11A-01] {name}: gap_minutes from 'primary_threshold_seconds'={v}s -> {m}m | label={lbl}")
        return m

    # common minute keys
    for k in ["gap_minutes", "session_gap_minutes", "threshold_minutes", "minutes", "gap_min"]:
        if k in d:
            v = int(d[k])
            print(f"[11A-01] {name}: gap_minutes from key '{k}' = {v}")
            return v

    # common second keys
    for k in ["gap_seconds", "session_gap_seconds", "threshold_seconds", "seconds", "gap_sec"]:
        if k in d:
            v = int(d[k])
            m = int(round(v / 60))
            print(f"[11A-01] {name}: gap_minutes inferred from '{k}'={v}s -> {m}m")
            return m

    # sometimes nested threshold object
    if "gap" in d and isinstance(d["gap"], dict):
        return infer_gap_minutes(d["gap"], name + ".gap")

    raise KeyError(f"[11A-01] {name}: Could not infer gap minutes. Keys={list(d.keys())}")

print("[11A-01] target keys:", list(GAPS["target"].keys()))
print("[11A-01] source keys:", list(GAPS["source"].keys()))

gap_target_m = infer_gap_minutes(GAPS["target"], "target")
gap_source_m = infer_gap_minutes(GAPS["source"], "source")

assert gap_target_m == 30, f"target gap mismatch: got {gap_target_m}m"
assert gap_source_m == 10, f"source gap mismatch: got {gap_source_m}m"

print("[11A-01] ✅ Session gaps confirmed:", f"target={gap_target_m}m,", f"source={gap_source_m}m")

print("\n[11A-01] CHECKPOINT A")
print("Paste: inferred minutes + labels (if printed).")


[11A-01] REPO_ROOT: C:\mooc-coldstart-session-meta
[11A-01] RUN_TAG: 20260103_220933
[11A-01] Expect config: C:\mooc-coldstart-session-meta\data\processed\supervised\dataloader_config_20251229_163357_20251229_232834.json
[11A-01] Expect sanity: C:\mooc-coldstart-session-meta\data\processed\supervised\sanity_metrics_20251229_163357_20251229_232834.json
[11A-01] Expect gaps: C:\mooc-coldstart-session-meta\data\processed\normalized_events\session_gap_thresholds.json
[11A-01] Loaded dataloader_config keys: ['target', 'source', 'protocol']
[11A-01] Loaded sanity_metrics keys: ['run_tag_target', 'run_tag_source', 'created_at', 'target', 'source', 'notes']
[11A-01] Loaded session_gap_thresholds keys: ['generated_from_run_tag', 'generated_at', 'target', 'source', 'decision_notes']
[11A-01] target keys: ['primary_threshold_seconds', 'primary_threshold_label']
[11A-01] source keys: ['primary_threshold_seconds', 'primary_threshold_label', 'sampling']
[11A-01] target: gap_minutes from 'primary_thr

Paths: source sequences + vocab

In [7]:
# [CELL 11A-02] Paths: source sequences + vocab
SOURCE_RUN_TAG = "20251229_232834"

source_train_dir = REPO_ROOT / f"data/processed/session_sequences/source_sessions_{SOURCE_RUN_TAG}/train"
source_val_dir   = REPO_ROOT / f"data/processed/session_sequences/source_sessions_{SOURCE_RUN_TAG}/val"
source_test_dir  = REPO_ROOT / f"data/processed/session_sequences/source_sessions_{SOURCE_RUN_TAG}/test"
source_vocab_path = REPO_ROOT / f"data/processed/session_sequences/source_sessions_{SOURCE_RUN_TAG}/source_vocab_items_{SOURCE_RUN_TAG}.json"

for p in [source_train_dir, source_val_dir, source_test_dir, source_vocab_path]:
    assert p.exists(), f"Missing: {p}"

train_files = sorted(source_train_dir.glob("*.parquet"))
val_files = sorted(source_val_dir.glob("*.parquet"))
test_files = sorted(source_test_dir.glob("*.parquet"))

print("[11A-02] Source shards counts: train=", len(train_files), "val=", len(val_files), "test=", len(test_files))
print("[11A-02] source_vocab:", source_vocab_path)

print("\n[11A-02] CHECKPOINT B")
print("Confirm all source dirs/files exist and shard counts look like 1024 each.")


[11A-02] Source shards counts: train= 1024 val= 1024 test= 1024
[11A-02] source_vocab: C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\source_vocab_items_20251229_232834.json

[11A-02] CHECKPOINT B
Confirm all source dirs/files exist and shard counts look like 1024 each.


Load source vocab + PAD/UNK + infer seq_col

In [8]:
# [CELL 11A-03] Load source vocab + PAD/UNK + infer seq_col
with open(source_vocab_path, "r", encoding="utf-8") as f:
    source_vocab = json.load(f)

PAD_ID_SOURCE = int(source_vocab.get("pad_id", 0))
UNK_ID_SOURCE = int(source_vocab.get("unk_id", 1))
VOCAB_SIZE_SOURCE = int(source_vocab["vocab_size"])
item2id = source_vocab["item2id"]  # token->id

print("[11A-03] source_vocab keys:", list(source_vocab.keys()))
print("[11A-03] VOCAB_SIZE_SOURCE:", VOCAB_SIZE_SOURCE)
print("[11A-03] PAD_ID_SOURCE:", PAD_ID_SOURCE, "| UNK_ID_SOURCE:", UNK_ID_SOURCE)
print("[11A-03] item2id size:", len(item2id))

probe = pd.read_parquet(train_files[0])
print("[11A-03] Probe columns:", list(probe.columns))

# we already used "items" in 07/08
seq_col = "items" if "items" in probe.columns else None
assert seq_col is not None, "Could not find seq column (expected 'items')."

x0 = probe[seq_col].iloc[0]
print("[11A-03] seq_col:", seq_col, "| first seq type:", type(x0), "| first elem type:", type(x0[0]))

print("\n[11A-03] CHECKPOINT C")
print("Paste: VOCAB_SIZE_SOURCE, PAD/UNK, seq_col, and first element type (should be str).")


[11A-03] source_vocab keys: ['run_tag_source', 'built_from', 'vocab_size', 'pad_id', 'unk_id', 'item2id']
[11A-03] VOCAB_SIZE_SOURCE: 1620
[11A-03] PAD_ID_SOURCE: 0 | UNK_ID_SOURCE: 1
[11A-03] item2id size: 1620
[11A-03] Probe columns: ['domain', 'user_id', 'session_id', 'session_length', 'start_ts', 'end_ts', 'items', 'split']
[11A-03] seq_col: items | first seq type: <class 'numpy.ndarray'> | first elem type: <class 'str'>

[11A-03] CHECKPOINT C
Paste: VOCAB_SIZE_SOURCE, PAD/UNK, seq_col, and first element type (should be str).


Add this new cell: Rebind PROTO safely
Bind + normalize PROTO (handles key name variants safely)

In [12]:
# [CELL 11A-03A] Bind + normalize PROTO from DL_CFG (actual schema)

if "DL_CFG" not in globals():
    cfg_path = REPO_ROOT / "data/processed/supervised/dataloader_config_20251229_163357_20251229_232834.json"
    with open(cfg_path, "r", encoding="utf-8") as f:
        DL_CFG = json.load(f)
    print("[11A-03A] Re-loaded dataloader_config:", cfg_path)

P = DL_CFG["protocol"]
print("[11A-03A] protocol keys:", list(P.keys()))

# 1) MAX_PREFIX_LEN
assert "max_prefix_len" in P, f"[11A-03A] missing max_prefix_len. keys={list(P.keys())}"
MAX_PREFIX_LEN = int(P["max_prefix_len"])

# 2) K_LIST is a fixed decision from Notebook 06 (not stored in config)
K_LIST = [5, 10, 20]

# 3) CAP policy from source_long_session_policy
# Expecting something like: {"enabled": true, "cap_session_len": 200, "strategy": "take_last"} (or similar)
SLSP = P.get("source_long_session_policy", {})
print("[11A-03A] source_long_session_policy keys:", list(SLSP.keys()))

def pick(d, keys, default=None):
    for k in keys:
        if k in d:
            return d[k]
    return default

CAP_ENABLED = bool(pick(SLSP, ["enabled", "cap_enabled", "capEnabled"], default=True))

# Prefer explicit cap length if present, else fall back to fixed decision (200)
CAP_SESSION_LEN = pick(SLSP, ["cap_session_len", "cap_len", "max_len", "capSessionLen"], default=200)
CAP_SESSION_LEN = int(CAP_SESSION_LEN)

# Prefer explicit strategy if present, else fixed decision (take_last)
CAP_STRATEGY = pick(SLSP, ["strategy", "cap_strategy", "capStrategy"], default="take_last")
CAP_STRATEGY = str(CAP_STRATEGY)

# Canonical dict used by remaining cells
PROTO = {
    "K_LIST": K_LIST,
    "MAX_PREFIX_LEN": MAX_PREFIX_LEN,
    "CAP_ENABLED": CAP_ENABLED,
    "CAP_SESSION_LEN": CAP_SESSION_LEN,
    "CAP_STRATEGY": CAP_STRATEGY,
}

print("[11A-03A] ✅ Normalized PROTO:", PROTO)

# Hard asserts (must match Notebook 06 / prior notebooks)
assert PROTO["MAX_PREFIX_LEN"] == 20, f"MAX_PREFIX_LEN drift: {PROTO['MAX_PREFIX_LEN']}"
assert PROTO["CAP_ENABLED"] is True, f"CAP_ENABLED drift: {PROTO['CAP_ENABLED']}"
assert PROTO["CAP_SESSION_LEN"] == 200, f"CAP_SESSION_LEN drift: {PROTO['CAP_SESSION_LEN']}"
assert PROTO["CAP_STRATEGY"] == "take_last", f"CAP_STRATEGY drift: {PROTO['CAP_STRATEGY']}"
assert PROTO["K_LIST"] == [5, 10, 20], f"K_LIST drift: {PROTO['K_LIST']}"

print("[11A-03A] ✅ PROTO asserts passed (matches Notebook 06).")

print("\n[11A-03A] CHECKPOINT D")
print("Paste: protocol keys + source_long_session_policy keys + PROTO.")


[11A-03A] protocol keys: ['max_prefix_len', 'source_vocab_mode', 'source_pair_rule', 'source_long_session_policy', 'dataloader', 'seeds']
[11A-03A] source_long_session_policy keys: ['enabled', 'cap_session_len', 'cap_strategy']
[11A-03A] ✅ Normalized PROTO: {'K_LIST': [5, 10, 20], 'MAX_PREFIX_LEN': 20, 'CAP_ENABLED': True, 'CAP_SESSION_LEN': 200, 'CAP_STRATEGY': 'take_last'}
[11A-03A] ✅ PROTO asserts passed (matches Notebook 06).

[11A-03A] CHECKPOINT D
Paste: protocol keys + source_long_session_policy keys + PROTO.


Streaming batch generator (one-pair-per-session)

In [18]:
# [CELL 11A-04] Streaming pair generator (one-pair-per-session) — SELF-SUFFICIENT
# - Reads parquet shards lazily per file
# - Maps string tokens -> ids using source_vocab["item2id"]
# - Applies CAP policy (take_last, 200) + MAX_PREFIX_LEN (20)
# - Emits one (x_ids, attn_mask, y_id) per session
# - Deterministic sampling for string session_id via stable hash

import time
import random
import hashlib
from pathlib import Path

import numpy as np
import pandas as pd
import torch

# ---- Protocol (must already be bound in [11A-03A]) ----
MAX_LEN = int(PROTO["MAX_PREFIX_LEN"])
CAP_ENABLED = bool(PROTO["CAP_ENABLED"])
CAP_SESSION_LEN = int(PROTO["CAP_SESSION_LEN"])
CAP_STRATEGY = str(PROTO["CAP_STRATEGY"])

assert MAX_LEN == 20
assert CAP_ENABLED is True
assert CAP_SESSION_LEN == 200
assert CAP_STRATEGY == "take_last"

# ---- Ensure vocab mapping exists (rebuild if missing) ----
# Requires: source_vocab loaded in [11A-03] and PAD/UNK ids resolved (or fallback).
if "source_vocab" not in globals():
    raise NameError("[11A-04] source_vocab not found. Run [11A-03] first (loads source_vocab JSON).")

if "source_token_to_id" not in globals():
    # Build from source_vocab schema
    if "item2id" not in source_vocab:
        raise KeyError(f"[11A-04] source_vocab missing 'item2id'. keys={list(source_vocab.keys())}")
    source_token_to_id = source_vocab["item2id"]
    print(f"[11A-04] Built source_token_to_id from source_vocab['item2id'] size={len(source_token_to_id):,}")
else:
    print(f"[11A-04] source_token_to_id already in globals size={len(source_token_to_id):,}")

# Ensure PAD/UNK are present
PAD_ID_SOURCE = int(source_vocab.get("pad_id", 0)) if "PAD_ID_SOURCE" not in globals() else int(PAD_ID_SOURCE)
UNK_ID_SOURCE = int(source_vocab.get("unk_id", 1)) if "UNK_ID_SOURCE" not in globals() else int(UNK_ID_SOURCE)

print("[11A-04] PAD/UNK:", {"PAD_ID_SOURCE": PAD_ID_SOURCE, "UNK_ID_SOURCE": UNK_ID_SOURCE})

# seq_col must be detected in [11A-03]
if "seq_col" not in globals():
    raise NameError("[11A-04] seq_col not found. Run [11A-03] first (detects seq_col).")

# ---- Helpers ----
def stable_mod(value, mod: int) -> int:
    """Deterministic mod for any session_id type (str/int/...)."""
    s = str(value).encode("utf-8")
    h = hashlib.blake2b(s, digest_size=8).digest()
    n = int.from_bytes(h, byteorder="little", signed=False)
    return n % mod

def map_tokens_to_ids(tokens) -> np.ndarray:
    """tokens: iterable of str -> np.int64 ids with UNK fallback."""
    out = np.empty(len(tokens), dtype=np.int64)
    for i, tok in enumerate(tokens):
        out[i] = source_token_to_id.get(str(tok), UNK_ID_SOURCE)
    return out

def session_to_one_pair(seq_tokens):
    """
    seq_tokens: list/np array of tokens (strings)
    Returns (x_ids[int64, MAX_LEN], attn_mask[int64, MAX_LEN], y_id[int]) or None
    """
    seq = seq_tokens
    if not isinstance(seq, (np.ndarray, list, tuple)):
        return None

    # Cap long sessions first
    if CAP_ENABLED and len(seq) > CAP_SESSION_LEN and CAP_STRATEGY == "take_last":
        seq = seq[-CAP_SESSION_LEN:]

    if len(seq) < 2:
        return None

    prefix_tokens = seq[:-1]
    label_token = seq[-1]

    # Cap prefix to MAX_LEN (take_last)
    if len(prefix_tokens) > MAX_LEN:
        prefix_tokens = prefix_tokens[-MAX_LEN:]

    x_ids = map_tokens_to_ids(prefix_tokens)
    y_id = int(source_token_to_id.get(str(label_token), UNK_ID_SOURCE))

    # Left-pad to MAX_LEN
    pad_len = MAX_LEN - len(x_ids)
    if pad_len > 0:
        x_ids = np.concatenate([np.full(pad_len, PAD_ID_SOURCE, dtype=np.int64), x_ids], axis=0)

    attn_mask = (x_ids != PAD_ID_SOURCE).astype(np.int64)

    if y_id == PAD_ID_SOURCE:
        return None

    return x_ids, attn_mask, y_id

def iter_session_pairs_parquet(
    parquet_files,
    seed: int = 42,
    sample_mod: int = 1,
    sample_rem: int = 0,
    log_every_files: int = 50,
):
    """
    Yields (x_ids[int64], attn_mask[int64], y_id[int]) for one pair per session.
    Sampling: keep session if stable_mod(session_id, sample_mod) == sample_rem
    """
    rng = random.Random(seed)
    files = list(parquet_files)
    rng.shuffle(files)

    t0 = time.time()
    n_files = 0
    n_sessions = 0
    n_yield = 0
    n_short = 0
    n_unk_label = 0

    for fpath in files:
        n_files += 1
        df = pd.read_parquet(fpath, columns=["session_id", seq_col])

        for sid, seq in zip(df["session_id"].values, df[seq_col].values):
            n_sessions += 1

            if sample_mod and sample_mod > 1:
                if stable_mod(sid, sample_mod) != sample_rem:
                    continue

            pair = session_to_one_pair(seq)
            if pair is None:
                n_short += 1
                continue

            x_ids, attn_mask, y_id = pair
            if y_id == UNK_ID_SOURCE:
                n_unk_label += 1

            n_yield += 1
            yield torch.from_numpy(x_ids), torch.from_numpy(attn_mask), int(y_id)

        if log_every_files and (n_files % log_every_files == 0):
            elapsed = time.time() - t0
            print(
                f"[11A-04] scanned_files={n_files}/{len(files)} "
                f"sessions_seen={n_sessions:,} yielded={n_yield:,} short={n_short:,} "
                f"unk_labels={n_unk_label:,} elapsed={elapsed:.1f}s"
            )

    elapsed = time.time() - t0
    print(
        f"[11A-04] DONE files={n_files} sessions_seen={n_sessions:,} yielded={n_yield:,} "
        f"short={n_short:,} unk_labels={n_unk_label:,} elapsed={elapsed:.1f}s"
    )

print("[11A-04] ✅ Streaming pair generator ready")
print("[11A-04] stable_mod probe:", stable_mod("3160332::21", 10), stable_mod("3160332::21", 10), "(should match)")

print("\n[11A-04] CHECKPOINT D")
print("Next run [11A-05] to probe 3 yielded pairs (x[:10], mask sum, label).")


[11A-04] Built source_token_to_id from source_vocab['item2id'] size=1,620
[11A-04] PAD/UNK: {'PAD_ID_SOURCE': 0, 'UNK_ID_SOURCE': 1}
[11A-04] ✅ Streaming pair generator ready
[11A-04] stable_mod probe: 9 9 (should match)

[11A-04] CHECKPOINT D
Next run [11A-05] to probe 3 yielded pairs (x[:10], mask sum, label).


 Probe generator output (must look sane)

In [20]:
# [CELL 11A-05] Probe generator output (must look sane)

gen = iter_session_pairs_parquet(train_files[:2], seed=42, sample_mod=10, sample_rem=0, log_every_files=1)
for j in range(3):
    x, m, y = next(gen)
    print(f"[11A-05] sample {j}: x_nonzero={int(m.sum())} label={y}")
    print(" x[:10]=", x[:10].tolist())
    print(" m[:10]=", m[:10].tolist())



print("\n[11A-05] CHECKPOINT E")
print("Paste these 3 probe samples (x_nonzero + label).")


[11A-05] sample 0: x_nonzero=19 label=417
 x[:10]= [0, 133, 213, 251, 251, 133, 251, 251, 251, 251]
 m[:10]= [0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[11A-05] sample 1: x_nonzero=20 label=13
 x[:10]= [13, 13, 13, 13, 13, 13, 13, 13, 13, 13]
 m[:10]= [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[11A-05] sample 2: x_nonzero=18 label=14
 x[:10]= [0, 0, 14, 14, 14, 14, 14, 14, 14, 14]
 m[:10]= [0, 0, 1, 1, 1, 1, 1, 1, 1, 1]

[11A-05] CHECKPOINT E
Paste these 3 probe samples (x_nonzero + label).


Model choice for 11A

GRU4Rec model (encoder we will transfer later)

In [21]:
# [CELL 11A-06] GRU4Rec model (encoder we will transfer later)
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class GRU4Rec(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, dropout, pad_id):
        super().__init__()
        self.pad_id = pad_id
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.drop = nn.Dropout(dropout)
        self.gru = nn.GRU(input_size=emb_dim, hidden_size=hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids, attn_mask):
        x = self.emb(input_ids)              # [B,L,E]
        x = self.drop(x)
        lengths = attn_mask.sum(dim=1).clamp(min=1).to(torch.int64)  # [B]
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        out_packed, _ = self.gru(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(out_packed, batch_first=True, total_length=input_ids.size(1))
        idx = (lengths - 1).view(-1, 1, 1).expand(-1, 1, out.size(-1))
        last_h = out.gather(1, idx).squeeze(1)  # [B,H]
        logits = self.out(last_h)               # [B,V]
        return logits


Train config + batching (stream to minibatches)

In [22]:
# [CELL 11A-07] Train config + batching (stream to minibatches)
PRETRAIN_CFG = {
    "model": "gru4rec",
    "emb_dim": 64,
    "hidden_dim": 128,
    "dropout": 0.3,
    "batch_size": 512,
    "lr": 1e-3,
    "weight_decay": 1e-6,
    "grad_clip": 1.0,
    "seed": 42,

    # streaming controls (compute parity)
    "sample_mod": 5,        # keep 1/5 sessions
    "sample_rem": 0,
    "max_steps_per_epoch": 3000,  # hard cap on updates per epoch
    "max_epochs": 5,
    "log_every_steps": 200,
}

print("[11A-07] PRETRAIN_CFG:", PRETRAIN_CFG)
print("\n[11A-07] CHECKPOINT F")
print("If you want to adjust sample_mod or steps_per_epoch for CPU, do it now.")


[11A-07] PRETRAIN_CFG: {'model': 'gru4rec', 'emb_dim': 64, 'hidden_dim': 128, 'dropout': 0.3, 'batch_size': 512, 'lr': 0.001, 'weight_decay': 1e-06, 'grad_clip': 1.0, 'seed': 42, 'sample_mod': 5, 'sample_rem': 0, 'max_steps_per_epoch': 3000, 'max_epochs': 5, 'log_every_steps': 200}

[11A-07] CHECKPOINT F
If you want to adjust sample_mod or steps_per_epoch for CPU, do it now.


Training loop (source pretrain) + quick val eval hook
For validation, will keep it small but real (cap pairs like we did in 08 for compute), but log it explicitly.

In [26]:
# [CELL 11A-08] Training loop (source pretrain) + quick val eval hook

@torch.no_grad()
def eval_source_pairs(model, files, pair_cap=200_000, seed=42, sample_mod=1, sample_rem=0):
    model.eval()
    K_LIST = PROTO["K_LIST"]
    pad_id = PAD_ID_SOURCE

    # metric accumulators
    hits = {k: 0 for k in K_LIST}
    mrrs = {k: 0.0 for k in K_LIST}
    ndcgs = {k: 0.0 for k in K_LIST}
    n = 0

    gen = iter_session_pairs_parquet(
    files,
    seed=seed,
    sample_mod=sample_mod,
    sample_rem=sample_rem,
    log_every_files=999999
    )

    bs = 512
    buf_x, buf_m, buf_y = [], [], []

    def rank_metrics(logits, y, x_seen_mask=None):
        # logits: [B,V], y: [B]
        # exclude PAD from ranking
        logits[:, pad_id] = -1e9

        # topK
        maxk = max(K_LIST)
        topk = torch.topk(logits, k=maxk, dim=1).indices  # [B,maxk]
        # compute metrics
        out = {k: {"hit":0, "mrr":0.0, "ndcg":0.0} for k in K_LIST}
        for i in range(topk.size(0)):
            yi = int(y[i].item())
            preds = topk[i].tolist()
            for k in K_LIST:
                p = preds[:k]
                if yi in p:
                    r = p.index(yi) + 1
                    out[k]["hit"] += 1
                    out[k]["mrr"] += 1.0 / r
                    out[k]["ndcg"] += 1.0 / math.log2(r + 1)
        return out

    for x, m, y in gen:
        buf_x.append(x); buf_m.append(m); buf_y.append(y)
        if len(buf_x) == bs:
            xb = torch.tensor(np.stack(buf_x), dtype=torch.long)
            mb = torch.tensor(np.stack(buf_m), dtype=torch.long)
            yb = torch.tensor(np.array(buf_y), dtype=torch.long)

            logits = model(xb, mb)
            out = rank_metrics(logits, yb)
            for k in K_LIST:
                hits[k] += out[k]["hit"]
                mrrs[k] += out[k]["mrr"]
                ndcgs[k] += out[k]["ndcg"]
            n += xb.size(0)

            buf_x, buf_m, buf_y = [], [], []
            if pair_cap is not None and n >= pair_cap:
                break

    res = {}
    for k in K_LIST:
        res[f"HR@{k}"] = hits[k] / max(1, n)
        res[f"MRR@{k}"] = mrrs[k] / max(1, n)
        res[f"NDCG@{k}"] = ndcgs[k] / max(1, n)
    res["_n_pairs"] = n
    res["_pair_cap"] = pair_cap
    return res

device = torch.device("cpu")
set_seed(PRETRAIN_CFG["seed"])

model = GRU4Rec(
    vocab_size=VOCAB_SIZE_SOURCE,
    emb_dim=PRETRAIN_CFG["emb_dim"],
    hidden_dim=PRETRAIN_CFG["hidden_dim"],
    dropout=PRETRAIN_CFG["dropout"],
    pad_id=PAD_ID_SOURCE,
).to(device)

opt = torch.optim.Adam(model.parameters(), lr=PRETRAIN_CFG["lr"], weight_decay=PRETRAIN_CFG["weight_decay"])

print("[11A-08] Model params:", sum(p.numel() for p in model.parameters()))
print("[11A-08] Starting pretrain on SOURCE (streamed, capped steps).")

best = {"hr20": -1, "epoch": -1, "state": None}
t0 = time.time()

for epoch in range(1, PRETRAIN_CFG["max_epochs"] + 1):
    model.train()
    step = 0
    losses = []

    gen = iter_session_pairs_parquet(
    train_files,
    seed=PRETRAIN_CFG["seed"] + epoch,
    sample_mod=PRETRAIN_CFG["sample_mod"],
    sample_rem=PRETRAIN_CFG["sample_rem"],
    log_every_files=200
)


    bs = PRETRAIN_CFG["batch_size"]
    buf_x, buf_m, buf_y = [], [], []

    for x, m, y in gen:
        buf_x.append(x); buf_m.append(m); buf_y.append(y)
        if len(buf_x) == bs:
            xb = torch.tensor(np.stack(buf_x), dtype=torch.long)
            mb = torch.tensor(np.stack(buf_m), dtype=torch.long)
            yb = torch.tensor(np.array(buf_y), dtype=torch.long)

            opt.zero_grad(set_to_none=True)
            logits = model(xb, mb)
            loss = F.cross_entropy(logits, yb, ignore_index=PAD_ID_SOURCE)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), PRETRAIN_CFG["grad_clip"])
            opt.step()

            losses.append(float(loss.item()))
            step += 1
            buf_x, buf_m, buf_y = [], [], []

            if (step % PRETRAIN_CFG["log_every_steps"]) == 0:
                print(f"[11A-08] epoch={epoch:02d} step={step} loss={np.mean(losses[-50:]):.4f} elapsed={time.time()-t0:.1f}s")

            if step >= PRETRAIN_CFG["max_steps_per_epoch"]:
                break

    val_res = eval_source_pairs(model, val_files, pair_cap=200_000, seed=42, sample_mod=5, sample_rem=0)
    print(f"[11A-08] epoch={epoch:02d} done | train_loss_mean={np.mean(losses):.4f} | SOURCE_VAL(cap): {val_res}")

    hr20 = val_res["HR@20"]
    if hr20 > best["hr20"]:
        best["hr20"] = hr20
        best["epoch"] = epoch
        best["state"] = {k: v.cpu() for k,v in model.state_dict().items()}

print("[11A-08] ✅ Best epoch:", best["epoch"], "best HR@20:", best["hr20"])
print("\n[11A-08] CHECKPOINT G")
print("Paste: best epoch + val_res dict (HR/MRR/NDCG + _n_pairs/_pair_cap).")


[11A-08] Model params: 387156
[11A-08] Starting pretrain on SOURCE (streamed, capped steps).
[11A-08] epoch=01 step=200 loss=4.5336 elapsed=37.4s
[11A-08] epoch=01 step=400 loss=4.0626 elapsed=76.4s
[11A-04] scanned_files=200/1024 sessions_seen=1,303,802 yielded=260,817 short=0 unk_labels=0 elapsed=96.7s
[11A-08] epoch=01 step=600 loss=3.8752 elapsed=114.4s
[11A-08] epoch=01 step=800 loss=3.8486 elapsed=152.6s
[11A-08] epoch=01 step=1000 loss=3.8282 elapsed=191.3s
[11A-04] scanned_files=400/1024 sessions_seen=2,606,471 yielded=521,582 short=0 unk_labels=0 elapsed=194.1s
[11A-08] epoch=01 step=1200 loss=3.8084 elapsed=227.7s
[11A-08] epoch=01 step=1400 loss=3.7458 elapsed=267.4s
[11A-04] scanned_files=600/1024 sessions_seen=3,910,143 yielded=782,331 short=0 unk_labels=0 elapsed=290.0s
[11A-08] epoch=01 step=1600 loss=3.7681 elapsed=303.0s
[11A-08] epoch=01 step=1800 loss=3.7808 elapsed=343.8s
[11A-08] epoch=01 step=2000 loss=3.7739 elapsed=380.7s
[11A-04] scanned_files=800/1024 sessions

Save pretrained checkpoint + reports + meta.json update

In [27]:
# [CELL 11A-09] Save pretrained checkpoint + reports + meta.json update

report_dir = REPO_ROOT / "reports" / "11A_transfer_pretrain_source" / RUN_TAG
report_dir.mkdir(parents=True, exist_ok=True)

ckpt = {
    "run_tag": RUN_TAG,
    "source_run_tag": SOURCE_RUN_TAG,
    "protocol": PROTO,
    "pretrain_cfg": PRETRAIN_CFG,
    "vocab_size_source": VOCAB_SIZE_SOURCE,
    "pad_id_source": PAD_ID_SOURCE,
    "unk_id_source": UNK_ID_SOURCE,
    "best_epoch": best["epoch"],
    "best_val_hr20_cap": best["hr20"],
    # This is what we will transfer later:
    "transfer_notes": {
        "transferable": ["gru.*"],  # in 11B we will load only GRU weights, not embeddings/out head
        "not_transferable": ["emb.weight", "out.weight", "out.bias"]
    },
    "state_dict": best["state"],
}

ckpt_path = report_dir / "model_pretrained_source.pt"
torch.save(ckpt, ckpt_path)
print("[11A-09] ✅ Saved pretrained checkpoint:", ckpt_path)

# write meta.json update (append run)
meta_path = REPO_ROOT / "meta.json"
meta = {}
if meta_path.exists():
    meta = json.loads(meta_path.read_text(encoding="utf-8"))

meta.setdefault("runs", [])
meta["runs"].append({
    "notebook": "11A_transfer_pretrain_source.ipynb",
    "run_tag": RUN_TAG,
    "created_at": datetime.now().isoformat(timespec="seconds"),
    "source_run_tag": SOURCE_RUN_TAG,
    "model": PRETRAIN_CFG["model"],
    "report_dir": str(report_dir),
    "checkpoint": str(ckpt_path),
    "notes": "Source pretraining (streamed one-pair-per-session, capped steps/epoch). Encoder weights intended for transfer; embeddings/head are source-specific."
})

meta_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
print("[11A-09] ✅ Updated meta.json:", meta_path)

print("\n[11A-09] CHECKPOINT H")
print("Paste: report_dir path + confirm meta.json updated.")


[11A-09] ✅ Saved pretrained checkpoint: C:\mooc-coldstart-session-meta\reports\11A_transfer_pretrain_source\20260103_220933\model_pretrained_source.pt
[11A-09] ✅ Updated meta.json: C:\mooc-coldstart-session-meta\meta.json

[11A-09] CHECKPOINT H
Paste: report_dir path + confirm meta.json updated.
