Imports + versions

In [26]:
# [CELL 12A-00] Imports + versions
import os, json, time, math, hashlib
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import torch

print("[12A-00] torch:", torch.__version__)
print("[12A-00] pandas:", pd.__version__)
print("[12A-00] numpy:", np.__version__)


[12A-00] torch: 2.9.1+cpu
[12A-00] pandas: 2.3.3
[12A-00] numpy: 2.4.0


Repo root (portable) + load single source-of-truth JSONs

In [27]:
# [CELL 12A-01] Repo root (portable) + load single source-of-truth JSONs
def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    # fallback: git root
    for p in [start, *start.parents]:
        if (p / ".git").exists():
            return p
    raise FileNotFoundError("Could not locate repo root (PROJECT_STATE.md or .git).")

REPO_ROOT = Path(os.getenv("MOOC_REPO_ROOT", "")).expanduser().resolve() if os.getenv("MOOC_REPO_ROOT") else find_repo_root(Path.cwd().resolve())

RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
print("[12A-01] REPO_ROOT:", str(REPO_ROOT))
print("[12A-01] RUN_TAG:", RUN_TAG)

CFG_PATH = REPO_ROOT / "data/processed/supervised/dataloader_config_20251229_163357_20251229_232834.json"
SANITY_PATH = REPO_ROOT / "data/processed/supervised/sanity_metrics_20251229_163357_20251229_232834.json"
GAPS_PATH = REPO_ROOT / "data/processed/normalized_events/session_gap_thresholds.json"

print("[12A-01] Expect config:", CFG_PATH)
print("[12A-01] Expect sanity:", SANITY_PATH)
print("[12A-01] Expect gaps  :", GAPS_PATH)

assert CFG_PATH.exists(), f"Missing: {CFG_PATH}"
assert SANITY_PATH.exists(), f"Missing: {SANITY_PATH}"
assert GAPS_PATH.exists(), f"Missing: {GAPS_PATH}"

DL_CFG = json.loads(CFG_PATH.read_text(encoding="utf-8"))
SANITY = json.loads(SANITY_PATH.read_text(encoding="utf-8"))
GAPS = json.loads(GAPS_PATH.read_text(encoding="utf-8"))

def infer_gap_minutes(d: dict, name: str) -> int:
    if "gap_minutes" in d:  # legacy
        return int(d["gap_minutes"])
    if "primary_threshold_seconds" in d:
        sec = int(d["primary_threshold_seconds"])
        return sec // 60
    raise KeyError(f"[12A-01] {name}: cannot infer gap minutes. keys={list(d.keys())}")

gap_target_m = infer_gap_minutes(GAPS["target"], "target")
gap_source_m = infer_gap_minutes(GAPS["source"], "source")
print(f"[12A-01] target: gap_minutes from primary_threshold_seconds={GAPS['target'].get('primary_threshold_seconds')} -> {gap_target_m}m | label={GAPS['target'].get('primary_threshold_label')}")
print(f"[12A-01] source: gap_minutes from primary_threshold_seconds={GAPS['source'].get('primary_threshold_seconds')} -> {gap_source_m}m | label={GAPS['source'].get('primary_threshold_label')}")

assert gap_target_m == 30, f"target gap mismatch: got {gap_target_m}m"
assert gap_source_m == 10, f"source gap mismatch: got {gap_source_m}m"
print("[12A-01] ✅ Session gaps confirmed.")

print("\n[12A-01] CHECKPOINT A")
print("Paste: inferred gaps lines + confirm asserts passed.")


[12A-01] REPO_ROOT: C:\mooc-coldstart-session-meta
[12A-01] RUN_TAG: 20260104_141727
[12A-01] Expect config: C:\mooc-coldstart-session-meta\data\processed\supervised\dataloader_config_20251229_163357_20251229_232834.json
[12A-01] Expect sanity: C:\mooc-coldstart-session-meta\data\processed\supervised\sanity_metrics_20251229_163357_20251229_232834.json
[12A-01] Expect gaps  : C:\mooc-coldstart-session-meta\data\processed\normalized_events\session_gap_thresholds.json
[12A-01] target: gap_minutes from primary_threshold_seconds=1800 -> 30m | label=30m
[12A-01] source: gap_minutes from primary_threshold_seconds=600 -> 10m | label=10m
[12A-01] ✅ Session gaps confirmed.

[12A-01] CHECKPOINT A
Paste: inferred gaps lines + confirm asserts passed.


Load SOURCE vocab + list shards

**CRITICAL CHECK**: Probe unique domains across shards (verify multi-task assumption)

Stable hash sampling + session→pairs (left-pad to MAX_PREFIX_LEN)

In [28]:
# [CELL 12A-02] Source artifacts + vocab
SRC_ROOT = REPO_ROOT / "data/processed/session_sequences/source_sessions_20251229_232834"
SRC_TRAIN_DIR = SRC_ROOT / "train"
SRC_VAL_DIR   = SRC_ROOT / "val"
SRC_TEST_DIR  = SRC_ROOT / "test"
SRC_VOCAB_PATH = SRC_ROOT / "source_vocab_items_20251229_232834.json"

for p in [SRC_TRAIN_DIR, SRC_VAL_DIR, SRC_TEST_DIR, SRC_VOCAB_PATH]:
    assert p.exists(), f"Missing: {p}"

def list_parquet_files(d: Path):
    files = sorted([p for p in d.glob("*.parquet")])
    return files

train_files = list_parquet_files(SRC_TRAIN_DIR)
val_files   = list_parquet_files(SRC_VAL_DIR)
test_files  = list_parquet_files(SRC_TEST_DIR)

print("[12A-02] shards: train=", len(train_files), "val=", len(val_files), "test=", len(test_files))
assert len(train_files) == 1024 and len(val_files) == 1024 and len(test_files) == 1024, "Shard count drift (expected 1024 each)."

source_vocab = json.loads(SRC_VOCAB_PATH.read_text(encoding="utf-8"))
print("[12A-02] source_vocab keys:", list(source_vocab.keys()))
VOCAB_SIZE_SOURCE = int(source_vocab["vocab_size"])
PAD_ID_SOURCE = int(source_vocab.get("pad_id", 0))
UNK_ID_SOURCE = int(source_vocab.get("unk_id", 1))
item2id = source_vocab["item2id"]

assert len(item2id) == VOCAB_SIZE_SOURCE, f"item2id size mismatch: {len(item2id)} vs vocab_size={VOCAB_SIZE_SOURCE}"

print("[12A-02] VOCAB_SIZE_SOURCE:", VOCAB_SIZE_SOURCE)
print("[12A-02] PAD/UNK:", PAD_ID_SOURCE, UNK_ID_SOURCE)

print("\n[12A-02] CHECKPOINT B")
print("Confirm: shards are 1024 each + vocab_size/pad/unk printed.")


[12A-02] shards: train= 1024 val= 1024 test= 1024
[12A-02] source_vocab keys: ['run_tag_source', 'built_from', 'vocab_size', 'pad_id', 'unk_id', 'item2id']
[12A-02] VOCAB_SIZE_SOURCE: 1620
[12A-02] PAD/UNK: 0 1

[12A-02] CHECKPOINT B
Confirm: shards are 1024 each + vocab_size/pad/unk printed.


Normalize protocol fields from dataloader_config (Notebook 06 compatibility)

In [29]:
# [CELL 12A-03] Normalize protocol fields from dataloader_config (Notebook 06 compatibility)
PROTO_RAW = DL_CFG.get("protocol", {})

# protocol keys in your config are nested; normalize to canonical fields used everywhere.
def normalize_proto(d: dict) -> dict:
    # fixed per repo decisions
    max_prefix_len = int(d.get("max_prefix_len", 20))
    long_pol = d.get("source_long_session_policy", {})
    cap_enabled = bool(long_pol.get("enabled", True))
    cap_session_len = int(long_pol.get("cap_session_len", 200))
    cap_strategy = str(long_pol.get("cap_strategy", "take_last"))

    # K_LIST is fixed by protocol in Notebook 06 (even if not stored under protocol)
    K_LIST = [5, 10, 20]

    return {
        "K_LIST": K_LIST,
        "MAX_PREFIX_LEN": max_prefix_len,
        "CAP_ENABLED": cap_enabled,
        "CAP_SESSION_LEN": cap_session_len,
        "CAP_STRATEGY": cap_strategy,
    }

PROTO = normalize_proto(PROTO_RAW)
print("[12A-03] protocol keys:", list(PROTO_RAW.keys()))
print("[12A-03] source_long_session_policy keys:", list(PROTO_RAW.get("source_long_session_policy", {}).keys()))
print("[12A-03] ✅ PROTO:", PROTO)

# Hard asserts to prevent drift
assert PROTO["K_LIST"] == [5, 10, 20]
assert int(PROTO["MAX_PREFIX_LEN"]) == 20
assert bool(PROTO["CAP_ENABLED"]) is True
assert int(PROTO["CAP_SESSION_LEN"]) == 200
assert str(PROTO["CAP_STRATEGY"]) == "take_last"
print("[12A-03] ✅ PROTO asserts passed (matches Notebook 06).")

print("\n[12A-03] CHECKPOINT C")
print("Paste: protocol keys + source_long_session_policy keys + PROTO.")


[12A-03] protocol keys: ['max_prefix_len', 'source_vocab_mode', 'source_pair_rule', 'source_long_session_policy', 'dataloader', 'seeds']
[12A-03] source_long_session_policy keys: ['enabled', 'cap_session_len', 'cap_strategy']
[12A-03] ✅ PROTO: {'K_LIST': [5, 10, 20], 'MAX_PREFIX_LEN': 20, 'CAP_ENABLED': True, 'CAP_SESSION_LEN': 200, 'CAP_STRATEGY': 'take_last'}
[12A-03] ✅ PROTO asserts passed (matches Notebook 06).

[12A-03] CHECKPOINT C
Paste: protocol keys + source_long_session_policy keys + PROTO.


Streaming episode builder (task = domain)

In [32]:
# [CELL 12A-04] Utilities + streaming pair generator WITH task_key derivation (multi-task inside SOURCE)
# FIX: no pd.read_parquet(..., nrows=1). Use pyarrow schema to check optional columns.

import hashlib
import time
import pyarrow.parquet as pq

MAX_LEN = int(PROTO["MAX_PREFIX_LEN"])
CAP_ENABLED = bool(PROTO["CAP_ENABLED"])
CAP_SESSION_LEN = int(PROTO["CAP_SESSION_LEN"])
CAP_STRATEGY = str(PROTO["CAP_STRATEGY"])
assert CAP_STRATEGY == "take_last"

# Ensure mapping exists
if "item2id" not in globals():
    raise AssertionError("[12A-04] item2id not found; run [12A-02] first.")
source_token_to_id = item2id

print(f"[12A-04] Built source_token_to_id size={len(source_token_to_id):,}")
print("[12A-04] PAD/UNK:", {"PAD_ID_SOURCE": PAD_ID_SOURCE, "UNK_ID_SOURCE": UNK_ID_SOURCE})

def stable_hash64(s: str) -> int:
    h = hashlib.blake2b(s.encode("utf-8"), digest_size=8).digest()
    return int.from_bytes(h, "little", signed=False)

def stable_mod(value, mod: int) -> int:
    return stable_hash64(str(value)) % mod

def map_seq_tokens_to_ids(seq) -> tuple[np.ndarray, int]:
    out = np.empty(len(seq), dtype=np.int64)
    unk = 0
    for i, tok in enumerate(seq):
        tid = source_token_to_id.get(str(tok), UNK_ID_SOURCE)
        out[i] = tid
        if tid == UNK_ID_SOURCE:
            unk += 1
    return out, unk

def session_to_one_pair(seq_ids: np.ndarray):
    if CAP_ENABLED and len(seq_ids) > CAP_SESSION_LEN and CAP_STRATEGY == "take_last":
        seq_ids = seq_ids[-CAP_SESSION_LEN:]
    if len(seq_ids) < 2:
        return None

    label = int(seq_ids[-1])
    prefix = seq_ids[:-1]
    if len(prefix) > MAX_LEN:
        prefix = prefix[-MAX_LEN:]

    x = np.zeros((MAX_LEN,), dtype=np.int64)
    m = np.zeros((MAX_LEN,), dtype=np.int64)
    plen = len(prefix)
    x[-plen:] = prefix
    m[-plen:] = 1
    return x, m, label

def reservoir_add(buf: list, item, cap: int | None, rng: np.random.Generator, seen_count: int):
    if cap is None:
        buf.append(item)
        return
    if len(buf) < cap:
        buf.append(item)
        return
    j = int(rng.integers(0, seen_count + 1))
    if j < cap:
        buf[j] = item

def detect_seq_col(cols: list[str]) -> str:
    for c in ["items", "item_seq", "sequence", "seq", "item_ids"]:
        if c in cols:
            return c
    raise KeyError(f"Could not detect sequence column. cols={cols}")

# ---- NEW: task_key derivation (multi-task inside one dataset) ----
def task_key_from_row(session_length: int, start_ts=None, mode: str = "len_bin") -> str:
    L = int(session_length)
    if mode == "len_bin":
        if L <= 2:   return "len_01_02"
        if L <= 5:   return "len_03_05"
        if L <= 10:  return "len_06_10"
        if L <= 20:  return "len_11_20"
        return "len_21_plus"

    if mode == "time_month":
        try:
            ts = pd.to_datetime(start_ts, unit="s", utc=True) if isinstance(start_ts, (int, np.integer)) else pd.to_datetime(start_ts, utc=True)
            return f"month_{ts.year:04d}_{ts.month:02d}"
        except Exception:
            return task_key_from_row(L, None, mode="len_bin")

    raise ValueError(f"Unknown task_key mode: {mode}")

def parquet_has_col(fp, col: str) -> bool:
    # Fast schema check (no full read)
    schema = pq.read_schema(fp)
    return col in schema.names

def iter_pairs_from_files(
    files,
    sample_mod: int,
    sample_rem: int,
    seed: int,
    task_key_mode: str = "len_bin",
    max_files=None,
    log_every_files: int = 200,
):
    """
    Stream (task_key, x, m, y) pairs.
    - sampling uses stable hash of session_id (works even when session_id isn't int)
    - filters out UNK labels
    """
    rng = np.random.default_rng(seed)
    n_files = 0
    sessions_seen = 0
    yielded = 0
    short = 0
    unk_labels = 0
    unk_tokens_total = 0
    t0 = time.time()

    for fp in files:
        n_files += 1
        if max_files is not None and n_files > max_files:
            break

        # Decide optional cols from schema
        has_start_ts = parquet_has_col(fp, "start_ts")

        cols = ["session_id", "session_length", "items"]
        if has_start_ts:
            cols.append("start_ts")

        df = pd.read_parquet(fp, columns=cols)

        seq_col = detect_seq_col(list(df.columns))

        for row in df.itertuples(index=False):
            sessions_seen += 1
            sid = getattr(row, "session_id")

            if sample_mod and sample_mod > 1:
                if stable_mod(sid, sample_mod) != sample_rem:
                    continue

            seq = getattr(row, seq_col)
            if seq is None:
                short += 1
                continue

            seq_ids, unk_tok = map_seq_tokens_to_ids(seq)
            unk_tokens_total += unk_tok

            pair = session_to_one_pair(seq_ids)
            if pair is None:
                short += 1
                continue

            x, m, y = pair
            if y == UNK_ID_SOURCE:
                unk_labels += 1
                continue

            sess_len = getattr(row, "session_length")
            st = getattr(row, "start_ts", None) if has_start_ts else None
            task_key = task_key_from_row(sess_len, st, mode=task_key_mode)

            yielded += 1
            yield task_key, x, m, y

        if (n_files % log_every_files) == 0:
            print(f"[12A-04] scanned_files={n_files}/{len(files)} sessions_seen={sessions_seen:,} "
                  f"yielded={yielded:,} short={short:,} unk_labels={unk_labels:,} unk_tokens_total={unk_tokens_total:,} "
                  f"elapsed={time.time()-t0:.1f}s")

    print(f"[12A-04] DONE files={n_files} sessions_seen={sessions_seen:,} yielded={yielded:,} short={short:,} "
          f"unk_labels={unk_labels:,} unk_tokens_total={unk_tokens_total:,} elapsed={time.time()-t0:.1f}s")

print("[12A-04] stable_mod probe:", stable_mod("3160332::21", 10), stable_mod("3160332::21", 10), "(should match)")
print("[12A-04] ✅ Streaming pair generator (task_key) ready")
print("\n[12A-04] CHECKPOINT D")
print("Next run [12A-05] again (it should work now).")


[12A-04] Built source_token_to_id size=1,620
[12A-04] PAD/UNK: {'PAD_ID_SOURCE': 0, 'UNK_ID_SOURCE': 1}
[12A-04] stable_mod probe: 9 9 (should match)
[12A-04] ✅ Streaming pair generator (task_key) ready

[12A-04] CHECKPOINT D
Next run [12A-05] again (it should work now).


 Probe 2 episodes (shape sanity)

In [33]:
# [CELL 12A-05] Probe generator output + task diversity diagnostics (replaces domain)

TASK_CFG = {
    "n_support": 20,
    "n_query": 20,
    "sample_mod": 10,
    "sample_rem": 0,
    "max_files": 50,  # diagnostic scan
    "max_pairs_per_task_buffer": 50_000,
    "seed": 42,
    "task_key_mode": "len_bin",   # <<< multi-task inside SOURCE
    "require_multi_task": True,
    "min_tasks_required": 3,
}

print("[12A-05] TASK_CFG:", TASK_CFG)
assert TASK_CFG["n_support"] + TASK_CFG["n_query"] <= TASK_CFG["max_pairs_per_task_buffer"], \
    "Invalid: n_support+n_query exceeds max_pairs_per_task_buffer."

# Probe 3 samples
gen = iter_pairs_from_files(
    train_files,
    sample_mod=TASK_CFG["sample_mod"],
    sample_rem=TASK_CFG["sample_rem"],
    seed=TASK_CFG["seed"],
    task_key_mode=TASK_CFG["task_key_mode"],
    max_files=2
)
for j in range(3):
    task_key, x, m, y = next(gen)
    print(f"[12A-05] sample {j}: task={task_key} x_nonzero={int(m.sum())} label={y}")
    print(" x[:10]=", x[:10].tolist())
    print(" m[:10]=", m[:10].tolist())

# Task diversity scan
task_counts = {}
pairs_scanned = 0
for task_key, *_ in iter_pairs_from_files(
    train_files,
    sample_mod=TASK_CFG["sample_mod"],
    sample_rem=TASK_CFG["sample_rem"],
    seed=TASK_CFG["seed"],
    task_key_mode=TASK_CFG["task_key_mode"],
    max_files=TASK_CFG["max_files"]
):
    task_counts[task_key] = task_counts.get(task_key, 0) + 1
    pairs_scanned += 1

tasks_sorted = sorted(task_counts.items(), key=lambda x: x[1], reverse=True)
print(f"[12A-05] task_unique={len(task_counts)} | pairs_scanned={pairs_scanned:,}")
print("[12A-05] top10 tasks:", tasks_sorted[:10])

if TASK_CFG["require_multi_task"]:
    assert len(task_counts) >= TASK_CFG["min_tasks_required"], (
        f"[12A-05] Too few tasks detected: {len(task_counts)}. "
        "Meta-learning would be weak. Switch task_key_mode or adjust bins."
    )

print("\n[12A-05] CHECKPOINT E")
print("Paste: 3 probe samples + task_unique + top10 tasks.")


[12A-05] TASK_CFG: {'n_support': 20, 'n_query': 20, 'sample_mod': 10, 'sample_rem': 0, 'max_files': 50, 'max_pairs_per_task_buffer': 50000, 'seed': 42, 'task_key_mode': 'len_bin', 'require_multi_task': True, 'min_tasks_required': 3}
[12A-05] sample 0: task=len_21_plus x_nonzero=20 label=197
 x[:10]= [197, 197, 197, 197, 197, 197, 197, 197, 197, 197]
 m[:10]= [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[12A-05] sample 1: task=len_21_plus x_nonzero=20 label=344
 x[:10]= [344, 344, 344, 344, 344, 344, 344, 344, 344, 344]
 m[:10]= [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[12A-05] sample 2: task=len_06_10 x_nonzero=9 label=179
 x[:10]= [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 m[:10]= [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[12A-04] DONE files=51 sessions_seen=325,575 yielded=32,540 short=0 unk_labels=0 unk_tokens_total=0 elapsed=3.3s
[12A-05] task_unique=5 | pairs_scanned=32,540
[12A-05] top10 tasks: [('len_21_plus', 7920), ('len_03_05', 7541), ('len_06_10', 6939), ('len_11_20', 6686), ('len_01_02', 3454)]

[12A-05] CHECKPOINT E
P

Save task-builder config + stats snapshot

In [34]:
# [CELL 12A-06] Build per-task buffers (reservoir sampled) for episode sampling

rng = np.random.default_rng(TASK_CFG["seed"])

buffers = {}        # task_key -> list of (x,m,y)
seen_per_task = {}  # task_key -> count for reservoir
cap = int(TASK_CFG["max_pairs_per_task_buffer"])

pairs_total_seen = 0
t0 = time.time()

for task_key, x, m, y in iter_pairs_from_files(
    train_files,
    sample_mod=TASK_CFG["sample_mod"],
    sample_rem=TASK_CFG["sample_rem"],
    seed=TASK_CFG["seed"],
    task_key_mode=TASK_CFG["task_key_mode"],
    max_files=TASK_CFG["max_files"]
):
    pairs_total_seen += 1
    if task_key not in buffers:
        buffers[task_key] = []
        seen_per_task[task_key] = 0
    seen_per_task[task_key] += 1

    reservoir_add(buffers[task_key], (x, m, y), cap=cap, rng=rng, seen_count=seen_per_task[task_key])

    if pairs_total_seen % 50_000 == 0:
        kept = sum(len(v) for v in buffers.values())
        print(f"[12A-06] pairs_seen={pairs_total_seen:,} kept={kept:,} tasks={len(buffers)} elapsed={time.time()-t0:.1f}s")

kept = sum(len(v) for v in buffers.values())
print(f"[12A-06] DONE build buffers | pairs_seen={pairs_total_seen:,} kept={kept:,} tasks={len(buffers)} elapsed={time.time()-t0:.1f}s")

need = TASK_CFG["n_support"] + TASK_CFG["n_query"]
ok = {k: len(v) for k, v in buffers.items() if len(v) >= need}
bad = {k: len(v) for k, v in buffers.items() if len(v) < need}

print("[12A-06] tasks_ok:", len(ok), "| tasks_bad:", len(bad), "| need per task:", need)
if len(bad) > 0:
    print("[12A-06] BAD tasks:", sorted(bad.items(), key=lambda x: x[1])[:10])

assert len(ok) >= TASK_CFG["min_tasks_required"], "[12A-06] Not enough tasks with sufficient pairs for episodes."

print("\n[12A-06] CHECKPOINT F")
print("Paste: DONE build buffers line + tasks_ok/tasks_bad + need.")


[12A-04] DONE files=51 sessions_seen=325,575 yielded=32,540 short=0 unk_labels=0 unk_tokens_total=0 elapsed=2.5s
[12A-06] DONE build buffers | pairs_seen=32,540 kept=32,540 tasks=5 elapsed=2.5s
[12A-06] tasks_ok: 5 | tasks_bad: 0 | need per task: 40

[12A-06] CHECKPOINT F
Paste: DONE build buffers line + tasks_ok/tasks_bad + need.


Episode sampler with max_episodes control

In [35]:
# [CELL 12A-07] Episode sampler with max_episodes control

def sample_episode(buf: list, n_support: int, n_query: int, rng: np.random.Generator):
    idx = rng.choice(len(buf), size=(n_support + n_query), replace=False)
    sup = [buf[i] for i in idx[:n_support]]
    qry = [buf[i] for i in idx[n_support:]]

    def stack(pairs):
        xs = np.stack([p[0] for p in pairs], axis=0)
        ms = np.stack([p[1] for p in pairs], axis=0)
        ys = np.asarray([p[2] for p in pairs], dtype=np.int64)
        return xs, ms, ys

    return stack(sup), stack(qry)

def iter_episodes(buffers: dict, n_support: int, n_query: int, seed: int, max_episodes: int = 100):
    rng = np.random.default_rng(seed)
    keys = sorted([k for k, v in buffers.items() if len(v) >= (n_support + n_query)])
    for ep in range(max_episodes):
        task_key = keys[int(rng.integers(0, len(keys)))]
        (sx, sm, sy), (qx, qm, qy) = sample_episode(buffers[task_key], n_support, n_query, rng)
        yield {
            "episode_id": ep,
            "task_key": task_key,
            "support": {"x": sx, "m": sm, "y": sy},
            "query": {"x": qx, "m": qm, "y": qy},
        }

print("[12A-07] ✅ Episode sampler ready")
print("\n[12A-07] CHECKPOINT G")
print("Next run [12A-08] to print a few episodes + task keys.")


[12A-07] ✅ Episode sampler ready

[12A-07] CHECKPOINT G
Next run [12A-08] to print a few episodes + task keys.


Probe episodes

In [36]:
# [CELL 12A-08] Probe episodes

MAX_EPISODES_PROBE = 5
eps = iter_episodes(
    buffers=buffers,
    n_support=TASK_CFG["n_support"],
    n_query=TASK_CFG["n_query"],
    seed=TASK_CFG["seed"],
    max_episodes=MAX_EPISODES_PROBE
)

seen = []
for ep in eps:
    tk = ep["task_key"]
    seen.append(tk)
    sm = ep["support"]["m"]
    qm = ep["query"]["m"]
    print(f"[12A-08] episode {ep['episode_id']} task={tk}")
    print("  support x", ep["support"]["x"].shape, "m", sm.shape, "y", ep["support"]["y"].shape,
          "| x_nonzero_mean", float(sm.sum(axis=1).mean()))
    print("  query   x", ep["query"]["x"].shape, "m", qm.shape, "y", ep["query"]["y"].shape,
          "| x_nonzero_mean", float(qm.sum(axis=1).mean()))

print("[12A-08] tasks_seen_in_probe:", seen)

assert len(set(seen)) >= 2, "[12A-08] probe did not show multiple tasks; increase MAX_EPISODES_PROBE or check bins."
print("\n[12A-08] CHECKPOINT H")
print("Paste: episodes output + tasks_seen_in_probe.")


[12A-08] episode 0 task=len_01_02
  support x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 1.0
  query   x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 1.0
[12A-08] episode 1 task=len_21_plus
  support x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 20.0
  query   x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 20.0
[12A-08] episode 2 task=len_21_plus
  support x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 20.0
  query   x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 20.0
[12A-08] episode 3 task=len_01_02
  support x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 1.0
  query   x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 1.0
[12A-08] episode 4 task=len_03_05
  support x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 3.05
  query   x (20, 20) m (20, 20) y (20,) | x_nonzero_mean 2.85
[12A-08] tasks_seen_in_probe: ['len_01_02', 'len_21_plus', 'len_21_plus', 'len_01_02', 'len_03_05']

[12A-08] CHECKPOINT H
Paste: episodes output + tasks_seen_in_probe.


Write report artifacts

In [37]:
# [CELL 12A-09] Write report artifacts

REPORT_DIR = REPO_ROOT / "reports/12A_task_builder_for_meta" / RUN_TAG
REPORT_DIR.mkdir(parents=True, exist_ok=True)

meta_task_config = {
    "run_tag": RUN_TAG,
    "created_at": datetime.now().isoformat(timespec="seconds"),
    "protocol": PROTO,
    "session_gaps": {
        "target_gap_minutes": gap_target_m,
        "source_gap_minutes": gap_source_m,
        "target_gap_label": GAPS["target"].get("primary_threshold_label"),
        "source_gap_label": GAPS["source"].get("primary_threshold_label"),
    },
    "task_cfg": TASK_CFG,
    "task_definition": {
        "note": "SOURCE parquet has domain_unique=1 (always 'source'). Tasks are derived from real fields.",
        "task_key_mode": TASK_CFG["task_key_mode"],
        "len_bin_definition": ["len_01_02","len_03_05","len_06_10","len_11_20","len_21_plus"],
    },
    "source": {
        "run_tag_source": source_vocab.get("run_tag_source"),
        "vocab_size": VOCAB_SIZE_SOURCE,
        "pad_id": PAD_ID_SOURCE,
        "unk_id": UNK_ID_SOURCE,
        "shards": {"train": len(train_files), "val": len(val_files), "test": len(test_files)},
        "paths": {
            "train_dir": str(SRC_TRAIN_DIR),
            "val_dir": str(SRC_VAL_DIR),
            "test_dir": str(SRC_TEST_DIR),
            "vocab": str(SRC_VOCAB_PATH),
        },
    },
    "notes": [
        "UNK labels filtered out at pair construction.",
        "Reservoir sampling used for buffer cap to avoid recency bias.",
        "Episodes are sampled across task_key groups (multi-task within SOURCE).",
        "Long-session cap uses take_last (ablate later in Notebook 13).",
    ]
}

task_coverage = {
    "pairs_total_seen": int(pairs_total_seen),
    "pairs_total_kept": int(sum(len(v) for v in buffers.values())),
    "tasks_total": int(len(buffers)),
    "need_per_task": int(TASK_CFG["n_support"] + TASK_CFG["n_query"]),
    "pairs_kept_per_task": {k: len(v) for k, v in buffers.items()},
    "pairs_seen_per_task": {k: int(seen_per_task.get(k, 0)) for k in buffers.keys()},
}

META_CFG_PATH = REPORT_DIR / "meta_task_config.json"
COV_PATH = REPORT_DIR / "task_coverage.json"

META_CFG_PATH.write_text(json.dumps(meta_task_config, indent=2, ensure_ascii=False), encoding="utf-8")
COV_PATH.write_text(json.dumps(task_coverage, indent=2, ensure_ascii=False), encoding="utf-8")

print("[12A-09] ✅ Wrote:", str(META_CFG_PATH))
print("[12A-09] ✅ Wrote:", str(COV_PATH))

print("\n[12A-09] CHECKPOINT I")
print("Paste: both written paths.")


[12A-09] ✅ Wrote: C:\mooc-coldstart-session-meta\reports\12A_task_builder_for_meta\20260104_141727\meta_task_config.json
[12A-09] ✅ Wrote: C:\mooc-coldstart-session-meta\reports\12A_task_builder_for_meta\20260104_141727\task_coverage.json

[12A-09] CHECKPOINT I
Paste: both written paths.


Update repo meta.json

In [38]:
# [CELL 12A-10] Update repo meta.json

META_JSON = REPO_ROOT / "meta.json"
meta = json.loads(META_JSON.read_text(encoding="utf-8")) if META_JSON.exists() else {}
meta.setdefault("runs", [])

meta["runs"].append({
    "kind": "12A_task_builder_for_meta",
    "run_tag": RUN_TAG,
    "created_at": datetime.now().isoformat(timespec="seconds"),
    "report_dir": str(REPORT_DIR),
    "artifacts": {
        "meta_task_config": str(META_CFG_PATH),
        "task_coverage": str(COV_PATH),
    }
})

META_JSON.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
print("[12A-10] ✅ Updated meta.json:", str(META_JSON))

print("\n[12A-10] CHECKPOINT J")
print("Confirm meta.json updated.")


[12A-10] ✅ Updated meta.json: C:\mooc-coldstart-session-meta\meta.json

[12A-10] CHECKPOINT J
Confirm meta.json updated.
