Bootstrap (root, run dir, device, seed)

In [1]:
# [CELL 12C-00] Bootstrap
import os, json, time, random
from pathlib import Path
from datetime import datetime

import numpy as np
import torch

CELL = "12C-00"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "meta.json").exists():
            return p
    return start

REPO_ROOT = str(find_repo_root(Path.cwd().resolve()))
RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
REPORT_DIR = os.path.join(REPO_ROOT, "reports", "12C_meta_adapt_and_eval_on_target", RUN_TAG)
os.makedirs(REPORT_DIR, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SEED = 20260105
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"[{CELL}] REPO_ROOT  : {REPO_ROOT}")
print(f"[{CELL}] REPORT_DIR : {REPORT_DIR}")
print(f"[{CELL}] DEVICE     : {DEVICE} | torch={torch.__version__}")
print(f"[{CELL}] SEED       : {SEED}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-00] Starting... 2026-01-05 01:33:27
[12C-00] REPO_ROOT  : C:\mooc-coldstart-session-meta
[12C-00] REPORT_DIR : C:\mooc-coldstart-session-meta\reports\12C_meta_adapt_and_eval_on_target\20260105_013327
[12C-00] DEVICE     : cpu | torch=2.9.1+cpu
[12C-00] SEED       : 20260105
[12C-00] Done in 0.01s


Load latest 12B run (report + ckpt + meta_cfg/proto)

In [2]:
# [CELL 12C-01] Load latest 12B meta-train run from meta.json
import time
from datetime import datetime

CELL = "12C-01"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

META_JSON_PATH = os.path.join(REPO_ROOT, "meta.json")
with open(META_JSON_PATH, "r", encoding="utf-8") as f:
    META = json.load(f)

runs_12b = [r for r in META.get("runs", []) if r.get("kind") == "12B_meta_train_on_source"]
if not runs_12b:
    raise RuntimeError(f"[{CELL}] No 12B run found in meta.json")

# pick last (created_at sort if present)
runs_12b = sorted(runs_12b, key=lambda r: r.get("created_at") or "")
RUN_12B = runs_12b[-1]

RUN_TAG_12B = RUN_12B["run_tag"]
REPORT_DIR_12B = RUN_12B["report_dir"]
REPORT_JSON_12B = RUN_12B["artifacts"]["report_json"]
CKPT_META_INIT = RUN_12B["artifacts"]["meta_model_source_pt"]

print(f"[{CELL}] 12B run_tag: {RUN_TAG_12B}")
print(f"[{CELL}] 12B report : {REPORT_JSON_12B}")
print(f"[{CELL}] 12B ckpt   : {CKPT_META_INIT}")

with open(REPORT_JSON_12B, "r", encoding="utf-8") as f:
    REPORT_12B = json.load(f)

PROTO = REPORT_12B["proto"]
META_CFG = REPORT_12B["meta_cfg"]  # your report uses meta_cfg
print(f"[{CELL}] PROTO keys   : {list(PROTO.keys())}")
print(f"[{CELL}] META_CFG keys: {list(META_CFG.keys())}")

print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-01] Starting... 2026-01-05 01:33:49
[12C-01] 12B run_tag: 20260104_165117
[12C-01] 12B report : C:\mooc-coldstart-session-meta\reports\12B_meta_train_on_source\20260104_165117\report.json
[12C-01] 12B ckpt   : C:\mooc-coldstart-session-meta\reports\12B_meta_train_on_source\20260104_165117\meta_model_source.pt
[12C-01] PROTO keys   : ['K_LIST', 'MAX_PREFIX_LEN', 'CAP_ENABLED', 'CAP_SESSION_LEN', 'CAP_STRATEGY']
[12C-01] META_CFG keys: ['emb_dim', 'hidden_dim', 'dropout', 'meta_lr', 'inner_lr', 'inner_steps', 'meta_steps', 'meta_batch_tasks', 'grad_clip', 'seed', 'log_every', 'eval_every', 'val_episodes']
[12C-01] Done in 0.00s


Load dataloader_config + target tensor metadata (PAD/VOCAB)

In [3]:
# [CELL 12C-02] Load tensor target config + target_tensor_metadata_*.json
import time
from datetime import datetime

CELL = "12C-02"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

DL_CFG_PATH = os.path.join(REPO_ROOT, "data", "processed", "supervised", "dataloader_config_20251229_163357_20251229_232834.json")
print(f"[{CELL}] DL_CFG_PATH: {DL_CFG_PATH}")

with open(DL_CFG_PATH, "r", encoding="utf-8") as f:
    DL_CFG = json.load(f)

TARGET_CFG = DL_CFG["target"]
TRAIN_PT = TARGET_CFG["train_pt"]
VAL_PT   = TARGET_CFG["val_pt"]
TEST_PT  = TARGET_CFG["test_pt"]
TARGET_META_JSON = TARGET_CFG["meta_json"]

print(f"[{CELL}] train_pt : {TRAIN_PT}")
print(f"[{CELL}] val_pt   : {VAL_PT}")
print(f"[{CELL}] test_pt  : {TEST_PT}")
print(f"[{CELL}] meta_json: {TARGET_META_JSON}")

with open(TARGET_META_JSON, "r", encoding="utf-8") as f:
    TARGET_META = json.load(f)

PAD_ID = int(TARGET_META["pad_id"])
UNK_ID = int(TARGET_META["unk_id"])
VOCAB_SIZE_TARGET = int(TARGET_META["vocab_size"])

print(f"[{CELL}] PAD_ID={PAD_ID} UNK_ID={UNK_ID} VOCAB_SIZE_TARGET={VOCAB_SIZE_TARGET}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-02] Starting... 2026-01-05 01:34:46
[12C-02] DL_CFG_PATH: C:\mooc-coldstart-session-meta\data\processed\supervised\dataloader_config_20251229_163357_20251229_232834.json
[12C-02] train_pt : C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_train_20251229_163357.pt
[12C-02] val_pt   : C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_val_20251229_163357.pt
[12C-02] test_pt  : C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_test_20251229_163357.pt
[12C-02] meta_json: C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_metadata_20251229_163357.json
[12C-02] PAD_ID=0 UNK_ID=1 VOCAB_SIZE_TARGET=747
[12C-02] Done in 0.00s


Load latest 12B run (report + ckpt + meta_cfg/proto)

In [4]:
# [CELL 12C-01] Load latest 12B meta-train run from meta.json
import time
from datetime import datetime

CELL = "12C-01"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

META_JSON_PATH = os.path.join(REPO_ROOT, "meta.json")
with open(META_JSON_PATH, "r", encoding="utf-8") as f:
    META = json.load(f)

runs_12b = [r for r in META.get("runs", []) if r.get("kind") == "12B_meta_train_on_source"]
if not runs_12b:
    raise RuntimeError(f"[{CELL}] No 12B run found in meta.json")

# pick last (created_at sort if present)
runs_12b = sorted(runs_12b, key=lambda r: r.get("created_at") or "")
RUN_12B = runs_12b[-1]

RUN_TAG_12B = RUN_12B["run_tag"]
REPORT_DIR_12B = RUN_12B["report_dir"]
REPORT_JSON_12B = RUN_12B["artifacts"]["report_json"]
CKPT_META_INIT = RUN_12B["artifacts"]["meta_model_source_pt"]

print(f"[{CELL}] 12B run_tag: {RUN_TAG_12B}")
print(f"[{CELL}] 12B report : {REPORT_JSON_12B}")
print(f"[{CELL}] 12B ckpt   : {CKPT_META_INIT}")

with open(REPORT_JSON_12B, "r", encoding="utf-8") as f:
    REPORT_12B = json.load(f)

PROTO = REPORT_12B["proto"]
META_CFG = REPORT_12B["meta_cfg"]  # your report uses meta_cfg
print(f"[{CELL}] PROTO keys   : {list(PROTO.keys())}")
print(f"[{CELL}] META_CFG keys: {list(META_CFG.keys())}")

print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-01] Starting... 2026-01-05 01:35:11
[12C-01] 12B run_tag: 20260104_165117
[12C-01] 12B report : C:\mooc-coldstart-session-meta\reports\12B_meta_train_on_source\20260104_165117\report.json
[12C-01] 12B ckpt   : C:\mooc-coldstart-session-meta\reports\12B_meta_train_on_source\20260104_165117\meta_model_source.pt
[12C-01] PROTO keys   : ['K_LIST', 'MAX_PREFIX_LEN', 'CAP_ENABLED', 'CAP_SESSION_LEN', 'CAP_STRATEGY']
[12C-01] META_CFG keys: ['emb_dim', 'hidden_dim', 'dropout', 'meta_lr', 'inner_lr', 'inner_steps', 'meta_steps', 'meta_batch_tasks', 'grad_clip', 'seed', 'log_every', 'eval_every', 'val_episodes']
[12C-01] Done in 0.00s


Load dataloader_config + target tensor metadata (PAD/VOCAB)

In [5]:
# [CELL 12C-02] Load tensor target config + target_tensor_metadata_*.json
import time
from datetime import datetime

CELL = "12C-02"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

DL_CFG_PATH = os.path.join(REPO_ROOT, "data", "processed", "supervised", "dataloader_config_20251229_163357_20251229_232834.json")
print(f"[{CELL}] DL_CFG_PATH: {DL_CFG_PATH}")

with open(DL_CFG_PATH, "r", encoding="utf-8") as f:
    DL_CFG = json.load(f)

TARGET_CFG = DL_CFG["target"]
TRAIN_PT = TARGET_CFG["train_pt"]
VAL_PT   = TARGET_CFG["val_pt"]
TEST_PT  = TARGET_CFG["test_pt"]
TARGET_META_JSON = TARGET_CFG["meta_json"]

print(f"[{CELL}] train_pt : {TRAIN_PT}")
print(f"[{CELL}] val_pt   : {VAL_PT}")
print(f"[{CELL}] test_pt  : {TEST_PT}")
print(f"[{CELL}] meta_json: {TARGET_META_JSON}")

with open(TARGET_META_JSON, "r", encoding="utf-8") as f:
    TARGET_META = json.load(f)

PAD_ID = int(TARGET_META["pad_id"])
UNK_ID = int(TARGET_META["unk_id"])
VOCAB_SIZE_TARGET = int(TARGET_META["vocab_size"])

print(f"[{CELL}] PAD_ID={PAD_ID} UNK_ID={UNK_ID} VOCAB_SIZE_TARGET={VOCAB_SIZE_TARGET}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-02] Starting... 2026-01-05 01:35:34
[12C-02] DL_CFG_PATH: C:\mooc-coldstart-session-meta\data\processed\supervised\dataloader_config_20251229_163357_20251229_232834.json
[12C-02] train_pt : C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_train_20251229_163357.pt
[12C-02] val_pt   : C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_val_20251229_163357.pt
[12C-02] test_pt  : C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_test_20251229_163357.pt
[12C-02] meta_json: C:\mooc-coldstart-session-meta\data\processed\tensor_target\target_tensor_metadata_20251229_163357.json
[12C-02] PAD_ID=0 UNK_ID=1 VOCAB_SIZE_TARGET=747
[12C-02] Done in 0.00s


Load .pt tensors (PyTorch 2.6+ weights_only fix)

In [6]:
# [CELL 12C-03] Load TARGET tensors (trusted) + lengths from attn_mask
import time, torch
from datetime import datetime

CELL = "12C-03"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

def torch_load_trusted(pt_path: str):
    try:
        return torch.load(pt_path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(pt_path, map_location="cpu")

def make_lengths(attn_mask: torch.Tensor) -> torch.Tensor:
    return attn_mask.sum(dim=1).clamp(min=1)

def load_split(pt_path, name):
    raw = torch_load_trusted(pt_path)
    if not isinstance(raw, dict):
        raise RuntimeError(f"[{CELL}] {name}: expected dict, got {type(raw)}")
    for k in ["input_ids", "attn_mask", "labels"]:
        if k not in raw:
            raise RuntimeError(f"[{CELL}] {name}: missing key '{k}' in pt dict. keys={list(raw.keys())}")

    out = {
        "input_ids": raw["input_ids"].long(),
        "attn_mask": raw["attn_mask"].long(),
        "labels": raw["labels"].long(),
        # keep metadata if present (optional)
        "session_id": raw.get("session_id"),
        "user_id": raw.get("user_id"),
        "t": raw.get("t"),
        "split": raw.get("split"),
    }
    out["lengths"] = make_lengths(out["attn_mask"])

    N, L = out["input_ids"].shape
    lens = out["lengths"]
    print(f"[{CELL}] {name}: input_ids={tuple(out['input_ids'].shape)} labels={tuple(out['labels'].shape)}")
    print(f"[{CELL}] {name}: lengths min={int(lens.min())} p50={int(lens.median())} max={int(lens.max())}")

    # left-pad sanity: last token should be non-pad for len>=1
    i0 = 0
    Li = int(lens[i0])
    tail = out["input_ids"][i0].tolist()[-5:]
    print(f"[{CELL}] {name}: sample0 len={Li} y={int(out['labels'][i0])} x_last5={tail}")
    return out

TARGET_TRAIN = load_split(TRAIN_PT, "TARGET_TRAIN")
TARGET_VAL   = load_split(VAL_PT,   "TARGET_VAL")
TARGET_TEST  = load_split(TEST_PT,  "TARGET_TEST")

TENSOR_SEQ_LEN = int(TARGET_TRAIN["input_ids"].shape[1])
print(f"[{CELL}] TENSOR_SEQ_LEN={TENSOR_SEQ_LEN}")

# verify vocab bound
mx = int(torch.max(torch.stack([
    TARGET_TRAIN["input_ids"].max(),
    TARGET_VAL["input_ids"].max(),
    TARGET_TEST["input_ids"].max(),
    TARGET_TRAIN["labels"].max(),
    TARGET_VAL["labels"].max(),
    TARGET_TEST["labels"].max(),
])).item())
print(f"[{CELL}] max_token_id_seen={mx} | VOCAB_SIZE_TARGET={VOCAB_SIZE_TARGET}")
if mx >= VOCAB_SIZE_TARGET:
    raise RuntimeError(f"[{CELL}] VOCAB_SIZE_TARGET too small: max_id={mx} >= {VOCAB_SIZE_TARGET}")

print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-03] Starting... 2026-01-05 01:36:56
[12C-03] TARGET_TRAIN: input_ids=(1944, 20) labels=(1944,)
[12C-03] TARGET_TRAIN: lengths min=1 p50=4 max=20
[12C-03] TARGET_TRAIN: sample0 len=1 y=423 x_last5=[0, 0, 0, 0, 416]
[12C-03] TARGET_VAL: input_ids=(189, 20) labels=(189,)
[12C-03] TARGET_VAL: lengths min=1 p50=3 max=20
[12C-03] TARGET_VAL: sample0 len=1 y=380 x_last5=[0, 0, 0, 0, 383]
[12C-03] TARGET_TEST: input_ids=(200, 20) labels=(200,)
[12C-03] TARGET_TEST: lengths min=1 p50=3 max=20
[12C-03] TARGET_TEST: sample0 len=1 y=151 x_last5=[0, 0, 0, 0, 150]
[12C-03] TENSOR_SEQ_LEN=20
[12C-03] max_token_id_seen=746 | VOCAB_SIZE_TARGET=747
[12C-03] Done in 0.01s


Metrics + exclude-seen masking (FIXED for left-padding)

In [7]:
# [CELL 12C-04] Metrics + exclude-seen mask (left-padded => seen are last 'length' tokens)
import time, torch
from datetime import datetime

CELL = "12C-04"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

K_LIST = PROTO.get("K_LIST", [5,10,20])

def get_batch(split_dict, idxs):
    x = split_dict["input_ids"][idxs].to(DEVICE).long()   # [B,L]
    l = split_dict["lengths"][idxs].to(DEVICE).long()     # [B]
    y = split_dict["labels"][idxs].to(DEVICE).long()      # [B]
    return x, l, y

@torch.no_grad()
def mask_seen_logits_leftpad(logits, input_ids, lengths, pad_id: int):
    # left-padded: actual tokens are at the tail => use input_ids[i, -Li:]
    masked = logits.clone()
    B = input_ids.size(0)
    for i in range(B):
        Li = int(lengths[i].item())
        if Li <= 0:
            continue
        seen = input_ids[i, -Li:]
        seen = seen[seen != int(pad_id)]
        if seen.numel() > 0:
            masked[i, seen] = float("-inf")
    return masked

@torch.no_grad()
def metrics_from_logits(logits, y_true, k_list):
    res = {}
    max_k = max(k_list)
    topk = torch.topk(logits, k=max_k, dim=-1).indices  # [B,max_k]
    y = y_true.view(-1,1)

    for k in k_list:
        match = (topk[:, :k] == y)
        hits = match.any(dim=1).float()
        hr = hits.mean().item()

        rank = torch.where(match.any(dim=1), match.float().argmax(dim=1)+1, torch.zeros_like(hits, dtype=torch.long))
        mrr = torch.where(rank>0, 1.0/rank.float(), torch.zeros_like(hits)).mean().item()
        ndcg = torch.where(rank>0, 1.0/torch.log2(rank.float()+1.0), torch.zeros_like(hits)).mean().item()

        res[f"HR@{k}"] = hr
        res[f"MRR@{k}"] = mrr
        res[f"NDCG@{k}"] = ndcg
    return res

print(f"[{CELL}] K_LIST={K_LIST} PAD_ID={PAD_ID} VOCAB_SIZE_TARGET={VOCAB_SIZE_TARGET} DEVICE={DEVICE}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-04] Starting... 2026-01-05 01:37:20
[12C-04] K_LIST=[5, 10, 20] PAD_ID=0 VOCAB_SIZE_TARGET=747 DEVICE=cpu
[12C-04] Done in 0.00s


Model definition (copy from 12B unchanged)

In [8]:
# [CELL 12C-05] Model: copied from 12B unchanged
import torch.nn as nn
import torch.nn.functional as F
import copy

class GRU4RecDropout(nn.Module):
    def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int, pad_id: int, dropout: float = 0.3):
        super().__init__()
        self.pad_id = int(pad_id)
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=self.pad_id)
        self.drop = nn.Dropout(float(dropout))
        self.gru = nn.GRU(input_size=emb_dim, hidden_size=hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids: torch.Tensor, lengths: torch.Tensor):
        # input_ids: [B, T] left-padded. lengths: [B] counts of non-pad
        emb = self.drop(self.emb(input_ids))  # [B,T,E]
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, h = self.gru(packed)  # h: [1,B,H]
        logits = self.out(h.squeeze(0))  # [B,V]
        return logits

print("[12C-05] ✅ GRU4RecDropout ready")


[12C-05] ✅ GRU4RecDropout ready


Load 12B checkpoint and build TARGET model with partial load (gru only)

In [9]:
# [CELL 12C-06] Load 12B ckpt; build target model; load only GRU weights (plus optional PAD/UNK rows)
import time, torch
from datetime import datetime

CELL = "12C-06"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

ckpt = torch.load(CKPT_META_INIT, map_location="cpu", weights_only=False)
print(f"[{CELL}] ckpt keys: {list(ckpt.keys())}")

state_dict_src = ckpt["state_dict"]
VOCAB_SIZE_SOURCE = int(ckpt.get("vocab_size_source", -1))
print(f"[{CELL}] VOCAB_SIZE_SOURCE(from ckpt)={VOCAB_SIZE_SOURCE} | VOCAB_SIZE_TARGET={VOCAB_SIZE_TARGET}")

# Instantiate TARGET-sized model
emb_dim = int(META_CFG["emb_dim"])
hidden_dim = int(META_CFG["hidden_dim"])
dropout = float(META_CFG["dropout"])

model = GRU4RecDropout(
    vocab_size=VOCAB_SIZE_TARGET,
    emb_dim=emb_dim,
    hidden_dim=hidden_dim,
    pad_id=PAD_ID,
    dropout=dropout,
).to(DEVICE)

# Build a filtered state_dict: keep only GRU weights (transferable across vocab)
filtered = {k: v for k, v in state_dict_src.items() if k.startswith("gru.")}
print(f"[{CELL}] Will load {len(filtered)} GRU keys. Example keys: {list(filtered.keys())[:6]}")

missing, unexpected = model.load_state_dict(filtered, strict=False)
print(f"[{CELL}] load_state_dict(gru-only strict=False): missing={len(missing)} unexpected={len(unexpected)}")
if missing[:10]:
    print(f"[{CELL}] missing sample: {missing[:10]}")
if unexpected[:10]:
    print(f"[{CELL}] unexpected sample: {unexpected[:10]}")

# Optional: copy PAD/UNK embedding rows if they exist in source (safe and tiny)
if "emb.weight" in state_dict_src:
    with torch.no_grad():
        src_emb = state_dict_src["emb.weight"]
        if src_emb.size(0) > max(PAD_ID, UNK_ID):
            model.emb.weight[PAD_ID].copy_(src_emb[PAD_ID])
            model.emb.weight[UNK_ID].copy_(src_emb[UNK_ID])
            print(f"[{CELL}] Copied emb rows for PAD_ID={PAD_ID} and UNK_ID={UNK_ID} from source checkpoint")

model.eval()
print(f"[{CELL}] ✅ Target model ready (gru transferred, emb/out target-sized).")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-06] Starting... 2026-01-05 01:38:01
[12C-06] ckpt keys: ['run_tag', 'task_run_dir', 'proto', 'task_cfg', 'meta_cfg', 'vocab_size_source', 'pad_id_source', 'unk_id_source', 'state_dict', 'best_step', 'best_val_hr20']
[12C-06] VOCAB_SIZE_SOURCE(from ckpt)=1620 | VOCAB_SIZE_TARGET=747
[12C-06] Will load 4 GRU keys. Example keys: ['gru.weight_ih_l0', 'gru.weight_hh_l0', 'gru.bias_ih_l0', 'gru.bias_hh_l0']
[12C-06] load_state_dict(gru-only strict=False): missing=3 unexpected=0
[12C-06] missing sample: ['emb.weight', 'out.weight', 'out.bias']
[12C-06] Copied emb rows for PAD_ID=0 and UNK_ID=1 from source checkpoint
[12C-06] ✅ Target model ready (gru transferred, emb/out target-sized).
[12C-06] Done in 0.01s


Evaluate META-INIT on TARGET (VAL/TEST), raw + exclude-seen

In [10]:
# [CELL 12C-07] Meta-init eval on TARGET (no adaptation)
import time, torch
from datetime import datetime

CELL = "12C-07"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

@torch.no_grad()
def eval_no_adapt(split_dict, batch_size=256):
    model.eval()
    N = int(split_dict["input_ids"].shape[0])

    logits_raw_list = []
    logits_mask_list = []
    y_list = []

    for start in range(0, N, batch_size):
        end = min(N, start+batch_size)
        idxs = list(range(start, end))
        x, l, y = get_batch(split_dict, idxs)

        logits = model(x, l)  # [B,V]
        if start == 0:
            print(f"[{CELL}] batch0: x={x.shape} l={l.shape} y={y.shape} logits={logits.shape}")

        logits_masked = mask_seen_logits_leftpad(logits, x, l, pad_id=PAD_ID)

        logits_raw_list.append(logits.cpu())
        logits_mask_list.append(logits_masked.cpu())
        y_list.append(y.cpu())

    logits_raw = torch.cat(logits_raw_list, dim=0)
    logits_masked = torch.cat(logits_mask_list, dim=0)
    y_cat = torch.cat(y_list, dim=0)

    res_raw = metrics_from_logits(logits_raw, y_cat, K_LIST)
    res_mask = metrics_from_logits(logits_masked, y_cat, K_LIST)

    return {
        "raw": {**res_raw, "_n_examples": N},
        "exclude_seen": {**res_mask, "_n_examples": N},
    }

VAL_META_INIT = eval_no_adapt(TARGET_VAL)
TEST_META_INIT = eval_no_adapt(TARGET_TEST)

print(f"[{CELL}] META-INIT VAL  exclude_seen HR@20={VAL_META_INIT['exclude_seen'].get('HR@20')} | raw HR@20={VAL_META_INIT['raw'].get('HR@20')}")
print(f"[{CELL}] META-INIT TEST exclude_seen HR@20={TEST_META_INIT['exclude_seen'].get('HR@20')} | raw HR@20={TEST_META_INIT['raw'].get('HR@20')}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-07] Starting... 2026-01-05 01:38:24
[12C-07] batch0: x=torch.Size([189, 20]) l=torch.Size([189]) y=torch.Size([189]) logits=torch.Size([189, 747])
[12C-07] batch0: x=torch.Size([200, 20]) l=torch.Size([200]) y=torch.Size([200]) logits=torch.Size([200, 747])
[12C-07] META-INIT VAL  exclude_seen HR@20=0.05820105969905853 | raw HR@20=0.05820105969905853
[12C-07] META-INIT TEST exclude_seen HR@20=0.05000000074505806 | raw HR@20=0.05000000074505806
[12C-07] Done in 0.07s


Meta-adapt on TARGET (deepcopy + SGD inner loop)

In [11]:
# [CELL 12C-08] Meta-adapt on TARGET: adapt on TRAIN support, eval on VAL/TEST query
import time, copy
import numpy as np
import torch.nn.functional as F
from datetime import datetime

CELL = "12C-08"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

INNER_STEPS = int(META_CFG.get("inner_steps", 1))
INNER_LR = float(META_CFG.get("inner_lr", 1e-3))
SUPPORT_SIZE = int(META_CFG.get("n_support", 32)) if "n_support" in META_CFG else 32
QUERY_SIZE = int(META_CFG.get("n_query", 64)) if "n_query" in META_CFG else 64
EPISODES_VAL = int(META_CFG.get("val_episodes", 100))
EPISODES_TEST = int(META_CFG.get("test_episodes", 200)) if "test_episodes" in META_CFG else 200

print(f"[{CELL}] inner_steps={INNER_STEPS} inner_lr={INNER_LR} support={SUPPORT_SIZE} query={QUERY_SIZE} episodes(val/test)={EPISODES_VAL}/{EPISODES_TEST}")

def sample_indices(N, k, rng):
    replace = N < k
    return rng.choice(N, size=k, replace=replace)

def adapt_on_support(base_model, support_idxs):
    m = copy.deepcopy(base_model)
    m.train()
    opt = torch.optim.SGD(m.parameters(), lr=INNER_LR)

    x, l, y = get_batch(TARGET_TRAIN, support_idxs)

    for s in range(INNER_STEPS):
        opt.zero_grad(set_to_none=True)
        with torch.enable_grad():
            logits = m(x, l)
            loss = F.cross_entropy(logits, y)  # labels are next-item ids (not PAD)
            loss.backward()
            opt.step()
        if s == 0:
            print(f"[{CELL}]   adapt step0 support_loss={loss.item():.6f} B={x.size(0)}")
    return m

@torch.no_grad()
def eval_meta_adapt(query_split, n_episodes, seed):
    rng = np.random.RandomState(seed)
    N_train = int(TARGET_TRAIN["input_ids"].shape[0])
    N_q = int(query_split["input_ids"].shape[0])

    logits_raw_list = []
    logits_mask_list = []
    y_list = []

    for ep in range(n_episodes):
        sup = sample_indices(N_train, SUPPORT_SIZE, rng)
        qry = sample_indices(N_q, QUERY_SIZE, rng)

        adapted = adapt_on_support(model, sup)

        xq, lq, yq = get_batch(query_split, qry)
        adapted.eval()
        logits = adapted(xq, lq)
        logits_masked = mask_seen_logits_leftpad(logits, xq, lq, pad_id=PAD_ID)

        logits_raw_list.append(logits.cpu())
        logits_mask_list.append(logits_masked.cpu())
        y_list.append(yq.cpu())

        if ep in [0, 1, 2, 9]:
            print(f"[{CELL}]   ep={ep+1}/{n_episodes} query_logits={tuple(logits.shape)}")

    logits_raw = torch.cat(logits_raw_list, dim=0)
    logits_masked = torch.cat(logits_mask_list, dim=0)
    y_cat = torch.cat(y_list, dim=0)

    res_raw = metrics_from_logits(logits_raw, y_cat, K_LIST)
    res_mask = metrics_from_logits(logits_masked, y_cat, K_LIST)

    return {
        "raw": {**res_raw, "_n_episodes": n_episodes, "_support": SUPPORT_SIZE, "_query": QUERY_SIZE, "_inner_steps": INNER_STEPS, "_inner_lr": INNER_LR},
        "exclude_seen": {**res_mask, "_n_episodes": n_episodes, "_support": SUPPORT_SIZE, "_query": QUERY_SIZE, "_inner_steps": INNER_STEPS, "_inner_lr": INNER_LR},
    }

VAL_META_ADAPT = eval_meta_adapt(TARGET_VAL, n_episodes=EPISODES_VAL, seed=SEED+10)
TEST_META_ADAPT = eval_meta_adapt(TARGET_TEST, n_episodes=EPISODES_TEST, seed=SEED+20)

print(f"[{CELL}] META-ADAPT VAL  exclude_seen HR@20={VAL_META_ADAPT['exclude_seen'].get('HR@20')} | raw HR@20={VAL_META_ADAPT['raw'].get('HR@20')}")
print(f"[{CELL}] META-ADAPT TEST exclude_seen HR@20={TEST_META_ADAPT['exclude_seen'].get('HR@20')} | raw HR@20={TEST_META_ADAPT['raw'].get('HR@20')}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-08] Starting... 2026-01-05 01:38:49
[12C-08] inner_steps=1 inner_lr=0.01 support=32 query=64 episodes(val/test)=50/200
[12C-08]   adapt step0 support_loss=6.610803 B=32
[12C-08]   ep=1/50 query_logits=(64, 747)
[12C-08]   adapt step0 support_loss=6.626634 B=32
[12C-08]   ep=2/50 query_logits=(64, 747)
[12C-08]   adapt step0 support_loss=6.621930 B=32
[12C-08]   ep=3/50 query_logits=(64, 747)
[12C-08]   adapt step0 support_loss=6.600657 B=32
[12C-08]   adapt step0 support_loss=6.594532 B=32
[12C-08]   adapt step0 support_loss=6.625181 B=32
[12C-08]   adapt step0 support_loss=6.605661 B=32
[12C-08]   adapt step0 support_loss=6.601851 B=32
[12C-08]   adapt step0 support_loss=6.615778 B=32
[12C-08]   adapt step0 support_loss=6.590286 B=32
[12C-08]   ep=10/50 query_logits=(64, 747)
[12C-08]   adapt step0 support_loss=6.602042 B=32
[12C-08]   adapt step0 support_loss=6.598908 B=32
[12C-08]   adapt step0 support_loss=6.619802 B=32
[12C-08]   adapt step0 support_loss=6.593378 B=32
[12C-08

In [18]:
# [CELL 12C-08B] Compact results summary (META-INIT vs META-ADAPT) + deltas
import time
from datetime import datetime

CELL = "12C-08B"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

def _delta(a, b, key):
    return float(b.get(key, 0.0) - a.get(key, 0.0))

for split_name, init_res, adapt_res in [
    ("VAL",  VAL_META_INIT["exclude_seen"],  VAL_META_ADAPT["exclude_seen"]),
    ("TEST", TEST_META_INIT["exclude_seen"], TEST_META_ADAPT["exclude_seen"]),
]:
    print(f"[{CELL}] {split_name} META-INIT  exclude_seen HR@20={init_res.get('HR@20'):.6f}")
    print(f"[{CELL}] {split_name} META-ADAPT exclude_seen HR@20={adapt_res.get('HR@20'):.6f}")
    print(f"[{CELL}] {split_name} ΔHR@20 = {_delta(init_res, adapt_res, 'HR@20'):.6f}")
    print("-" * 60)

print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-08B] Starting... 2026-01-05 01:47:45
[12C-08B] VAL META-INIT  exclude_seen HR@20=0.058201
[12C-08B] VAL META-ADAPT exclude_seen HR@20=0.054375
[12C-08B] VAL ΔHR@20 = -0.003826
------------------------------------------------------------
[12C-08B] TEST META-INIT  exclude_seen HR@20=0.050000
[12C-08B] TEST META-ADAPT exclude_seen HR@20=0.049531
[12C-08B] TEST ΔHR@20 = -0.000469
------------------------------------------------------------
[12C-08B] Done in 0.00s


In [19]:
# [CELL 12C-08C] Audit episodic sampling: replacement + unique coverage stats (VAL/TEST)
import numpy as np, time
from datetime import datetime

CELL = "12C-08C"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

def audit_sampling(N_q, n_episodes, query_size, seed):
    rng = np.random.RandomState(seed)
    all_q = []
    for _ in range(n_episodes):
        replace = N_q < query_size
        q = rng.choice(N_q, size=query_size, replace=replace)
        all_q.append(q)
    all_q = np.concatenate(all_q, axis=0)

    unique = np.unique(all_q)
    return {
        "N_q": int(N_q),
        "n_episodes": int(n_episodes),
        "query_size": int(query_size),
        "total_query_draws": int(n_episodes * query_size),
        "unique_query_indices": int(unique.size),
        "unique_coverage_ratio": float(unique.size / max(1, N_q)),
        "sampling_with_replacement": bool(N_q < query_size),
    }

N_val = int(TARGET_VAL["input_ids"].shape[0])
N_test = int(TARGET_TEST["input_ids"].shape[0])

audit_val  = audit_sampling(N_val,  EPISODES_VAL,  QUERY_SIZE, seed=SEED+10)
audit_test = audit_sampling(N_test, EPISODES_TEST, QUERY_SIZE, seed=SEED+20)

print(f"[{CELL}] VAL  audit: {audit_val}")
print(f"[{CELL}] TEST audit: {audit_test}")

if audit_val["total_query_draws"] > audit_val["N_q"]:
    print(f"[{CELL}] NOTE: VAL total draws exceed dataset size (expected for episodic eval).")
if audit_test["total_query_draws"] > audit_test["N_q"]:
    print(f"[{CELL}] NOTE: TEST total draws exceed dataset size (expected for episodic eval).")

print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-08C] Starting... 2026-01-05 01:48:07
[12C-08C] VAL  audit: {'N_q': 189, 'n_episodes': 50, 'query_size': 64, 'total_query_draws': 3200, 'unique_query_indices': 189, 'unique_coverage_ratio': 1.0, 'sampling_with_replacement': False}
[12C-08C] TEST audit: {'N_q': 200, 'n_episodes': 200, 'query_size': 64, 'total_query_draws': 12800, 'unique_query_indices': 200, 'unique_coverage_ratio': 1.0, 'sampling_with_replacement': False}
[12C-08C] NOTE: VAL total draws exceed dataset size (expected for episodic eval).
[12C-08C] NOTE: TEST total draws exceed dataset size (expected for episodic eval).
[12C-08C] Done in 0.02s


In [20]:
# [CELL 12C-08D] Exhaustive evaluation: single pass over full VAL/TEST (no resampling)
import time, torch
from datetime import datetime

CELL = "12C-08D"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

@torch.no_grad()
def eval_exhaustive(split_dict, batch_size=256):
    model.eval()
    N = int(split_dict["input_ids"].shape[0])

    logits_raw_list = []
    logits_mask_list = []
    y_list = []

    for start in range(0, N, batch_size):
        end = min(N, start+batch_size)
        idxs = list(range(start, end))
        x, l, y = get_batch(split_dict, idxs)

        logits = model(x, l)
        logits_masked = mask_seen_logits_leftpad(logits, x, l, pad_id=PAD_ID)

        logits_raw_list.append(logits.cpu())
        logits_mask_list.append(logits_masked.cpu())
        y_list.append(y.cpu())

    logits_raw = torch.cat(logits_raw_list, dim=0)
    logits_masked = torch.cat(logits_mask_list, dim=0)
    y_cat = torch.cat(y_list, dim=0)

    return {
        "raw": {**metrics_from_logits(logits_raw, y_cat, K_LIST), "_n_examples": N},
        "exclude_seen": {**metrics_from_logits(logits_masked, y_cat, K_LIST), "_n_examples": N},
    }

EXH_VAL  = eval_exhaustive(TARGET_VAL)
EXH_TEST = eval_exhaustive(TARGET_TEST)

print(f"[{CELL}] EXHAUSTIVE VAL  exclude_seen HR@20={EXH_VAL['exclude_seen'].get('HR@20'):.6f}")
print(f"[{CELL}] EXHAUSTIVE TEST exclude_seen HR@20={EXH_TEST['exclude_seen'].get('HR@20'):.6f}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-08D] Starting... 2026-01-05 01:48:45
[12C-08D] EXHAUSTIVE VAL  exclude_seen HR@20=0.058201
[12C-08D] EXHAUSTIVE TEST exclude_seen HR@20=0.050000
[12C-08D] Done in 0.07s


Save report + manifest

In [12]:
# [CELL 12C-09] Save report.json + manifest.json
import time
from datetime import datetime

CELL = "12C-09"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

report = {
    "kind": "12C_meta_adapt_and_eval_on_target",
    "run_tag": RUN_TAG,
    "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "source_12b_run_tag": RUN_TAG_12B,
    "proto": PROTO,
    "meta_cfg_from_12b": META_CFG,
    "target_meta_json": TARGET_META_JSON,
    "target_pts": {"train_pt": TRAIN_PT, "val_pt": VAL_PT, "test_pt": TEST_PT},
    "transfer_loading": {
        "note": "Loaded GRU weights only (vocab mismatch between source and target). emb/out are target-sized init; copied PAD/UNK emb rows if possible."
    },
    "target_eval": {
        "meta_init": {"val": VAL_META_INIT, "test": TEST_META_INIT},
        "meta_adapt": {"val": VAL_META_ADAPT, "test": TEST_META_ADAPT},
    },
}

report_path = os.path.join(REPORT_DIR, "report.json")
with open(report_path, "w", encoding="utf-8") as f:
    json.dump(report, f, indent=2)

manifest = {
    "run_tag": RUN_TAG,
    "report_dir": REPORT_DIR,
    "files": {"report.json": report_path},
    "sizes_bytes": {"report.json": os.path.getsize(report_path)},
    "updated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
manifest_path = os.path.join(REPORT_DIR, "manifest.json")
with open(manifest_path, "w", encoding="utf-8") as f:
    json.dump(manifest, f, indent=2)

print(f"[{CELL}] ✅ report.json   : {report_path}")
print(f"[{CELL}] ✅ manifest.json : {manifest_path}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-09] Starting... 2026-01-05 01:39:35
[12C-09] ✅ report.json   : C:\mooc-coldstart-session-meta\reports\12C_meta_adapt_and_eval_on_target\20260105_013327\report.json
[12C-09] ✅ manifest.json : C:\mooc-coldstart-session-meta\reports\12C_meta_adapt_and_eval_on_target\20260105_013327\manifest.json
[12C-09] Done in 0.00s


Update root meta.json (idempotent)

In [13]:
# [CELL 12C-10] Update meta.json with 12C run entry
import time
from datetime import datetime

CELL = "12C-10"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

meta_path = os.path.join(REPO_ROOT, "meta.json")
with open(meta_path, "r", encoding="utf-8") as f:
    meta = json.load(f)

entry = {
    "kind": "12C_meta_adapt_and_eval_on_target",
    "run_tag": RUN_TAG,
    "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "inputs": {
        "source_12b_run_tag": RUN_TAG_12B,
        "ckpt_meta_init": CKPT_META_INIT,
        "dataloader_config": DL_CFG_PATH,
        "target_meta_json": TARGET_META_JSON,
    },
    "report_dir": REPORT_DIR,
    "artifacts": {
        "report_json": os.path.join(REPORT_DIR, "report.json"),
        "manifest_json": os.path.join(REPORT_DIR, "manifest.json"),
    },
    "results": {
        "target_val_meta_init": VAL_META_INIT,
        "target_test_meta_init": TEST_META_INIT,
        "target_val_meta_adapt": VAL_META_ADAPT,
        "target_test_meta_adapt": TEST_META_ADAPT,
    },
    "updated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}

meta.setdefault("runs", [])
idx = None
for i, r in enumerate(meta["runs"]):
    if r.get("kind") == entry["kind"] and r.get("run_tag") == entry["run_tag"]:
        idx = i
        break

if idx is None:
    meta["runs"].append(entry)
    print(f"[{CELL}] ✅ Added 12C entry run_tag={RUN_TAG}")
else:
    meta["runs"][idx] = entry
    print(f"[{CELL}] ✅ Updated 12C entry run_tag={RUN_TAG}")

meta["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

with open(meta_path, "w", encoding="utf-8") as f:
    json.dump(meta, f, indent=2)

print(f"[{CELL}] ✅ Saved meta.json: {meta_path}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-10] Starting... 2026-01-05 01:40:01
[12C-10] ✅ Added 12C entry run_tag=20260105_013327
[12C-10] ✅ Saved meta.json: C:\mooc-coldstart-session-meta\meta.json
[12C-10] Done in 0.00s


Sanity: confirm exclude-seen masking changes logits for at least one batch

In [14]:
# [CELL 12C-11] Sanity: confirm exclude-seen masking changes logits for at least one batch
import torch, time
from datetime import datetime

CELL = "12C-11"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

x, l, y = get_batch(TARGET_VAL, list(range(min(32, TARGET_VAL["input_ids"].shape[0]))))
with torch.no_grad():
    logits = model(x, l)
    logits_m = mask_seen_logits_leftpad(logits, x, l, pad_id=PAD_ID)

diff = (logits - logits_m).abs()
num_changed = int((diff > 0).sum().item())
print(f"[{CELL}] num_changed_elements={num_changed} (should be > 0)")

# check how many positions were set to -inf
num_infs = int(torch.isinf(logits_m).sum().item())
print(f"[{CELL}] num_-inf_elements={num_infs} (should be > 0 if any seen tokens exist)")

print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-11] Starting... 2026-01-05 01:41:43
[12C-11] num_changed_elements=176 (should be > 0)
[12C-11] num_-inf_elements=176 (should be > 0 if any seen tokens exist)
[12C-11] Done in 0.02s


Config: head-only adaptation settings (freeze GRU)

In [21]:
# [CELL 12C-12] Config: head-only adaptation settings (freeze GRU)
import time
from datetime import datetime

CELL = "12C-12"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# Keep your previous episode settings for apples-to-apples comparison
INNER_STEPS_H = 3          # try 3 steps (more signal than 1)
INNER_LR_H = 0.05          # head can take higher LR
SUPPORT_SIZE_H = 64        # more support since target is small
QUERY_SIZE_H = 64
EPISODES_VAL_H = 100
EPISODES_TEST_H = 200

# Choose adaptation scope:
ADAPT_SCOPE = "out+emb"        # "out" or "out+emb"

print(f"[{CELL}] ADAPT_SCOPE={ADAPT_SCOPE}")
print(f"[{CELL}] inner_steps={INNER_STEPS_H} inner_lr={INNER_LR_H} support={SUPPORT_SIZE_H} query={QUERY_SIZE_H} episodes(val/test)={EPISODES_VAL_H}/{EPISODES_TEST_H}")

print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-12] Starting... 2026-01-05 01:53:58
[12C-12] ADAPT_SCOPE=out+emb
[12C-12] inner_steps=3 inner_lr=0.05 support=64 query=64 episodes(val/test)=100/200
[12C-12] Done in 0.00s


Meta-adapt eval (HEAD-ONLY): freeze GRU, adapt out (or out+emb)

In [22]:
# [CELL 12C-13] Meta-adapt eval (HEAD-ONLY): freeze GRU, adapt out (or out+emb)
import time, copy
import numpy as np
import torch
import torch.nn.functional as F
from datetime import datetime

CELL = "12C-13"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

def sample_indices(N, k, rng):
    replace = N < k
    return rng.choice(N, size=k, replace=replace)

def clone_for_adapt(base_model):
    # deepcopy keeps this simple & reproducible (slow but OK for small target)
    m = copy.deepcopy(base_model)

    # Freeze GRU always
    for p in m.gru.parameters():
        p.requires_grad = False

    # Set scope
    if ADAPT_SCOPE == "out+emb":
        for p in m.emb.parameters():
            p.requires_grad = False
        for p in m.out.parameters():
            p.requires_grad = True

        params = list(m.out.parameters())

    elif ADAPT_SCOPE == "out+emb":
        for p in m.emb.parameters():
            p.requires_grad = True
        for p in m.out.parameters():
            p.requires_grad = True

        params = list(m.emb.parameters()) + list(m.out.parameters())

    else:
        raise RuntimeError(f"[{CELL}] Unknown ADAPT_SCOPE={ADAPT_SCOPE}")

    # Debug: count trainable params
    n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in m.parameters())
    print(f"[{CELL}] clone_for_adapt: trainable={n_trainable:,}/{n_total:,} ({ADAPT_SCOPE})")

    return m, params

def adapt_head_on_support(base_model, support_idxs):
    m, params = clone_for_adapt(base_model)
    m.train()

    opt = torch.optim.SGD(params, lr=INNER_LR_H)

    x, l, y = get_batch(TARGET_TRAIN, support_idxs)
    for s in range(INNER_STEPS_H):
        opt.zero_grad(set_to_none=True)
        with torch.enable_grad():
            logits = m(x, l)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            opt.step()
        if s == 0:
            print(f"[{CELL}]   adapt step0 support_loss={loss.item():.6f}")
    return m

@torch.no_grad()
def eval_meta_adapt_head(query_split, n_episodes, seed):
    rng = np.random.RandomState(seed)
    N_train = int(TARGET_TRAIN["input_ids"].shape[0])
    N_q = int(query_split["input_ids"].shape[0])

    logits_raw_list = []
    logits_mask_list = []
    y_list = []

    for ep in range(n_episodes):
        sup = sample_indices(N_train, SUPPORT_SIZE_H, rng)
        qry = sample_indices(N_q, QUERY_SIZE_H, rng)

        adapted = adapt_head_on_support(model, sup)

        xq, lq, yq = get_batch(query_split, qry)
        adapted.eval()
        logits = adapted(xq, lq)
        logits_masked = mask_seen_logits_leftpad(logits, xq, lq, pad_id=PAD_ID)

        logits_raw_list.append(logits.cpu())
        logits_mask_list.append(logits_masked.cpu())
        y_list.append(yq.cpu())

        if ep in [0, 1, 2, 9]:
            print(f"[{CELL}]   ep={ep+1}/{n_episodes} query_logits={tuple(logits.shape)}")

    logits_raw = torch.cat(logits_raw_list, dim=0)
    logits_masked = torch.cat(logits_mask_list, dim=0)
    y_cat = torch.cat(y_list, dim=0)

    res_raw = metrics_from_logits(logits_raw, y_cat, K_LIST)
    res_mask = metrics_from_logits(logits_masked, y_cat, K_LIST)

    return {
        "raw": {**res_raw, "_n_episodes": n_episodes, "_support": SUPPORT_SIZE_H, "_query": QUERY_SIZE_H, "_inner_steps": INNER_STEPS_H, "_inner_lr": INNER_LR_H, "_scope": ADAPT_SCOPE},
        "exclude_seen": {**res_mask, "_n_episodes": n_episodes, "_support": SUPPORT_SIZE_H, "_query": QUERY_SIZE_H, "_inner_steps": INNER_STEPS_H, "_inner_lr": INNER_LR_H, "_scope": ADAPT_SCOPE},
    }

VAL_META_ADAPT_HEAD = eval_meta_adapt_head(TARGET_VAL, n_episodes=EPISODES_VAL_H, seed=SEED+110)
TEST_META_ADAPT_HEAD = eval_meta_adapt_head(TARGET_TEST, n_episodes=EPISODES_TEST_H, seed=SEED+120)

print(f"[{CELL}] HEAD-ADAPT VAL  exclude_seen HR@20={VAL_META_ADAPT_HEAD['exclude_seen'].get('HR@20')} | raw HR@20={VAL_META_ADAPT_HEAD['raw'].get('HR@20')}")
print(f"[{CELL}] HEAD-ADAPT TEST exclude_seen HR@20={TEST_META_ADAPT_HEAD['exclude_seen'].get('HR@20')} | raw HR@20={TEST_META_ADAPT_HEAD['raw'].get('HR@20')}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-13] Starting... 2026-01-05 01:54:04
[12C-13] clone_for_adapt: trainable=96,363/218,667 (out+emb)
[12C-13]   adapt step0 support_loss=6.624877
[12C-13]   ep=1/100 query_logits=(64, 747)
[12C-13] clone_for_adapt: trainable=96,363/218,667 (out+emb)
[12C-13]   adapt step0 support_loss=6.634795
[12C-13]   ep=2/100 query_logits=(64, 747)
[12C-13] clone_for_adapt: trainable=96,363/218,667 (out+emb)
[12C-13]   adapt step0 support_loss=6.628953
[12C-13]   ep=3/100 query_logits=(64, 747)
[12C-13] clone_for_adapt: trainable=96,363/218,667 (out+emb)
[12C-13]   adapt step0 support_loss=6.602334
[12C-13] clone_for_adapt: trainable=96,363/218,667 (out+emb)
[12C-13]   adapt step0 support_loss=6.652944
[12C-13] clone_for_adapt: trainable=96,363/218,667 (out+emb)
[12C-13]   adapt step0 support_loss=6.638491
[12C-13] clone_for_adapt: trainable=96,363/218,667 (out+emb)
[12C-13]   adapt step0 support_loss=6.615194
[12C-13] clone_for_adapt: trainable=96,363/218,667 (out+emb)
[12C-13]   adapt step0 supp

Save head-adapt addendum (does not overwrite prior report)

In [23]:
# [CELL 12C-14] Save head-adapt addendum (does not overwrite prior report)
import os, json, time
from datetime import datetime

CELL = "12C-14"
t0 = time.time()
print(f"[{CELL}] Starting... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

addendum = {
    "kind": "12C_head_only_meta_adapt_addendum",
    "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "base_run_tag": RUN_TAG,
    "scope": ADAPT_SCOPE,
    "cfg": {
        "inner_steps": INNER_STEPS_H,
        "inner_lr": INNER_LR_H,
        "support": SUPPORT_SIZE_H,
        "query": QUERY_SIZE_H,
        "episodes_val": EPISODES_VAL_H,
        "episodes_test": EPISODES_TEST_H,
        "seed_base": SEED,
    },
    "results": {
        "val": VAL_META_ADAPT_HEAD,
        "test": TEST_META_ADAPT_HEAD,
    },
    "notes": [
        "GRU frozen. Adapted head only (out) or head+emb depending on scope.",
        "This is a targeted test because source->target vocab mismatch prevents loading emb/out from source.",
    ],
}

path = os.path.join(REPORT_DIR, f"head_adapt_addendum_{ADAPT_SCOPE.replace('+','_')}.json")
with open(path, "w", encoding="utf-8") as f:
    json.dump(addendum, f, indent=2)

print(f"[{CELL}] ✅ Wrote addendum: {path}")
print(f"[{CELL}] Done in {time.time()-t0:.2f}s")


[12C-14] Starting... 2026-01-05 01:54:39
[12C-14] ✅ Wrote addendum: C:\mooc-coldstart-session-meta\reports\12C_meta_adapt_and_eval_on_target\20260105_013327\head_adapt_addendum_out_emb.json
[12C-14] Done in 0.00s
