Imports + versions

In [50]:
# [CELL 12B-00] Imports + versions
import os, json, time, math, hashlib
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import torch

print("[12B-00] torch:", torch.__version__)
print("[12B-00] pandas:", pd.__version__)
print("[12B-00] numpy:", np.__version__)

[12B-00] torch: 2.9.1+cpu
[12B-00] pandas: 2.3.3
[12B-00] numpy: 2.4.0


Repo root (portable) + load single source-of-truth JSONs

In [51]:
# [CELL 12B-01] Repo root (portable) + load single source-of-truth JSONs
def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    # fallback: git root
    for p in [start, *start.parents]:
        if (p / ".git").exists():
            return p
    raise FileNotFoundError("Could not locate repo root (PROJECT_STATE.md or .git).")

REPO_ROOT = Path(os.getenv("MOOC_REPO_ROOT", "")).expanduser().resolve() if os.getenv("MOOC_REPO_ROOT") else find_repo_root(Path.cwd().resolve())

RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
print("[12B-01] REPO_ROOT:", str(REPO_ROOT))
print("[12B-01] RUN_TAG:", RUN_TAG)

CFG_PATH = REPO_ROOT / "data/processed/supervised/dataloader_config_20251229_163357_20251229_232834.json"
SANITY_PATH = REPO_ROOT / "data/processed/supervised/sanity_metrics_20251229_163357_20251229_232834.json"
GAPS_PATH = REPO_ROOT / "data/processed/normalized_events/session_gap_thresholds.json"

print("[12B-01] Expect config:", CFG_PATH)
print("[12B-01] Expect sanity:", SANITY_PATH)
print("[12B-01] Expect gaps  :", GAPS_PATH)

assert CFG_PATH.exists(), f"Missing: {CFG_PATH}"
assert SANITY_PATH.exists(), f"Missing: {SANITY_PATH}"
assert GAPS_PATH.exists(), f"Missing: {GAPS_PATH}"

DL_CFG = json.loads(CFG_PATH.read_text(encoding="utf-8"))
SANITY = json.loads(SANITY_PATH.read_text(encoding="utf-8"))
GAPS = json.loads(GAPS_PATH.read_text(encoding="utf-8"))

def infer_gap_minutes(d: dict, name: str) -> int:
    if "gap_minutes" in d:  # legacy
        return int(d["gap_minutes"])
    if "primary_threshold_seconds" in d:
        sec = int(d["primary_threshold_seconds"])
        return sec // 60
    raise KeyError(f"[12A-01] {name}: cannot infer gap minutes. keys={list(d.keys())}")

gap_target_m = infer_gap_minutes(GAPS["target"], "target")
gap_source_m = infer_gap_minutes(GAPS["source"], "source")
print(f"[12B-01] target: gap_minutes from primary_threshold_seconds={GAPS['target'].get('primary_threshold_seconds')} -> {gap_target_m}m | label={GAPS['target'].get('primary_threshold_label')}")
print(f"[12B-01] source: gap_minutes from primary_threshold_seconds={GAPS['source'].get('primary_threshold_seconds')} -> {gap_source_m}m | label={GAPS['source'].get('primary_threshold_label')}")

assert gap_target_m == 30, f"target gap mismatch: got {gap_target_m}m"
assert gap_source_m == 10, f"source gap mismatch: got {gap_source_m}m"
print("[12B-01] ✅ Session gaps confirmed.")

print("\n[12B-01] CHECKPOINT A")
print("Paste: inferred gaps lines + confirm asserts passed.")


[12B-01] REPO_ROOT: C:\mooc-coldstart-session-meta
[12B-01] RUN_TAG: 20260104_165117
[12B-01] Expect config: C:\mooc-coldstart-session-meta\data\processed\supervised\dataloader_config_20251229_163357_20251229_232834.json
[12B-01] Expect sanity: C:\mooc-coldstart-session-meta\data\processed\supervised\sanity_metrics_20251229_163357_20251229_232834.json
[12B-01] Expect gaps  : C:\mooc-coldstart-session-meta\data\processed\normalized_events\session_gap_thresholds.json
[12B-01] target: gap_minutes from primary_threshold_seconds=1800 -> 30m | label=30m
[12B-01] source: gap_minutes from primary_threshold_seconds=600 -> 10m | label=10m
[12B-01] ✅ Session gaps confirmed.

[12B-01] CHECKPOINT A
Paste: inferred gaps lines + confirm asserts passed.


Load 12A task artifacts (auto-pick latest run dir) + basic asserts

In [52]:
# [CELL 12B-02] Load TASK_META + TASK_COV robustly and assert we can proceed

from pathlib import Path
import json

REPO_ROOT = REPO_ROOT  # from 12B-01
TASK_RUN_DIR = Path(r"C:\mooc-coldstart-session-meta\reports\12A_task_builder_for_meta\20260104_141727")

meta_path = TASK_RUN_DIR / "meta_task_config.json"
cov_path  = TASK_RUN_DIR / "task_coverage.json"
assert meta_path.exists(), f"[12B-02] Missing: {meta_path}"
assert cov_path.exists(),  f"[12B-02] Missing: {cov_path}"

TASK_META = json.loads(meta_path.read_text(encoding="utf-8"))
TASK_COV  = json.loads(cov_path.read_text(encoding="utf-8"))

print("[12B-02] REPO_ROOT:", REPO_ROOT)
print("[12B-02] Using TASK_RUN_DIR:", TASK_RUN_DIR)
print("[12B-02] TASK_META keys:", list(TASK_META.keys()))
print("[12B-02] TASK_COV keys :", list(TASK_COV.keys()))

# IMPORTANT: real cfg is nested
TASK_CFG = TASK_META["task_cfg"]
print("[12B-02] TASK_CFG keys :", list(TASK_CFG.keys()))
print("[12B-02] TASK_CFG:", TASK_CFG)

# derive coverage in a defensible way (no guessing)
pairs_kept_per_task = TASK_COV.get("pairs_kept_per_task", {}) or {}
task_unique = int(TASK_COV.get("tasks_total", len(pairs_kept_per_task))) if (TASK_COV.get("tasks_total") is not None or pairs_kept_per_task) else None
top10 = sorted(pairs_kept_per_task.items(), key=lambda kv: kv[1], reverse=True)[:10]

need_per_task = int(TASK_COV.get("need_per_task", TASK_CFG["n_support"] + TASK_CFG["n_query"]))
print("[12B-02] task_unique:", task_unique, "| top10:", top10)
print("[12B-02] need_per_task:", need_per_task, "| n_support+n_query:", int(TASK_CFG["n_support"]) + int(TASK_CFG["n_query"]))

# hard sanity
assert int(TASK_CFG["n_support"]) > 0 and int(TASK_CFG["n_query"]) > 0
assert int(TASK_CFG["n_support"]) + int(TASK_CFG["n_query"]) <= int(TASK_CFG["max_pairs_per_task_buffer"]), \
    "[12B-02] n_support+n_query must be <= max_pairs_per_task_buffer"

min_required = int(TASK_CFG.get("min_tasks_required", 1))
assert task_unique is not None and task_unique >= min_required, \
    f"[12B-02] Not enough tasks: task_unique={task_unique}, min_required={min_required}"

print("\n[12B-02] CHECKPOINT B")
print("Paste: TASK_RUN_DIR + TASK_CFG + TASK_COV keys + task_unique/top10 + need_per_task line.")


[12B-02] REPO_ROOT: C:\mooc-coldstart-session-meta
[12B-02] Using TASK_RUN_DIR: C:\mooc-coldstart-session-meta\reports\12A_task_builder_for_meta\20260104_141727
[12B-02] TASK_META keys: ['run_tag', 'created_at', 'protocol', 'session_gaps', 'task_cfg', 'task_definition', 'source', 'notes']
[12B-02] TASK_COV keys : ['pairs_total_seen', 'pairs_total_kept', 'tasks_total', 'need_per_task', 'pairs_kept_per_task', 'pairs_seen_per_task']
[12B-02] TASK_CFG keys : ['n_support', 'n_query', 'sample_mod', 'sample_rem', 'max_files', 'max_pairs_per_task_buffer', 'seed', 'task_key_mode', 'require_multi_task', 'min_tasks_required']
[12B-02] TASK_CFG: {'n_support': 20, 'n_query': 20, 'sample_mod': 10, 'sample_rem': 0, 'max_files': 50, 'max_pairs_per_task_buffer': 50000, 'seed': 42, 'task_key_mode': 'len_bin', 'require_multi_task': True, 'min_tasks_required': 3}
[12B-02] task_unique: 5 | top10: [('len_21_plus', 7920), ('len_03_05', 7541), ('len_06_10', 6939), ('len_11_20', 6686), ('len_01_02', 3454)]
[12

Load source shard paths + vocab + normalize protocol (same logic as 12A)

In [53]:
# [CELL 12B-03] Resolve SOURCE shard dirs + vocab + protocol (robust: prefer TASK_META paths)

import json
from pathlib import Path
import numpy as np
import pandas as pd

# ---- PROTO: prefer TASK_META (already normalized + asserted in 12B-02) ----
assert "protocol" in TASK_META, f"[12B-03] TASK_META missing protocol. Keys={list(TASK_META.keys())}"
PROTO = TASK_META["protocol"]
print("[12B-03] ✅ PROTO:", PROTO)

# hard asserts (must match earlier notebooks)
assert PROTO["MAX_PREFIX_LEN"] == 20
assert PROTO["CAP_ENABLED"] is True
assert PROTO["CAP_SESSION_LEN"] == 200
assert PROTO["CAP_STRATEGY"] == "take_last"
assert PROTO["K_LIST"] == [5, 10, 20]

K_LIST = list(map(int, PROTO["K_LIST"]))
MAX_LEN = int(PROTO["MAX_PREFIX_LEN"])
CAP_ENABLED = bool(PROTO["CAP_ENABLED"])
CAP_SESSION_LEN = int(PROTO["CAP_SESSION_LEN"])
CAP_STRATEGY = str(PROTO["CAP_STRATEGY"])

# ---- SOURCE paths: prefer TASK_META (12A wrote them) ----
assert "source" in TASK_META and "paths" in TASK_META["source"], (
    f"[12B-03] TASK_META missing source.paths. "
    f"Keys={list(TASK_META.get('source', {}).keys())}"
)
SRC_PATHS = TASK_META["source"]["paths"]

TRAIN_DIR = Path(SRC_PATHS["train_dir"])
VAL_DIR   = Path(SRC_PATHS["val_dir"])
TEST_DIR  = Path(SRC_PATHS["test_dir"])
SRC_VOCAB_PATH = Path(SRC_PATHS["vocab"])

print("[12B-03] TRAIN_DIR:", TRAIN_DIR)
print("[12B-03] VAL_DIR  :", VAL_DIR)
print("[12B-03] TEST_DIR :", TEST_DIR)
print("[12B-03] VOCAB    :", SRC_VOCAB_PATH)

assert TRAIN_DIR.exists(), f"[12B-03] Missing TRAIN_DIR: {TRAIN_DIR}"
assert VAL_DIR.exists(),   f"[12B-03] Missing VAL_DIR: {VAL_DIR}"
assert TEST_DIR.exists(),  f"[12B-03] Missing TEST_DIR: {TEST_DIR}"
assert SRC_VOCAB_PATH.exists(), f"[12B-03] Missing SRC_VOCAB_PATH: {SRC_VOCAB_PATH}"

train_files = sorted(TRAIN_DIR.glob("*.parquet"))
val_files   = sorted(VAL_DIR.glob("*.parquet"))
test_files  = sorted(TEST_DIR.glob("*.parquet"))
print("[12B-03] Source shards:", "train=", len(train_files), "val=", len(val_files), "test=", len(test_files))
assert len(train_files) == 1024 and len(val_files) == 1024 and len(test_files) == 1024

source_vocab = json.loads(SRC_VOCAB_PATH.read_text(encoding="utf-8"))
VOCAB_SIZE_SOURCE = int(source_vocab["vocab_size"])
PAD_ID_SOURCE = int(source_vocab.get("pad_id", 0))
UNK_ID_SOURCE = int(source_vocab.get("unk_id", 1))

print("[12B-03] VOCAB_SIZE_SOURCE:", VOCAB_SIZE_SOURCE)
print("[12B-03] PAD/UNK:", PAD_ID_SOURCE, UNK_ID_SOURCE)

# detect seq_col (parquet schema must match 11A/12A)
probe = pd.read_parquet(train_files[0], columns=None)
cols = list(probe.columns)
assert "items" in cols, f"[12B-03] Expected 'items' col. Got: {cols}"
SEQ_COL = "items"

first_seq = probe[SEQ_COL].iloc[0]
first_elem = first_seq[0] if len(first_seq) > 0 else None
print("[12B-03] seq_col:", SEQ_COL, "| first seq type:", type(first_seq), "| first elem type:", type(first_elem))

print("\n[12B-03] CHECKPOINT C")
print("Paste: shard counts + VOCAB_SIZE_SOURCE/PAD/UNK + seq_col + first elem type.")


[12B-03] ✅ PROTO: {'K_LIST': [5, 10, 20], 'MAX_PREFIX_LEN': 20, 'CAP_ENABLED': True, 'CAP_SESSION_LEN': 200, 'CAP_STRATEGY': 'take_last'}
[12B-03] TRAIN_DIR: C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\train
[12B-03] VAL_DIR  : C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\val
[12B-03] TEST_DIR : C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\test
[12B-03] VOCAB    : C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\source_vocab_items_20251229_232834.json
[12B-03] Source shards: train= 1024 val= 1024 test= 1024
[12B-03] VOCAB_SIZE_SOURCE: 1620
[12B-03] PAD/UNK: 0 1
[12B-03] seq_col: items | first seq type: <class 'numpy.ndarray'> | first elem type: <class 'str'>

[12B-03] CHECKPOINT C
Paste: shard counts + VOCAB_SIZE_SOURCE/PAD/UNK + seq_col + first elem type.


Pair builder + task_key builder + reservoir buffer (no recency bias)
- Uses pyarrow row-group streaming (fast, avoids full shard load)

In [54]:
# [CELL 12B-04] Pair builder + task_key builder + reservoir buffer (no recency bias)
# Uses pyarrow row-group streaming (fast, avoids full shard load)

import hashlib
import pyarrow.parquet as pq

MAX_LEN = int(PROTO["MAX_PREFIX_LEN"])
CAP_ENABLED = bool(PROTO["CAP_ENABLED"])
CAP_SESSION_LEN = int(PROTO["CAP_SESSION_LEN"])
CAP_STRATEGY = str(PROTO["CAP_STRATEGY"])

source_token_to_id = dict(source_vocab["item2id"])
assert len(source_token_to_id) == VOCAB_SIZE_SOURCE

def stable_mod(value, mod: int) -> int:
    h = hashlib.md5(str(value).encode("utf-8")).hexdigest()
    return int(h[:8], 16) % mod

def session_len_to_task_key(L: int) -> str:
    # same bins you printed in 12A
    if L <= 2: return "len_01_02"
    if L <= 5: return "len_03_05"
    if L <= 10: return "len_06_10"
    if L <= 20: return "len_11_20"
    return "len_21_plus"

def session_to_one_pair(seq_tokens):
    # seq_tokens: list[str] or np.ndarray of str
    if isinstance(seq_tokens, np.ndarray):
        seq_tokens = seq_tokens.tolist()
    if len(seq_tokens) < 2:
        return None

    if CAP_ENABLED and len(seq_tokens) > CAP_SESSION_LEN and CAP_STRATEGY == "take_last":
        seq_tokens = seq_tokens[-CAP_SESSION_LEN:]

    # label is last token
    y_tok = seq_tokens[-1]
    y = int(source_token_to_id.get(y_tok, UNK_ID_SOURCE))
    if y == UNK_ID_SOURCE:
        return None  # filter UNK labels (panel concern)

    prefix = seq_tokens[:-1]
    if len(prefix) > MAX_LEN:
        prefix = prefix[-MAX_LEN:]  # take last MAX_LEN

    # left-pad (recent at end), mask tracks non-pad
    x = np.full((MAX_LEN,), PAD_ID_SOURCE, dtype=np.int64)
    m = np.zeros((MAX_LEN,), dtype=np.int64)

    ids = [int(source_token_to_id.get(t, UNK_ID_SOURCE)) for t in prefix]
    # map UNK tokens inside prefix to UNK_ID_SOURCE (allowed)
    ids = [i if i >= 0 else UNK_ID_SOURCE for i in ids]

    start = MAX_LEN - len(ids)
    x[start:] = np.asarray(ids, dtype=np.int64)
    m[start:] = 1
    return x, m, y

class ReservoirBuffer:
    """
    Uniform reservoir sampling => avoids 'take last 50k' recency bias.
    Keeps at most cap items; each incoming item has chance cap/seen to replace a random stored item.
    """
    def __init__(self, cap: int, seed: int):
        self.cap = int(cap)
        self.rng = np.random.default_rng(int(seed))
        self.items = []
        self.seen = 0

    def add(self, item):
        self.seen += 1
        if len(self.items) < self.cap:
            self.items.append(item)
        else:
            j = self.rng.integers(0, self.seen)
            if j < self.cap:
                self.items[int(j)] = item

def iter_pairs_from_files(files, sample_mod: int, sample_rem: int, max_files=None, seed=42):
    """
    Yields: (task_key, x, m, y)
    Streams row groups via pyarrow for speed + low memory.
    """
    rng = np.random.default_rng(int(seed))
    n_files = 0
    sessions_seen = 0
    yielded = 0

    for fp in files:
        n_files += 1
        if max_files is not None and n_files > int(max_files):
            break

        pf = pq.ParquetFile(fp)
        for rg in range(pf.num_row_groups):
            table = pf.read_row_group(rg, columns=["session_id", "session_length", SEQ_COL])
            df = table.to_pandas()

            for sid, slen, seq in zip(df["session_id"], df["session_length"], df[SEQ_COL]):
                sessions_seen += 1

                if sample_mod and int(sample_mod) > 1:
                    if stable_mod(sid, int(sample_mod)) != int(sample_rem):
                        continue

                pair = session_to_one_pair(seq)
                if pair is None:
                    continue

                x, m, y = pair
                task_key = session_len_to_task_key(int(slen))
                yielded += 1
                yield task_key, x, m, y

    # no prints here (caller prints)

print("[12B-04] ✅ Pair+task generator ready")
print("[12B-04] stable_mod probe:", stable_mod("3160332::21", 10), stable_mod("3160332::21", 10), "(should match)")

print("\n[12B-04] CHECKPOINT D")
print("Proceed to build task buffers.")


[12B-04] ✅ Pair+task generator ready
[12B-04] stable_mod probe: 3 3 (should match)

[12B-04] CHECKPOINT D
Proceed to build task buffers.


Build task buffers (train + val) with coverage checks

In [55]:
# [CELL 12B-05] Build task buffers (train + val) with coverage checks

TASKS_MIN = int(TASK_CFG.get("min_tasks_required", 3))
REQUIRE_MULTI_TASK = bool(TASK_CFG.get("require_multi_task", True))
N_SUPPORT = int(TASK_CFG["n_support"])
N_QUERY   = int(TASK_CFG["n_query"])
NEED_PER_TASK = N_SUPPORT + N_QUERY
BUF_CAP = int(TASK_CFG["max_pairs_per_task_buffer"])

SAMPLE_MOD = int(TASK_CFG["sample_mod"])
SAMPLE_REM = int(TASK_CFG["sample_rem"])
MAX_FILES  = TASK_CFG.get("max_files", None)

print("[12B-05] NEED_PER_TASK:", NEED_PER_TASK, "| BUF_CAP:", BUF_CAP)
assert NEED_PER_TASK <= BUF_CAP, "[12B-05] n_support+n_query must be <= buffer cap"

def build_task_buffers(files, seed: int, max_files=None):
    buffers = {}  # task_key -> ReservoirBuffer
    counts = {}   # task_key -> kept
    seen_pairs = 0

    gen = iter_pairs_from_files(
        files,
        sample_mod=SAMPLE_MOD,
        sample_rem=SAMPLE_REM,
        max_files=max_files,
        seed=seed,
    )

    for task_key, x, m, y in gen:
        seen_pairs += 1
        if task_key not in buffers:
            buffers[task_key] = ReservoirBuffer(cap=BUF_CAP, seed=seed + stable_mod(task_key, 9973))
            counts[task_key] = 0

        buffers[task_key].add((x, m, y))
        # count kept is min(cap, seen) but we approximate by current len
        counts[task_key] = len(buffers[task_key].items)

        # optional: stop early if all tasks have enough
        if len(counts) >= 5 and all(v >= NEED_PER_TASK for v in counts.values()):
            # do NOT break: we want more diversity unless you want speed
            pass

    return buffers, counts, seen_pairs

# build train buffers (bigger) and val buffers (smaller max_files to keep CPU reasonable)
train_buffers, train_counts, train_seen = build_task_buffers(train_files, seed=42, max_files=MAX_FILES)
val_buffers,   val_counts,   val_seen   = build_task_buffers(val_files,   seed=1337, max_files=MAX_FILES)

def summarize_counts(name, counts):
    items = sorted(counts.items(), key=lambda kv: kv[1], reverse=True)
    print(f"[12B-05] {name}: task_unique={len(items)} | top10:", items[:10])
    ok = [k for k,v in items if v >= NEED_PER_TASK]
    bad = [k for k,v in items if v < NEED_PER_TASK]
    print(f"[12B-05] {name}: tasks_ok={len(ok)} tasks_bad={len(bad)} need={NEED_PER_TASK}")
    return ok, bad

train_ok, train_bad = summarize_counts("TRAIN", train_counts)
val_ok,   val_bad   = summarize_counts("VAL", val_counts)

if REQUIRE_MULTI_TASK:
    assert len(train_ok) >= TASKS_MIN, f"[12B-05] Not enough tasks_ok in TRAIN: {len(train_ok)} < {TASKS_MIN}"
    assert len(val_ok)   >= max(2, TASKS_MIN), f"[12B-05] Not enough tasks_ok in VAL: {len(val_ok)}"

print("\n[12B-05] CHECKPOINT E")
print("Paste: TRAIN/VAL task_unique+top10 + tasks_ok/tasks_bad.")


[12B-05] NEED_PER_TASK: 40 | BUF_CAP: 50000
[12B-05] TRAIN: task_unique=5 | top10: [('len_21_plus', 7968), ('len_03_05', 7722), ('len_11_20', 6794), ('len_06_10', 6749), ('len_01_02', 3431)]
[12B-05] TRAIN: tasks_ok=5 tasks_bad=0 need=40
[12B-05] VAL: task_unique=5 | top10: [('len_03_05', 979), ('len_21_plus', 953), ('len_11_20', 875), ('len_06_10', 833), ('len_01_02', 442)]
[12B-05] VAL: tasks_ok=5 tasks_bad=0 need=40

[12B-05] CHECKPOINT E
Paste: TRAIN/VAL task_unique+top10 + tasks_ok/tasks_bad.


Episode sampler (support/query) + quick probe

In [56]:
# [CELL 12B-06] Episode sampler (support/query) + quick probe

import torch

device = torch.device("cpu")

def sample_episode(buffers, ok_tasks, n_support, n_query, seed):
    rng = np.random.default_rng(int(seed))
    task = rng.choice(ok_tasks)
    pool = buffers[task].items
    # choose without replacement for support+query
    idx = rng.choice(len(pool), size=(n_support + n_query), replace=False)
    sup = [pool[i] for i in idx[:n_support]]
    qry = [pool[i] for i in idx[n_support:]]

    def stack(batch):
        x = torch.tensor(np.stack([b[0] for b in batch], axis=0), dtype=torch.long, device=device)
        m = torch.tensor(np.stack([b[1] for b in batch], axis=0), dtype=torch.long, device=device)
        y = torch.tensor(np.array([b[2] for b in batch], dtype=np.int64), dtype=torch.long, device=device)
        return x, m, y

    return task, stack(sup), stack(qry)

# probe 3 episodes
for i in range(3):
    task, (sx, sm, sy), (qx, qm, qy) = sample_episode(train_buffers, train_ok, N_SUPPORT, N_QUERY, seed=100+i)
    print(f"[12B-06] episode {i} task={task} | support_nonzero_mean={float(sm.sum(1).float().mean()):.2f} | query_nonzero_mean={float(qm.sum(1).float().mean()):.2f}")
    print("  support shapes:", tuple(sx.shape), tuple(sm.shape), tuple(sy.shape))
    print("  query   shapes:", tuple(qx.shape), tuple(qm.shape), tuple(qy.shape))

print("\n[12B-06] CHECKPOINT F")
print("Paste: the 3 episode lines (task + nonzero means + shapes).")


[12B-06] episode 0 task=len_06_10 | support_nonzero_mean=6.55 | query_nonzero_mean=7.30
  support shapes: (20, 20) (20, 20) (20,)
  query   shapes: (20, 20) (20, 20) (20,)
[12B-06] episode 1 task=len_03_05 | support_nonzero_mean=3.15 | query_nonzero_mean=2.85
  support shapes: (20, 20) (20, 20) (20,)
  query   shapes: (20, 20) (20, 20) (20,)
[12B-06] episode 2 task=len_11_20 | support_nonzero_mean=15.20 | query_nonzero_mean=12.60
  support shapes: (20, 20) (20, 20) (20,)
  query   shapes: (20, 20) (20, 20) (20,)

[12B-06] CHECKPOINT F
Paste: the 3 episode lines (task + nonzero means + shapes).


Model: GRU4RecDropout (same family as 11A/11B) for SOURCE vocab

In [57]:
# [CELL 12B-07] Model: GRU4RecDropout (same family as 11A/11B) for SOURCE vocab

import torch.nn as nn
import torch.nn.functional as F
import copy

def set_seed(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class GRU4RecDropout(nn.Module):
    def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int, pad_id: int, dropout: float = 0.3):
        super().__init__()
        self.pad_id = int(pad_id)
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=self.pad_id)
        self.drop = nn.Dropout(float(dropout))
        self.gru = nn.GRU(input_size=emb_dim, hidden_size=hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids: torch.Tensor, lengths: torch.Tensor):
        # input_ids: [B, T] left-padded. lengths: [B] counts of non-pad
        emb = self.drop(self.emb(input_ids))  # [B,T,E]
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, h = self.gru(packed)  # h: [1,B,H]
        logits = self.out(h.squeeze(0))  # [B,V]
        return logits

def make_lengths(attn_mask: torch.Tensor) -> torch.Tensor:
    # attn_mask: [B,T] 0/1
    lengths = attn_mask.sum(dim=1).clamp(min=1)
    return lengths

print("[12B-07] ✅ GRU4RecDropout ready")


[12B-07] ✅ GRU4RecDropout ready


Metrics (HR/MRR/NDCG @ K) on logits (PAD excluded)

In [58]:
# [CELL 12B-08] Metrics (HR/MRR/NDCG @ K) on logits (PAD excluded)

K_LIST = list(map(int, PROTO["K_LIST"]))

@torch.no_grad()
def metrics_from_logits(logits: torch.Tensor, labels: torch.Tensor):
    # logits: [B,V], labels: [B]
    # exclude PAD from ranking
    logits = logits.clone()
    logits[:, PAD_ID_SOURCE] = -1e9

    res = {}
    B = labels.shape[0]
    for k in K_LIST:
        topk = torch.topk(logits, k=k, dim=1).indices  # [B,k]
        hits = (topk == labels.unsqueeze(1)).any(dim=1).float()  # [B]
        hr = hits.mean().item()

        # ranks
        # if hit: rank = position+1 else inf
        match = (topk == labels.unsqueeze(1))
        # get first position if exists
        pos = torch.where(match.any(dim=1), match.float().argmax(dim=1) + 1, torch.full((B,), 10**9, device=logits.device))
        rr = torch.where(pos < 10**8, 1.0 / pos.float(), torch.zeros_like(pos, dtype=torch.float))
        mrr = rr.mean().item()

        ndcg = torch.where(pos < 10**8, 1.0 / torch.log2(pos.float() + 1.0), torch.zeros_like(pos, dtype=torch.float)).mean().item()

        res[f"HR@{k}"] = hr
        res[f"MRR@{k}"] = mrr
        res[f"NDCG@{k}"] = ndcg

    res["_n_examples"] = int(B)
    return res

print("[12B-08] ✅ Metrics ready (PAD excluded).")


[12B-08] ✅ Metrics ready (PAD excluded).


Reptile meta-train loop on SOURCE tasks (CPU-friendly)

In [64]:
# [CELL 12B-09] Build task buffers (train/val) + meta-train loop (Reptile-ish) + checkpoints
# Assumes:
#   - DL_CFG loaded in [12B-01]
#   - TASK_META/TASK_CFG/TASK_COV loaded in [12B-02]
# This cell is self-healing for protocol/paths/vocab.

import time, math, json, hashlib
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------
# 12B-09A) Resolve protocol + source paths + vocab robustly
# ----------------------------
t0 = time.time()
REPO_ROOT = REPO_ROOT if "REPO_ROOT" in globals() else Path.cwd().resolve()
print("[12B-09] REPO_ROOT:", REPO_ROOT)

assert "DL_CFG" in globals() and isinstance(DL_CFG, dict), "[12B-09] DL_CFG not found; run [12B-01]"
assert "TASK_META" in globals() and isinstance(TASK_META, dict), "[12B-09] TASK_META not found; run [12B-02]"
assert "TASK_CFG" in globals() and isinstance(TASK_CFG, dict), "[12B-09] TASK_CFG not found; run [12B-02]"

def pick_first(d: dict, keys: list[str], name: str):
    for k in keys:
        if k in d and d[k]:
            return d[k], k
    raise KeyError(f"[12B-09] Missing '{name}'. Tried {keys}. Available keys={list(d.keys())}")

# Protocol: prefer TASK_META['protocol'] (frozen) then DL_CFG['protocol']
PROTO = None
if "protocol" in TASK_META and isinstance(TASK_META["protocol"], dict):
    PROTO = TASK_META["protocol"]
elif "protocol" in DL_CFG and isinstance(DL_CFG["protocol"], dict):
    PROTO_RAW = DL_CFG["protocol"]
    max_prefix_len, _ = pick_first(PROTO_RAW, ["MAX_PREFIX_LEN", "max_prefix_len", "max_len", "seq_len"], "MAX_PREFIX_LEN")
    slsp = PROTO_RAW.get("source_long_session_policy", {}) or {}
    cap_enabled = bool(slsp.get("enabled", True))
    cap_session_len = int(slsp.get("cap_session_len", 200))
    cap_strategy = str(slsp.get("cap_strategy", "take_last"))
    k_list = PROTO_RAW.get("K_LIST") or PROTO_RAW.get("k_list") or [5, 10, 20]
    PROTO = {
        "K_LIST": list(map(int, k_list)),
        "MAX_PREFIX_LEN": int(max_prefix_len),
        "CAP_ENABLED": bool(cap_enabled),
        "CAP_SESSION_LEN": int(cap_session_len),
        "CAP_STRATEGY": cap_strategy,
    }
else:
    raise KeyError("[12B-09] Could not find protocol in TASK_META or DL_CFG")

print("[12B-09] ✅ PROTO:", PROTO)
assert PROTO["MAX_PREFIX_LEN"] == 20, "[12B-09] Protocol drift: MAX_PREFIX_LEN != 20"
assert PROTO["CAP_ENABLED"] is True, "[12B-09] Protocol drift: CAP_ENABLED != True"
assert PROTO["CAP_SESSION_LEN"] == 200, "[12B-09] Protocol drift: CAP_SESSION_LEN != 200"
assert PROTO["CAP_STRATEGY"] == "take_last", "[12B-09] Protocol drift: CAP_STRATEGY != take_last"
assert PROTO["K_LIST"] == [5, 10, 20], "[12B-09] Protocol drift: K_LIST != [5,10,20]"

MAX_LEN = int(PROTO["MAX_PREFIX_LEN"])
CAP_ENABLED = bool(PROTO["CAP_ENABLED"])
CAP_SESSION_LEN = int(PROTO["CAP_SESSION_LEN"])
CAP_STRATEGY = str(PROTO["CAP_STRATEGY"])
K_LIST = list(map(int, PROTO["K_LIST"]))

# Source paths: your schema has seq_dir + *_glob + vocab_json
SRC = DL_CFG["source"]
seq_dir_raw, _ = pick_first(SRC, ["seq_dir"], "source seq_dir")
seq_dir = Path(seq_dir_raw)

TRAIN_DIR = seq_dir / "train"
VAL_DIR   = seq_dir / "val"
TEST_DIR  = seq_dir / "test"
print(f"[12B-09] Source dirs from seq_dir: train={TRAIN_DIR} | val={VAL_DIR} | test={TEST_DIR}")
assert TRAIN_DIR.exists() and VAL_DIR.exists() and TEST_DIR.exists(), "[12B-09] Missing one of train/val/test dirs"

train_files = sorted(TRAIN_DIR.glob("*.parquet"))
val_files   = sorted(VAL_DIR.glob("*.parquet"))
test_files  = sorted(TEST_DIR.glob("*.parquet"))
print("[12B-09] Source shards:", "train=", len(train_files), "val=", len(val_files), "test=", len(test_files))
assert len(train_files) == 1024 and len(val_files) == 1024 and len(test_files) == 1024, "[12B-09] shard count mismatch"

vocab_raw, vocab_key = pick_first(SRC, ["vocab_json", "vocab_items_path", "vocab"], "source vocab path")
SRC_VOCAB_PATH = Path(vocab_raw)
print(f"[12B-09] Source vocab: {vocab_key}={SRC_VOCAB_PATH}")
assert SRC_VOCAB_PATH.exists(), f"[12B-09] Missing source vocab: {SRC_VOCAB_PATH}"

source_vocab = json.loads(SRC_VOCAB_PATH.read_text(encoding="utf-8"))
VOCAB_SIZE_SOURCE = int(source_vocab["vocab_size"])
PAD_ID_SOURCE = int(source_vocab.get("pad_id", 0))
UNK_ID_SOURCE = int(source_vocab.get("unk_id", 1))
source_token_to_id = source_vocab["item2id"]
print("[12B-09] VOCAB_SIZE_SOURCE:", VOCAB_SIZE_SOURCE, "| PAD/UNK:", PAD_ID_SOURCE, UNK_ID_SOURCE)
assert isinstance(source_token_to_id, dict) and len(source_token_to_id) == VOCAB_SIZE_SOURCE, "[12B-09] item2id invalid"

# Detect seq col (expected)
probe_df = pd.read_parquet(train_files[0])
assert "items" in probe_df.columns, f"[12B-09] Expected 'items' col. Got: {list(probe_df.columns)}"
SEQ_COL = "items"

# ----------------------------
# 12B-09B) Pair construction helpers
# ----------------------------
def cap_session_tokens(seq_tokens: list[str]) -> list[str]:
    if not CAP_ENABLED:
        return seq_tokens
    if len(seq_tokens) <= CAP_SESSION_LEN:
        return seq_tokens
    if CAP_STRATEGY == "take_last":
        return seq_tokens[-CAP_SESSION_LEN:]
    raise ValueError(f"[12B-09] Unknown CAP_STRATEGY={CAP_STRATEGY}")

def session_to_one_pair(seq_tokens: list[str]):
    seq_tokens = cap_session_tokens(seq_tokens)
    if len(seq_tokens) < 2:
        return None

    prefix_tokens = seq_tokens[:-1]
    label_token = seq_tokens[-1]

    x_ids = [source_token_to_id.get(tok, UNK_ID_SOURCE) for tok in prefix_tokens]
    y_id = source_token_to_id.get(label_token, UNK_ID_SOURCE)

    # filter UNK labels
    if y_id == UNK_ID_SOURCE:
        return None

    # cap to MAX_LEN (take last)
    if len(x_ids) > MAX_LEN:
        x_ids = x_ids[-MAX_LEN:]

    m = [1] * len(x_ids)

    # right pad
    if len(x_ids) < MAX_LEN:
        pad_n = MAX_LEN - len(x_ids)
        x_ids = x_ids + [PAD_ID_SOURCE] * pad_n
        m = m + [0] * pad_n

    return (np.array(x_ids, dtype=np.int64), np.array(m, dtype=np.int64), int(y_id), int(len(prefix_tokens)))

def stable_mod(value, mod: int) -> int:
    if mod <= 1:
        return 0
    if isinstance(value, (int, np.integer)):
        return int(value) % mod
    s = str(value).encode("utf-8")
    h = hashlib.blake2b(s, digest_size=8).digest()
    as_int = int.from_bytes(h, byteorder="little", signed=False)
    return as_int % mod

LEN_BINS = [
    ("len_01_02", 1, 2),
    ("len_03_05", 3, 5),
    ("len_06_10", 6, 10),
    ("len_11_20", 11, 20),
    ("len_21_plus", 21, 10**9),
]
def task_key_from_len(prefix_len: int) -> str:
    for name, lo, hi in LEN_BINS:
        if lo <= prefix_len <= hi:
            return name
    return "len_21_plus"

def iter_pairs_from_files(files, sample_mod: int, sample_rem: int, seed: int, max_files: int | None = None):
    n_files = 0
    n_sessions = 0
    yielded = 0
    short = 0
    unk_labels = 0
    unk_tokens_total = 0
    t_start = time.time()

    for fp in files:
        n_files += 1
        if max_files is not None and n_files > max_files:
            break

        df = pd.read_parquet(fp, columns=["session_id", "session_length", SEQ_COL])
        for sid, slen, seq in zip(df["session_id"].tolist(), df["session_length"].tolist(), df[SEQ_COL].tolist()):
            n_sessions += 1

            if sample_mod > 1 and stable_mod(sid, sample_mod) != sample_rem:
                continue

            pair = session_to_one_pair(list(seq))
            if pair is None:
                short += 1
                continue

            x, m, y, prefix_len = pair
            unk_tokens_total += int((x == UNK_ID_SOURCE).sum())
            if y == UNK_ID_SOURCE:
                unk_labels += 1
                continue

            task_key = task_key_from_len(prefix_len)
            yielded += 1
            yield task_key, x, m, y

        if n_files % 10 == 0:
            elapsed = time.time() - t_start
            print(f"[12B-09] scanned_files={n_files} sessions_seen={n_sessions:,} yielded={yielded:,} short={short:,} unk_labels={unk_labels:,} elapsed={elapsed:.1f}s")

    elapsed = time.time() - t_start
    print(f"[12B-09] DONE files={n_files} sessions_seen={n_sessions:,} yielded={yielded:,} short={short:,} unk_labels={unk_labels:,} unk_tokens_total={unk_tokens_total:,} elapsed={elapsed:.1f}s")

# ----------------------------
# 12B-09C) Reservoir sampling per task buffer
# ----------------------------
MAX_BUF = int(TASK_CFG["max_pairs_per_task_buffer"])
NEED_PER_TASK = int(TASK_CFG["n_support"]) + int(TASK_CFG["n_query"])
assert NEED_PER_TASK <= MAX_BUF, "[12B-09] n_support+n_query must be <= max_pairs_per_task_buffer"

def reservoir_push(buf, item, rng: np.random.Generator, seen_count: int, cap: int):
    if len(buf) < cap:
        buf.append(item)
        return
    j = rng.integers(0, seen_count)
    if j < cap:
        buf[int(j)] = item

def build_task_buffers(files, split_name: str, seed: int):
    rng = np.random.default_rng(seed)
    buffers = {}
    seen_per_task = {}
    total_seen = 0

    gen = iter_pairs_from_files(
        files,
        sample_mod=int(TASK_CFG["sample_mod"]),
        sample_rem=int(TASK_CFG["sample_rem"]),
        seed=seed,
        max_files=int(TASK_CFG["max_files"]) if TASK_CFG.get("max_files") is not None else None
    )

    for task_key, x, m, y in gen:
        total_seen += 1
        if task_key not in buffers:
            buffers[task_key] = []
            seen_per_task[task_key] = 0
        seen_per_task[task_key] += 1
        reservoir_push(buffers[task_key], (x, m, y), rng, seen_per_task[task_key], MAX_BUF)

    ok = {k: v for k, v in buffers.items() if len(v) >= NEED_PER_TASK}
    bad = {k: len(v) for k, v in buffers.items() if len(v) < NEED_PER_TASK}

    print(f"[12B-09] [{split_name}] buffers built | total_seen={total_seen:,} | tasks_ok={len(ok)} | tasks_bad={len(bad)}")
    return ok, {"pairs_total_seen": total_seen, "tasks_ok": len(ok), "tasks_bad": len(bad), "bad": bad}

BUFFERS_TRAIN, TRAIN_STATS = build_task_buffers(train_files, "train", seed=int(TASK_CFG["seed"]) + 0)
BUFFERS_VAL,   VAL_STATS   = build_task_buffers(val_files,   "val",   seed=int(TASK_CFG["seed"]) + 1)

min_tasks_required = int(TASK_CFG.get("min_tasks_required", 1))
if bool(TASK_CFG.get("require_multi_task", False)):
    assert len(BUFFERS_TRAIN) >= min_tasks_required, f"[12B-09] Not enough tasks in train buffers: {len(BUFFERS_TRAIN)} < {min_tasks_required}"
    assert len(BUFFERS_VAL)   >= min_tasks_required, f"[12B-09] Not enough tasks in val buffers: {len(BUFFERS_VAL)} < {min_tasks_required}"

print("[12B-09] tasks(train) top:", sorted([(k, len(v)) for k, v in BUFFERS_TRAIN.items()], key=lambda x: -x[1])[:5])
print("[12B-09] tasks(val)   top:", sorted([(k, len(v)) for k, v in BUFFERS_VAL.items()], key=lambda x: -x[1])[:5])

# ----------------------------
# 12B-09D) Episode sampler
# ----------------------------
def sample_episode(buffers: dict, rng: np.random.Generator, n_support: int, n_query: int):
    task_key = rng.choice(list(buffers.keys()))
    pool = buffers[task_key]
    idx = rng.choice(len(pool), size=(n_support + n_query), replace=False)
    sup = [pool[i] for i in idx[:n_support]]
    qry = [pool[i] for i in idx[n_support:]]
    def stack(examples):
        x = torch.tensor(np.stack([e[0] for e in examples], axis=0), dtype=torch.long)
        m = torch.tensor(np.stack([e[1] for e in examples], axis=0), dtype=torch.long)
        y = torch.tensor([e[2] for e in examples], dtype=torch.long)
        return x, m, y
    return task_key, stack(sup), stack(qry)

# ----------------------------
# 12B-09E) Model
# ----------------------------
class GRU4RecDropout(nn.Module):
    def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int, pad_id: int, dropout: float):
        super().__init__()
        self.pad_id = pad_id
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.gru = nn.GRU(input_size=emb_dim, hidden_size=hidden_dim, batch_first=True)
        self.drop = nn.Dropout(dropout)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids: torch.Tensor, lengths: torch.Tensor):
        emb = self.emb(input_ids)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, h = self.gru(packed)
        h_last = self.drop(h[-1])
        return self.out(h_last)

def make_lengths(attn_mask: torch.Tensor) -> torch.Tensor:
    return attn_mask.sum(dim=1).clamp(min=1)

# ----------------------------
# 12B-09F) Metrics
# ----------------------------
def topk_metrics(logits: torch.Tensor, labels: torch.Tensor, ks=(5,10,20)):
    res = {}
    with torch.no_grad():
        max_k = max(ks)
        topk = torch.topk(logits, k=max_k, dim=1).indices
        for k in ks:
            hit = (topk[:, :k] == labels.unsqueeze(1)).any(dim=1).float().mean().item()
            res[f"HR@{k}"] = hit
        for k in ks:
            preds = topk[:, :k]
            eq = (preds == labels.unsqueeze(1))
            ranks = torch.where(eq.any(dim=1), eq.float().argmax(dim=1) + 1, torch.zeros_like(labels))
            rr = torch.where(ranks > 0, 1.0 / ranks.float(), torch.zeros_like(ranks).float()).mean().item()
            res[f"MRR@{k}"] = rr
            ndcg = torch.where(ranks > 0, 1.0 / torch.log2(ranks.float() + 1.0), torch.zeros_like(ranks).float()).mean().item()
            res[f"NDCG@{k}"] = ndcg
    return res

# ----------------------------
# 12B-09G) Meta-train (Reptile-ish)
# ----------------------------
META_CFG = {
    "emb_dim": 64,
    "hidden_dim": 128,
    "dropout": 0.3,
    "meta_lr": 5e-4,
    "inner_lr": 1e-2,
    "inner_steps": 1,
    "meta_steps": 2000,
    "meta_batch_tasks": 4,
    "grad_clip": 1.0,
    "seed": 42,
    "log_every": 100,
    "eval_every": 250,
    "val_episodes": 50,
}
print("[12B-09] META_CFG:", META_CFG)

device = torch.device("cpu")
torch.manual_seed(int(META_CFG["seed"]))
rng = np.random.default_rng(int(META_CFG["seed"]))

model = GRU4RecDropout(
    vocab_size=VOCAB_SIZE_SOURCE,
    emb_dim=int(META_CFG["emb_dim"]),
    hidden_dim=int(META_CFG["hidden_dim"]),
    pad_id=PAD_ID_SOURCE,
    dropout=float(META_CFG["dropout"]),
).to(device)

meta_opt = torch.optim.Adam(model.parameters(), lr=float(META_CFG["meta_lr"]), weight_decay=0.0)

def clone_state_dict(sd):
    return {k: v.clone() for k, v in sd.items()}

def load_state_dict_(m: nn.Module, sd: dict):
    m.load_state_dict(sd, strict=True)

# ✅ CRITICAL FIX: ensure adaptation runs with grads enabled even if caller is in no_grad
def adapt_one_task(base_sd: dict, support_batch, inner_lr: float, inner_steps: int):
    with torch.enable_grad():
        fast = GRU4RecDropout(
            vocab_size=VOCAB_SIZE_SOURCE,
            emb_dim=int(META_CFG["emb_dim"]),
            hidden_dim=int(META_CFG["hidden_dim"]),
            pad_id=PAD_ID_SOURCE,
            dropout=float(META_CFG["dropout"]),
        ).to(device)
        load_state_dict_(fast, base_sd)
        fast.train()

        x_s, m_s, y_s = support_batch
        x_s = x_s.to(device)
        m_s = m_s.to(device)
        y_s = y_s.to(device)
        lengths = make_lengths(m_s)

        opt = torch.optim.SGD(fast.parameters(), lr=inner_lr)
        for _ in range(inner_steps):
            opt.zero_grad(set_to_none=True)
            logits = fast(x_s, lengths)
            loss = F.cross_entropy(logits, y_s, ignore_index=PAD_ID_SOURCE)
            loss.backward()
            nn.utils.clip_grad_norm_(fast.parameters(), float(META_CFG["grad_clip"]))
            opt.step()

        return clone_state_dict(fast.state_dict())

# ✅ CRITICAL FIX: DO NOT decorate this with @torch.no_grad
def eval_meta(model: nn.Module, buffers: dict, n_episodes: int = 50, seed: int = 123):
    rng_eval = np.random.default_rng(seed)
    model.eval()

    agg = {f"HR@{k}": 0.0 for k in K_LIST}
    agg |= {f"MRR@{k}": 0.0 for k in K_LIST}
    agg |= {f"NDCG@{k}": 0.0 for k in K_LIST}

    for _ in range(n_episodes):
        _, sup, qry = sample_episode(buffers, rng_eval, int(TASK_CFG["n_support"]), int(TASK_CFG["n_query"]))
        base_sd = clone_state_dict(model.state_dict())

        # adaptation MUST have grads
        fast_sd = adapt_one_task(base_sd, sup, inner_lr=float(META_CFG["inner_lr"]), inner_steps=int(META_CFG["inner_steps"]))

        # query eval no grads
        with torch.no_grad():
            fast = GRU4RecDropout(
                vocab_size=VOCAB_SIZE_SOURCE,
                emb_dim=int(META_CFG["emb_dim"]),
                hidden_dim=int(META_CFG["hidden_dim"]),
                pad_id=PAD_ID_SOURCE,
                dropout=float(META_CFG["dropout"]),
            ).to(device)
            load_state_dict_(fast, fast_sd)
            fast.eval()

            x_q, m_q, y_q = qry
            x_q = x_q.to(device)
            m_q = m_q.to(device)
            y_q = y_q.to(device)
            lengths = make_lengths(m_q)
            logits = fast(x_q, lengths)

            res = topk_metrics(logits, y_q, ks=K_LIST)
            for k in res:
                agg[k] += res[k]

    for k in agg:
        agg[k] /= float(n_episodes)
    agg["_n_episodes"] = n_episodes
    return agg

print("[12B-09] Starting meta-train on SOURCE tasks...")
start = time.time()
best_val = -1.0
best_step = 0
best_sd = None

for step in range(1, int(META_CFG["meta_steps"]) + 1):
    model.train()
    base_sd = clone_state_dict(model.state_dict())
    deltas = {k: torch.zeros_like(v).detach() for k, v in base_sd.items()}
    meta_q_loss = 0.0

    for _ in range(int(META_CFG["meta_batch_tasks"])):
        _, sup, qry = sample_episode(BUFFERS_TRAIN, rng, int(TASK_CFG["n_support"]), int(TASK_CFG["n_query"]))
        fast_sd = adapt_one_task(base_sd, sup, inner_lr=float(META_CFG["inner_lr"]), inner_steps=int(META_CFG["inner_steps"]))

        # logging-only query loss
        with torch.no_grad():
            fast = GRU4RecDropout(
                vocab_size=VOCAB_SIZE_SOURCE,
                emb_dim=int(META_CFG["emb_dim"]),
                hidden_dim=int(META_CFG["hidden_dim"]),
                pad_id=PAD_ID_SOURCE,
                dropout=float(META_CFG["dropout"]),
            ).to(device)
            load_state_dict_(fast, fast_sd)
            fast.eval()
            x_q, m_q, y_q = qry
            x_q, m_q, y_q = x_q.to(device), m_q.to(device), y_q.to(device)
            logits_q = fast(x_q, make_lengths(m_q))
            meta_q_loss += float(F.cross_entropy(logits_q, y_q, ignore_index=PAD_ID_SOURCE).item())

        for k in deltas:
            deltas[k] += (fast_sd[k].detach() - base_sd[k].detach())

    meta_q_loss /= float(META_CFG["meta_batch_tasks"])

    # Apply Reptile update
    with torch.no_grad():
        for name, p in model.named_parameters():
            if name in deltas:
                p.add_(float(META_CFG["meta_lr"]) * (deltas[name] / float(META_CFG["meta_batch_tasks"])))

    if step % int(META_CFG["log_every"]) == 0:
        elapsed = time.time() - start
        print(f"[12B-09] step={step}/{META_CFG['meta_steps']} meta_query_loss~{meta_q_loss:.4f} elapsed={elapsed:.1f}s")

    if step % int(META_CFG["eval_every"]) == 0:
        val_res = eval_meta(model, BUFFERS_VAL, n_episodes=int(META_CFG["val_episodes"]), seed=1000 + step)
        hr20 = float(val_res["HR@20"])
        print(f"[12B-09] EVAL step={step} | VAL(meta-adapt) HR@20={hr20:.6f} | full={val_res}")
        if hr20 > best_val + 1e-6:
            best_val = hr20
            best_step = step
            best_sd = clone_state_dict(model.state_dict())

if best_sd is not None:
    load_state_dict_(model, best_sd)
print(f"[12B-09] ✅ Meta-train done. best_step={best_step} best_val_HR@20={best_val:.6f}")

# ----------------------------
# 12B-09H) Save checkpoint + report
# ----------------------------
RUN_TAG = RUN_TAG if "RUN_TAG" in globals() else TASK_META.get("run_tag", "unknown_run_tag")
REPORT_DIR = Path(REPO_ROOT) / "reports" / "12B_meta_train_on_source" / str(RUN_TAG)
REPORT_DIR.mkdir(parents=True, exist_ok=True)

ckpt_path = REPORT_DIR / "meta_model_source.pt"
torch.save(
    {
        "run_tag": str(RUN_TAG),
        "task_run_dir": str(Path(REPO_ROOT) / "reports" / "12A_task_builder_for_meta" / TASK_META["run_tag"]),
        "proto": PROTO,
        "task_cfg": TASK_CFG,
        "meta_cfg": META_CFG,
        "vocab_size_source": VOCAB_SIZE_SOURCE,
        "pad_id_source": PAD_ID_SOURCE,
        "unk_id_source": UNK_ID_SOURCE,
        "state_dict": model.state_dict(),
        "best_step": best_step,
        "best_val_hr20": best_val,
    },
    ckpt_path
)

final_val = eval_meta(model, BUFFERS_VAL, n_episodes=100, seed=1234)
report = {
    "run_tag": str(RUN_TAG),
    "created_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
    "proto": PROTO,
    "task_cfg": TASK_CFG,
    "meta_cfg": META_CFG,
    "buffers": {
        "train": {"tasks": len(BUFFERS_TRAIN), **TRAIN_STATS},
        "val":   {"tasks": len(BUFFERS_VAL),   **VAL_STATS},
    },
    "best_step": best_step,
    "best_val_hr20": best_val,
    "final_val_meta_adapt": final_val,
    "notes": [
        "Tasks are multi-task within SOURCE derived from real session_length bins (len_bin).",
        "UNK labels are filtered out during pair construction.",
        "Reservoir sampling used for buffer cap to avoid recency bias.",
        "Adaptation (support) runs under torch.enable_grad(); query eval uses torch.no_grad().",
    ],
}
(REPORT_DIR / "report.json").write_text(json.dumps(report, indent=2), encoding="utf-8")

print("[12B-09] ✅ Saved:", ckpt_path)
print("[12B-09] ✅ Wrote report:", REPORT_DIR / "report.json")
print("[12B-09] REPORT_DIR:", str(REPORT_DIR))

print("\n[12B-09] CHECKPOINT Z")
print("Paste: REPORT_DIR + best_step/best_val_hr20 + final_val_meta_adapt.")


[12B-09] REPO_ROOT: C:\mooc-coldstart-session-meta
[12B-09] ✅ PROTO: {'K_LIST': [5, 10, 20], 'MAX_PREFIX_LEN': 20, 'CAP_ENABLED': True, 'CAP_SESSION_LEN': 200, 'CAP_STRATEGY': 'take_last'}
[12B-09] Source dirs from seq_dir: train=C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\train | val=C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\val | test=C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\test
[12B-09] Source shards: train= 1024 val= 1024 test= 1024
[12B-09] Source vocab: vocab_json=C:\mooc-coldstart-session-meta\data\processed\session_sequences\source_sessions_20251229_232834\source_vocab_items_20251229_232834.json
[12B-09] VOCAB_SIZE_SOURCE: 1620 | PAD/UNK: 0 1
[12B-09] scanned_files=10 sessions_seen=64,811 yielded=6,475 short=0 unk_labels=0 elapsed=0.5s
[12B-09] scanned_files=20 sessions_seen=130,232 yielded=12,974 short=0 unk_labels=0 