# Notebook 04: User Split (XuetangX)

**Purpose:** Create deterministic user-level split (80/10/10) for cold-start evaluation.

**Cold-Start Focus:**
- **Disjoint users** across train/val/test (no user appears in multiple splits)
- Test users are **completely unseen** during training (true cold-start)
- Split pairs by user assignment (user's ALL pairs go to same split)

**Inputs:**
- `data/processed/xuetangx/pairs/pairs.parquet` (264K pairs, 42K users)

**Outputs:**
- `data/processed/xuetangx/user_splits/users_train.json` (80% of users)
- `data/processed/xuetangx/user_splits/users_val.json` (10% of users)
- `data/processed/xuetangx/user_splits/users_test.json` (10% of users)
- `data/processed/xuetangx/pairs/pairs_train.parquet`
- `data/processed/xuetangx/pairs/pairs_val.parquet`
- `data/processed/xuetangx/pairs/pairs_test.parquet`
- DuckDB views: `xuetangx_pairs_train`, `xuetangx_pairs_val`, `xuetangx_pairs_test`
- `reports/04_user_split_xuetangx/<run_tag>/report.json`

**Strategy:**
1. Load all pairs from Notebook 03
2. Extract unique users, sort deterministically (alphabetical)
3. Split users: 80% train, 10% val, 10% test (seeded random shuffle)
4. Assign all pairs for each user to corresponding split
5. Save user lists + split pairs separately
6. Validate: no user overlap, all pairs assigned

In [1]:
# [CELL 04-00] Bootstrap: repo root + paths + logger

import os
import sys
import json
import time
import uuid
import hashlib
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List

import numpy as np
import pandas as pd

t0 = datetime.now()
print(f"[CELL 04-00] start={t0.isoformat(timespec='seconds')}")
print("[CELL 04-00] CWD:", Path.cwd().resolve())

def find_repo_root(start: Path) -> Path:
    start = start.resolve()
    for p in [start, *start.parents]:
        if (p / "PROJECT_STATE.md").exists():
            return p
    raise RuntimeError("Could not find PROJECT_STATE.md. Open notebook from within the repo.")

REPO_ROOT = find_repo_root(Path.cwd())
print("[CELL 04-00] REPO_ROOT:", REPO_ROOT)

PATHS = {
    "META_REGISTRY": REPO_ROOT / "meta.json",
    "DATA_INTERIM": REPO_ROOT / "data" / "interim",
    "DATA_PROCESSED": REPO_ROOT / "data" / "processed",
    "REPORTS": REPO_ROOT / "reports",
}
for k, v in PATHS.items():
    print(f"[CELL 04-00] {k}={v}")

def cell_start(cell_id: str, title: str, **kwargs: Any) -> float:
    t = time.time()
    print(f"\n[{cell_id}] {title}")
    print(f"[{cell_id}] start={datetime.now().isoformat(timespec='seconds')}")
    for k, v in kwargs.items():
        print(f"[{cell_id}] {k}={v}")
    return t

def cell_end(cell_id: str, t0: float, **kwargs: Any) -> None:
    for k, v in kwargs.items():
        print(f"[{cell_id}] {k}={v}")
    print(f"[{cell_id}] elapsed={time.time()-t0:.2f}s")
    print(f"[{cell_id}] done")

print("[CELL 04-00] done")

[CELL 04-00] start=2026-01-07T14:40:34
[CELL 04-00] CWD: C:\anonymous-users-mooc-session-meta\notebooks
[CELL 04-00] REPO_ROOT: C:\anonymous-users-mooc-session-meta
[CELL 04-00] META_REGISTRY=C:\anonymous-users-mooc-session-meta\meta.json
[CELL 04-00] DATA_INTERIM=C:\anonymous-users-mooc-session-meta\data\interim
[CELL 04-00] DATA_PROCESSED=C:\anonymous-users-mooc-session-meta\data\processed
[CELL 04-00] REPORTS=C:\anonymous-users-mooc-session-meta\reports
[CELL 04-00] done


In [2]:
# [CELL 04-01] Reproducibility: seed everything

t0 = cell_start("CELL 04-01", "Seed everything")

GLOBAL_SEED = 20260107

def seed_everything(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)

seed_everything(GLOBAL_SEED)

cell_end("CELL 04-01", t0, seed=GLOBAL_SEED)


[CELL 04-01] Seed everything
[CELL 04-01] start=2026-01-07T14:40:34
[CELL 04-01] seed=20260107
[CELL 04-01] elapsed=0.00s
[CELL 04-01] done


In [3]:
# [CELL 04-02] JSON IO + hashing helpers

t0 = cell_start("CELL 04-02", "JSON IO + hashing")

def write_json_atomic(path: Path, obj: Any, indent: int = 2) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    tmp = path.with_suffix(path.suffix + f".tmp_{uuid.uuid4().hex}")
    with tmp.open("w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=indent)
    tmp.replace(path)

def read_json(path: Path) -> Any:
    if not path.exists():
        raise RuntimeError(f"Missing JSON file: {path}")
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)

def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
    h = hashlib.sha256()
    with path.open("rb") as f:
        while True:
            b = f.read(chunk_size)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

cell_end("CELL 04-02", t0)


[CELL 04-02] JSON IO + hashing
[CELL 04-02] start=2026-01-07T14:40:34
[CELL 04-02] elapsed=0.00s
[CELL 04-02] done


In [4]:
# [CELL 04-03] Run tagging + report/config/manifest + meta.json

t0 = cell_start("CELL 04-03", "Start run + init files + meta.json")

NOTEBOOK_NAME = "04_user_split_xuetangx"
RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_ID = uuid.uuid4().hex

OUT_DIR = PATHS["REPORTS"] / NOTEBOOK_NAME / RUN_TAG
OUT_DIR.mkdir(parents=True, exist_ok=True)

REPORT_PATH = OUT_DIR / "report.json"
CONFIG_PATH = OUT_DIR / "config.json"
MANIFEST_PATH = OUT_DIR / "manifest.json"

DUCKDB_PATH = PATHS["DATA_INTERIM"] / "xuetangx.duckdb"
PAIRS_PARQUET = PATHS["DATA_PROCESSED"] / "xuetangx" / "pairs" / "pairs.parquet"

USER_SPLITS_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "user_splits"
PAIRS_DIR = PATHS["DATA_PROCESSED"] / "xuetangx" / "pairs"
USER_SPLITS_DIR.mkdir(parents=True, exist_ok=True)

OUT_USERS_TRAIN = USER_SPLITS_DIR / "users_train.json"
OUT_USERS_VAL = USER_SPLITS_DIR / "users_val.json"
OUT_USERS_TEST = USER_SPLITS_DIR / "users_test.json"

OUT_PAIRS_TRAIN = PAIRS_DIR / "pairs_train.parquet"
OUT_PAIRS_VAL = PAIRS_DIR / "pairs_val.parquet"
OUT_PAIRS_TEST = PAIRS_DIR / "pairs_test.parquet"

CFG = {
    "notebook": NOTEBOOK_NAME,
    "run_id": RUN_ID,
    "run_tag": RUN_TAG,
    "seed": GLOBAL_SEED,
    "inputs": {
        "pairs": str(PAIRS_PARQUET),
    },
    "outputs": {
        "users_train": str(OUT_USERS_TRAIN),
        "users_val": str(OUT_USERS_VAL),
        "users_test": str(OUT_USERS_TEST),
        "pairs_train": str(OUT_PAIRS_TRAIN),
        "pairs_val": str(OUT_PAIRS_VAL),
        "pairs_test": str(OUT_PAIRS_TEST),
        "out_dir": str(OUT_DIR),
    },
    "split": {
        "train_ratio": 0.8,
        "val_ratio": 0.1,
        "test_ratio": 0.1,
        "strategy": "user_level_disjoint",  # no user overlap between splits
        "ordering": "alphabetical_then_shuffle",  # deterministic user ordering before split
    }
}

write_json_atomic(CONFIG_PATH, CFG)

report = {
    "run_id": RUN_ID,
    "notebook": NOTEBOOK_NAME,
    "run_tag": RUN_TAG,
    "created_at": datetime.now().isoformat(timespec="seconds"),
    "repo_root": str(REPO_ROOT),
    "metrics": {},
    "key_findings": [],
    "sanity_samples": {},
    "data_fingerprints": {},
    "notes": [],
}
write_json_atomic(REPORT_PATH, report)

manifest = {"run_id": RUN_ID, "notebook": NOTEBOOK_NAME, "run_tag": RUN_TAG, "artifacts": []}
write_json_atomic(MANIFEST_PATH, manifest)

# meta.json append-only
META_PATH = PATHS["META_REGISTRY"]
if not META_PATH.exists():
    write_json_atomic(META_PATH, {"schema_version": 1, "runs": []})
meta = read_json(META_PATH)
meta["runs"].append({
    "run_id": RUN_ID,
    "notebook": NOTEBOOK_NAME,
    "run_tag": RUN_TAG,
    "out_dir": str(OUT_DIR),
    "created_at": datetime.now().isoformat(timespec="seconds"),
})
write_json_atomic(META_PATH, meta)

cell_end("CELL 04-03", t0, out_dir=str(OUT_DIR))


[CELL 04-03] Start run + init files + meta.json
[CELL 04-03] start=2026-01-07T14:40:34
[CELL 04-03] out_dir=C:\anonymous-users-mooc-session-meta\reports\04_user_split_xuetangx\20260107_144034
[CELL 04-03] elapsed=0.01s
[CELL 04-03] done


In [5]:
# [CELL 04-04] Load pairs from Notebook 03

t0 = cell_start("CELL 04-04", "Load pairs", pairs=str(PAIRS_PARQUET))

if not PAIRS_PARQUET.exists():
    raise RuntimeError(f"Missing pairs.parquet: {PAIRS_PARQUET}. Run Notebook 03 first.")

pairs_df = pd.read_parquet(PAIRS_PARQUET)

print(f"[CELL 04-04] Loaded pairs shape: {pairs_df.shape}")
print(f"[CELL 04-04] Columns: {list(pairs_df.columns)}")
print(f"\n[CELL 04-04] Head(3):")
print(pairs_df.head(3).to_string(index=False))

cell_end("CELL 04-04", t0, n_pairs=int(len(pairs_df)))


[CELL 04-04] Load pairs
[CELL 04-04] start=2026-01-07T14:40:34
[CELL 04-04] pairs=C:\anonymous-users-mooc-session-meta\data\processed\xuetangx\pairs\pairs.parquet
[CELL 04-04] Loaded pairs shape: (264229, 7)
[CELL 04-04] Columns: ['pair_id', 'user_id', 'session_id', 'prefix', 'label', 'label_ts_epoch', 'prefix_len']

[CELL 04-04] Head(3):
 pair_id user_id                               session_id          prefix  label  label_ts_epoch  prefix_len
       0 1000009 1000009_95163f59939941d9fd47d6c9b17fdaf6           [107]    133      1443241561           1
       1 1000009 1000009_95163f59939941d9fd47d6c9b17fdaf6      [107, 133]    334      1443242059           2
       2 1000009 1000009_95163f59939941d9fd47d6c9b17fdaf6 [107, 133, 334]    297      1443242155           3
[CELL 04-04] n_pairs=264229
[CELL 04-04] elapsed=0.22s
[CELL 04-04] done


In [6]:
# [CELL 04-05] Extract unique users (deterministic ordering)

t0 = cell_start("CELL 04-05", "Extract unique users")

# Extract unique users and sort alphabetically (deterministic)
unique_users = sorted(pairs_df["user_id"].unique())

n_users_total = len(unique_users)
print(f"[CELL 04-05] Total unique users: {n_users_total:,}")
print(f"[CELL 04-05] First 5 users: {unique_users[:5]}")
print(f"[CELL 04-05] Last 5 users: {unique_users[-5:]}")

# Verify: all pairs have a user_id
n_missing = pairs_df["user_id"].isna().sum()
if n_missing > 0:
    raise RuntimeError(f"Found {n_missing} pairs with missing user_id")

cell_end("CELL 04-05", t0, n_users=n_users_total)


[CELL 04-05] Extract unique users
[CELL 04-05] start=2026-01-07T14:40:34
[CELL 04-05] Total unique users: 42,171
[CELL 04-05] First 5 users: ['1000009', '1000019', '1000066', '1000080', '1000085']
[CELL 04-05] Last 5 users: ['999956', '999958', '999970', '999981', '999996']
[CELL 04-05] n_users=42171
[CELL 04-05] elapsed=0.02s
[CELL 04-05] done


In [7]:
# [CELL 04-06] Shuffle users (seeded) and split 80/10/10

t0 = cell_start("CELL 04-06", "Split users (80/10/10)")

# Shuffle users with seed for reproducibility
rng = np.random.RandomState(GLOBAL_SEED)
shuffled_users = np.array(unique_users)
rng.shuffle(shuffled_users)

# Split indices: 80% train, 10% val, 10% test
n_train = int(n_users_total * CFG["split"]["train_ratio"])
n_val = int(n_users_total * CFG["split"]["val_ratio"])
# Remaining goes to test (handles rounding)
n_test = n_users_total - n_train - n_val

users_train = shuffled_users[:n_train].tolist()
users_val = shuffled_users[n_train:n_train+n_val].tolist()
users_test = shuffled_users[n_train+n_val:].tolist()

print(f"[CELL 04-06] Shuffled users with seed={GLOBAL_SEED}")
print(f"[CELL 04-06] Split sizes:")
print(f"  Train: {len(users_train):,} users ({len(users_train)/n_users_total*100:.1f}%)")
print(f"  Val:   {len(users_val):,} users ({len(users_val)/n_users_total*100:.1f}%)")
print(f"  Test:  {len(users_test):,} users ({len(users_test)/n_users_total*100:.1f}%)")
print(f"  Total: {len(users_train) + len(users_val) + len(users_test):,} users")

# Validation: no overlap
train_set = set(users_train)
val_set = set(users_val)
test_set = set(users_test)

overlap_train_val = train_set & val_set
overlap_train_test = train_set & test_set
overlap_val_test = val_set & test_set

if overlap_train_val or overlap_train_test or overlap_val_test:
    raise RuntimeError(f"User overlap detected: train-val={len(overlap_train_val)}, train-test={len(overlap_train_test)}, val-test={len(overlap_val_test)}")

print(f"\n[CELL 04-06] [OK] No user overlap between splits")

cell_end("CELL 04-06", t0, n_train=len(users_train), n_val=len(users_val), n_test=len(users_test))


[CELL 04-06] Split users (80/10/10)
[CELL 04-06] start=2026-01-07T14:40:34
[CELL 04-06] Shuffled users with seed=20260107
[CELL 04-06] Split sizes:
  Train: 33,736 users (80.0%)
  Val:   4,217 users (10.0%)
  Test:  4,218 users (10.0%)
  Total: 42,171 users

[CELL 04-06] [OK] No user overlap between splits
[CELL 04-06] n_train=33736
[CELL 04-06] n_val=4217
[CELL 04-06] n_test=4218
[CELL 04-06] elapsed=0.01s
[CELL 04-06] done


In [8]:
# [CELL 04-07] Save user splits

t0 = cell_start("CELL 04-07", "Save user splits")

write_json_atomic(OUT_USERS_TRAIN, users_train)
write_json_atomic(OUT_USERS_VAL, users_val)
write_json_atomic(OUT_USERS_TEST, users_test)

users_train_sha = sha256_file(OUT_USERS_TRAIN)
users_val_sha = sha256_file(OUT_USERS_VAL)
users_test_sha = sha256_file(OUT_USERS_TEST)

print(f"[CELL 04-07] Saved: {OUT_USERS_TRAIN.name} (SHA256: {users_train_sha[:16]}...)")
print(f"[CELL 04-07] Saved: {OUT_USERS_VAL.name} (SHA256: {users_val_sha[:16]}...)")
print(f"[CELL 04-07] Saved: {OUT_USERS_TEST.name} (SHA256: {users_test_sha[:16]}...)")

cell_end("CELL 04-07", t0)


[CELL 04-07] Save user splits
[CELL 04-07] start=2026-01-07T14:40:34
[CELL 04-07] Saved: users_train.json (SHA256: 698c7ef143352601...)
[CELL 04-07] Saved: users_val.json (SHA256: d3d23ca226c91d04...)
[CELL 04-07] Saved: users_test.json (SHA256: ca62c227326741e9...)
[CELL 04-07] elapsed=0.06s
[CELL 04-07] done


In [9]:
# [CELL 04-08] Assign pairs to splits based on user_id

t0 = cell_start("CELL 04-08", "Assign pairs to splits")

# Create lookup sets for fast membership testing
train_set = set(users_train)
val_set = set(users_val)
test_set = set(users_test)

# Assign each pair to a split based on user_id
pairs_train = pairs_df[pairs_df["user_id"].isin(train_set)].copy()
pairs_val = pairs_df[pairs_df["user_id"].isin(val_set)].copy()
pairs_test = pairs_df[pairs_df["user_id"].isin(test_set)].copy()

print(f"[CELL 04-08] Pairs assigned to splits:")
print(f"  Train: {len(pairs_train):,} pairs ({len(pairs_train)/len(pairs_df)*100:.1f}%)")
print(f"  Val:   {len(pairs_val):,} pairs ({len(pairs_val)/len(pairs_df)*100:.1f}%)")
print(f"  Test:  {len(pairs_test):,} pairs ({len(pairs_test)/len(pairs_df)*100:.1f}%)")
print(f"  Total: {len(pairs_train) + len(pairs_val) + len(pairs_test):,} pairs")

# Validation: all pairs assigned
n_total_assigned = len(pairs_train) + len(pairs_val) + len(pairs_test)
if n_total_assigned != len(pairs_df):
    raise RuntimeError(f"Not all pairs assigned: {n_total_assigned} != {len(pairs_df)}")

print(f"\n[CELL 04-08] [OK] All pairs assigned to exactly one split")

# User stats per split
n_users_train = pairs_train["user_id"].nunique()
n_users_val = pairs_val["user_id"].nunique()
n_users_test = pairs_test["user_id"].nunique()

print(f"\n[CELL 04-08] Unique users in pairs:")
print(f"  Train: {n_users_train:,} users")
print(f"  Val:   {n_users_val:,} users")
print(f"  Test:  {n_users_test:,} users")

# Verify user counts match original split
if n_users_train != len(users_train) or n_users_val != len(users_val) or n_users_test != len(users_test):
    raise RuntimeError(f"User count mismatch in pairs splits")

print(f"\n[CELL 04-08] [OK] User counts match original split")

cell_end("CELL 04-08", t0, 
         n_pairs_train=int(len(pairs_train)), 
         n_pairs_val=int(len(pairs_val)), 
         n_pairs_test=int(len(pairs_test)))


[CELL 04-08] Assign pairs to splits
[CELL 04-08] start=2026-01-07T14:40:34
[CELL 04-08] Pairs assigned to splits:
  Train: 212,923 pairs (80.6%)
  Val:   24,698 pairs (9.3%)
  Test:  26,608 pairs (10.1%)
  Total: 264,229 pairs

[CELL 04-08] [OK] All pairs assigned to exactly one split

[CELL 04-08] Unique users in pairs:
  Train: 33,736 users
  Val:   4,217 users
  Test:  4,218 users

[CELL 04-08] [OK] User counts match original split
[CELL 04-08] n_pairs_train=212923
[CELL 04-08] n_pairs_val=24698
[CELL 04-08] n_pairs_test=26608
[CELL 04-08] elapsed=0.09s
[CELL 04-08] done


In [10]:
# [CELL 04-09] Save split pairs

t0 = cell_start("CELL 04-09", "Save split pairs")

pairs_train.to_parquet(OUT_PAIRS_TRAIN, index=False, compression="zstd")
pairs_val.to_parquet(OUT_PAIRS_VAL, index=False, compression="zstd")
pairs_test.to_parquet(OUT_PAIRS_TEST, index=False, compression="zstd")

pairs_train_bytes = int(OUT_PAIRS_TRAIN.stat().st_size)
pairs_val_bytes = int(OUT_PAIRS_VAL.stat().st_size)
pairs_test_bytes = int(OUT_PAIRS_TEST.stat().st_size)

pairs_train_sha = sha256_file(OUT_PAIRS_TRAIN)
pairs_val_sha = sha256_file(OUT_PAIRS_VAL)
pairs_test_sha = sha256_file(OUT_PAIRS_TEST)

print(f"[CELL 04-09] Saved: {OUT_PAIRS_TRAIN.name} ({pairs_train_bytes / 1024 / 1024:.1f} MB, SHA256: {pairs_train_sha[:16]}...)")
print(f"[CELL 04-09] Saved: {OUT_PAIRS_VAL.name} ({pairs_val_bytes / 1024 / 1024:.1f} MB, SHA256: {pairs_val_sha[:16]}...)")
print(f"[CELL 04-09] Saved: {OUT_PAIRS_TEST.name} ({pairs_test_bytes / 1024 / 1024:.1f} MB, SHA256: {pairs_test_sha[:16]}...)")

cell_end("CELL 04-09", t0)


[CELL 04-09] Save split pairs
[CELL 04-09] start=2026-01-07T14:40:34
[CELL 04-09] Saved: pairs_train.parquet (4.0 MB, SHA256: a58c082770686850...)
[CELL 04-09] Saved: pairs_val.parquet (0.5 MB, SHA256: 7ab45f2fcf94e9c7...)
[CELL 04-09] Saved: pairs_test.parquet (0.6 MB, SHA256: 047794fc5de7ae7b...)
[CELL 04-09] elapsed=0.35s
[CELL 04-09] done


In [11]:
# [CELL 04-10] Register DuckDB views for split pairs

t0 = cell_start("CELL 04-10", "Register DuckDB views", duckdb=str(DUCKDB_PATH))

import duckdb

con = duckdb.connect(str(DUCKDB_PATH), read_only=False)

# Helper to escape Windows paths for DuckDB
def esc_path(p: Path) -> str:
    return str(p).replace("'", "''")

# Drop existing views
con.execute("DROP VIEW IF EXISTS xuetangx_pairs_train;")
con.execute("DROP VIEW IF EXISTS xuetangx_pairs_val;")
con.execute("DROP VIEW IF EXISTS xuetangx_pairs_test;")

# Create views
con.execute(f"""
CREATE VIEW xuetangx_pairs_train AS
SELECT * FROM read_parquet('{esc_path(OUT_PAIRS_TRAIN)}')
""")

con.execute(f"""
CREATE VIEW xuetangx_pairs_val AS
SELECT * FROM read_parquet('{esc_path(OUT_PAIRS_VAL)}')
""")

con.execute(f"""
CREATE VIEW xuetangx_pairs_test AS
SELECT * FROM read_parquet('{esc_path(OUT_PAIRS_TEST)}')
""")

n_train = int(con.execute("SELECT COUNT(*) FROM xuetangx_pairs_train").fetchone()[0])
n_val = int(con.execute("SELECT COUNT(*) FROM xuetangx_pairs_val").fetchone()[0])
n_test = int(con.execute("SELECT COUNT(*) FROM xuetangx_pairs_test").fetchone()[0])

print(f"[CELL 04-10] View xuetangx_pairs_train: {n_train:,} rows")
print(f"[CELL 04-10] View xuetangx_pairs_val: {n_val:,} rows")
print(f"[CELL 04-10] View xuetangx_pairs_test: {n_test:,} rows")

con.close()
print(f"[CELL 04-10] Closed DuckDB connection")

cell_end("CELL 04-10", t0)


[CELL 04-10] Register DuckDB views
[CELL 04-10] start=2026-01-07T14:40:35
[CELL 04-10] duckdb=C:\anonymous-users-mooc-session-meta\data\interim\xuetangx.duckdb
[CELL 04-10] View xuetangx_pairs_train: 212,923 rows
[CELL 04-10] View xuetangx_pairs_val: 24,698 rows
[CELL 04-10] View xuetangx_pairs_test: 26,608 rows
[CELL 04-10] Closed DuckDB connection
[CELL 04-10] elapsed=0.12s
[CELL 04-10] done


In [12]:
# [CELL 04-11] Validation: cold-start eligibility check

t0 = cell_start("CELL 04-11", "Cold-start eligibility check")

# For K-shot learning, we need users with >=K+Q pairs
# From Notebook 02, we know K=5,Q=10 means >=15 pairs minimum

K_CONFIGS = [(5, 10), (10, 20)]  # (K support, Q query)

print(f"[CELL 04-11] Cold-start eligibility (pairs per user):")
print(f"\n[CELL 04-11] Train split:")
train_pair_counts = pairs_train.groupby("user_id").size()
print(f"  Min pairs/user: {train_pair_counts.min()}")
print(f"  p50 pairs/user: {train_pair_counts.quantile(0.50):.0f}")
print(f"  p90 pairs/user: {train_pair_counts.quantile(0.90):.0f}")
print(f"  Max pairs/user: {train_pair_counts.max()}")

print(f"\n[CELL 04-11] Val split:")
val_pair_counts = pairs_val.groupby("user_id").size()
print(f"  Min pairs/user: {val_pair_counts.min()}")
print(f"  p50 pairs/user: {val_pair_counts.quantile(0.50):.0f}")
print(f"  p90 pairs/user: {val_pair_counts.quantile(0.90):.0f}")
print(f"  Max pairs/user: {val_pair_counts.max()}")

print(f"\n[CELL 04-11] Test split:")
test_pair_counts = pairs_test.groupby("user_id").size()
print(f"  Min pairs/user: {test_pair_counts.min()}")
print(f"  p50 pairs/user: {test_pair_counts.quantile(0.50):.0f}")
print(f"  p90 pairs/user: {test_pair_counts.quantile(0.90):.0f}")
print(f"  Max pairs/user: {test_pair_counts.max()}")

print(f"\n[CELL 04-11] Eligibility by K+Q budget:")
for K, Q in K_CONFIGS:
    min_pairs = K + Q
    n_eligible_train = (train_pair_counts >= min_pairs).sum()
    n_eligible_val = (val_pair_counts >= min_pairs).sum()
    n_eligible_test = (test_pair_counts >= min_pairs).sum()
    print(f"  K={K}, Q={Q} (min {min_pairs} pairs):")
    print(f"    Train: {n_eligible_train:,}/{len(users_train):,} users ({n_eligible_train/len(users_train)*100:.1f}%)")
    print(f"    Val:   {n_eligible_val:,}/{len(users_val):,} users ({n_eligible_val/len(users_val)*100:.1f}%)")
    print(f"    Test:  {n_eligible_test:,}/{len(users_test):,} users ({n_eligible_test/len(users_test)*100:.1f}%)")

cell_end("CELL 04-11", t0)


[CELL 04-11] Cold-start eligibility check
[CELL 04-11] start=2026-01-07T14:40:35
[CELL 04-11] Cold-start eligibility (pairs per user):

[CELL 04-11] Train split:
  Min pairs/user: 1
  p50 pairs/user: 2
  p90 pairs/user: 13
  Max pairs/user: 1318

[CELL 04-11] Val split:
  Min pairs/user: 1
  p50 pairs/user: 2
  p90 pairs/user: 12
  Max pairs/user: 377

[CELL 04-11] Test split:
  Min pairs/user: 1
  p50 pairs/user: 2
  p90 pairs/user: 13
  Max pairs/user: 564

[CELL 04-11] Eligibility by K+Q budget:
  K=5, Q=10 (min 15 pairs):
    Train: 3,006/33,736 users (8.9%)
    Val:   340/4,217 users (8.1%)
    Test:  346/4,218 users (8.2%)
  K=10, Q=20 (min 30 pairs):
    Train: 1,070/33,736 users (3.2%)
    Val:   120/4,217 users (2.8%)
    Test:  139/4,218 users (3.3%)
[CELL 04-11] elapsed=0.04s
[CELL 04-11] done


In [13]:
# [CELL 04-12] Update report + manifest

t0 = cell_start("CELL 04-12", "Write report + manifest")

report = read_json(REPORT_PATH)
manifest = read_json(MANIFEST_PATH)

# Metrics
report["metrics"] = {
    "n_users_total": n_users_total,
    "n_users_train": len(users_train),
    "n_users_val": len(users_val),
    "n_users_test": len(users_test),
    "n_pairs_total": int(len(pairs_df)),
    "n_pairs_train": int(len(pairs_train)),
    "n_pairs_val": int(len(pairs_val)),
    "n_pairs_test": int(len(pairs_test)),
    "split_ratios": {
        "train": CFG["split"]["train_ratio"],
        "val": CFG["split"]["val_ratio"],
        "test": CFG["split"]["test_ratio"],
    },
    "pairs_per_user": {
        "train": {
            "min": int(train_pair_counts.min()),
            "p50": float(train_pair_counts.quantile(0.50)),
            "p90": float(train_pair_counts.quantile(0.90)),
            "max": int(train_pair_counts.max()),
        },
        "val": {
            "min": int(val_pair_counts.min()),
            "p50": float(val_pair_counts.quantile(0.50)),
            "p90": float(val_pair_counts.quantile(0.90)),
            "max": int(val_pair_counts.max()),
        },
        "test": {
            "min": int(test_pair_counts.min()),
            "p50": float(test_pair_counts.quantile(0.50)),
            "p90": float(test_pair_counts.quantile(0.90)),
            "max": int(test_pair_counts.max()),
        },
    },
}

# Key findings
report["key_findings"].append(
    f"Created deterministic user-level split: {len(users_train):,} train ({len(users_train)/n_users_total*100:.1f}%), "
    f"{len(users_val):,} val ({len(users_val)/n_users_total*100:.1f}%), "
    f"{len(users_test):,} test ({len(users_test)/n_users_total*100:.1f}%) users. "
    f"No user overlap between splits (cold-start guarantee)."
)

report["key_findings"].append(
    f"Pairs distributed: {len(pairs_train):,} train ({len(pairs_train)/len(pairs_df)*100:.1f}%), "
    f"{len(pairs_val):,} val ({len(pairs_val)/len(pairs_df)*100:.1f}%), "
    f"{len(pairs_test):,} test ({len(pairs_test)/len(pairs_df)*100:.1f}%). "
    f"All pairs assigned to exactly one split based on user_id."
)

# Sanity samples
report["sanity_samples"]["users_train_head5"] = users_train[:5]
report["sanity_samples"]["users_val_head5"] = users_val[:5]
report["sanity_samples"]["users_test_head5"] = users_test[:5]

# Fingerprints
report["data_fingerprints"]["users_train"] = {"path": str(OUT_USERS_TRAIN), "sha256": users_train_sha}
report["data_fingerprints"]["users_val"] = {"path": str(OUT_USERS_VAL), "sha256": users_val_sha}
report["data_fingerprints"]["users_test"] = {"path": str(OUT_USERS_TEST), "sha256": users_test_sha}
report["data_fingerprints"]["pairs_train"] = {"path": str(OUT_PAIRS_TRAIN), "bytes": pairs_train_bytes, "sha256": pairs_train_sha}
report["data_fingerprints"]["pairs_val"] = {"path": str(OUT_PAIRS_VAL), "bytes": pairs_val_bytes, "sha256": pairs_val_sha}
report["data_fingerprints"]["pairs_test"] = {"path": str(OUT_PAIRS_TEST), "bytes": pairs_test_bytes, "sha256": pairs_test_sha}

write_json_atomic(REPORT_PATH, report)

# Manifest
def add_artifact(path: Path) -> None:
    rec = {"path": str(path), "bytes": int(path.stat().st_size), "sha256": None, "sha256_error": None}
    try:
        rec["sha256"] = sha256_file(path)
    except PermissionError as e:
        rec["sha256_error"] = f"PermissionError: {e}"
    manifest["artifacts"].append(rec)

add_artifact(OUT_USERS_TRAIN)
add_artifact(OUT_USERS_VAL)
add_artifact(OUT_USERS_TEST)
add_artifact(OUT_PAIRS_TRAIN)
add_artifact(OUT_PAIRS_VAL)
add_artifact(OUT_PAIRS_TEST)

write_json_atomic(MANIFEST_PATH, manifest)

print(f"[CELL 04-12] Updated: {REPORT_PATH}")
print(f"[CELL 04-12] Updated: {MANIFEST_PATH}")

cell_end("CELL 04-12", t0)


[CELL 04-12] Write report + manifest
[CELL 04-12] start=2026-01-07T14:40:35
[CELL 04-12] Updated: C:\anonymous-users-mooc-session-meta\reports\04_user_split_xuetangx\20260107_144034\report.json
[CELL 04-12] Updated: C:\anonymous-users-mooc-session-meta\reports\04_user_split_xuetangx\20260107_144034\manifest.json
[CELL 04-12] elapsed=0.04s
[CELL 04-12] done


In [14]:
# [CELL 04-13] Visualizations: plots and tables for reporting

t0 = cell_start("CELL 04-13", "Generate plots and tables")

import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 10

VIZ_DIR = OUT_DIR / "visualizations"
VIZ_DIR.mkdir(exist_ok=True)

print(f"[CELL 04-13] Creating visualizations in {VIZ_DIR}")

# ===== PLOT 1: Split distribution (users and pairs) =====
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Users
splits = ['Train', 'Val', 'Test']
user_counts = [len(users_train), len(users_val), len(users_test)]
colors = ['#2ecc71', '#3498db', '#e74c3c']

ax1.bar(splits, user_counts, color=colors, alpha=0.7, edgecolor='black')
ax1.set_ylabel('Number of Users', fontsize=11, fontweight='bold')
ax1.set_title('User Split Distribution', fontsize=12, fontweight='bold')
ax1.grid(axis='y', alpha=0.3)
for i, (split, count) in enumerate(zip(splits, user_counts)):
    ax1.text(i, count + 500, f'{count:,}\n({count/n_users_total*100:.1f}%)', 
             ha='center', va='bottom', fontsize=9, fontweight='bold')

# Pairs
pair_counts = [len(pairs_train), len(pairs_val), len(pairs_test)]
ax2.bar(splits, pair_counts, color=colors, alpha=0.7, edgecolor='black')
ax2.set_ylabel('Number of Pairs', fontsize=11, fontweight='bold')
ax2.set_title('Pair Split Distribution', fontsize=12, fontweight='bold')
ax2.grid(axis='y', alpha=0.3)
for i, (split, count) in enumerate(zip(splits, pair_counts)):
    ax2.text(i, count + 3000, f'{count:,}\n({count/len(pairs_df)*100:.1f}%)', 
             ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig(VIZ_DIR / "fig1_split_distribution.png", bbox_inches='tight')
plt.close()
print(f"[CELL 04-13] Saved: fig1_split_distribution.png")

# ===== PLOT 2: Pairs per user distribution by split =====
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)

for idx, (split_name, pairs_split, ax) in enumerate([
    ('Train', pairs_train, axes[0]),
    ('Val', pairs_val, axes[1]),
    ('Test', pairs_test, axes[2])
]):
    pair_counts_split = pairs_split.groupby('user_id').size()
    
    ax.hist(pair_counts_split, bins=50, color=colors[idx], alpha=0.7, edgecolor='black')
    ax.axvline(pair_counts_split.median(), color='red', linestyle='--', linewidth=2, label=f'Median={pair_counts_split.median():.0f}')
    ax.set_xlabel('Pairs per User', fontsize=11, fontweight='bold')
    if idx == 0:
        ax.set_ylabel('Number of Users', fontsize=11, fontweight='bold')
    ax.set_title(f'{split_name} Split', fontsize=12, fontweight='bold')
    ax.legend(loc='upper right')
    ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(VIZ_DIR / "fig2_pairs_per_user.png", bbox_inches='tight')
plt.close()
print(f"[CELL 04-13] Saved: fig2_pairs_per_user.png")

# ===== TABLE 1: Split summary statistics =====
table1 = pd.DataFrame([
    {
        'Split': 'Train',
        'Users': f"{len(users_train):,}",
        'Pairs': f"{len(pairs_train):,}",
        'Users (%)': f"{len(users_train)/n_users_total*100:.1f}%",
        'Pairs (%)': f"{len(pairs_train)/len(pairs_df)*100:.1f}%",
        'Min Pairs/User': int(train_pair_counts.min()),
        'Median Pairs/User': int(train_pair_counts.median()),
        'Max Pairs/User': int(train_pair_counts.max()),
    },
    {
        'Split': 'Val',
        'Users': f"{len(users_val):,}",
        'Pairs': f"{len(pairs_val):,}",
        'Users (%)': f"{len(users_val)/n_users_total*100:.1f}%",
        'Pairs (%)': f"{len(pairs_val)/len(pairs_df)*100:.1f}%",
        'Min Pairs/User': int(val_pair_counts.min()),
        'Median Pairs/User': int(val_pair_counts.median()),
        'Max Pairs/User': int(val_pair_counts.max()),
    },
    {
        'Split': 'Test',
        'Users': f"{len(users_test):,}",
        'Pairs': f"{len(pairs_test):,}",
        'Users (%)': f"{len(users_test)/n_users_total*100:.1f}%",
        'Pairs (%)': f"{len(pairs_test)/len(pairs_df)*100:.1f}%",
        'Min Pairs/User': int(test_pair_counts.min()),
        'Median Pairs/User': int(test_pair_counts.median()),
        'Max Pairs/User': int(test_pair_counts.max()),
    },
])

table1.to_csv(VIZ_DIR / "table1_split_summary.csv", index=False)
print(f"[CELL 04-13] Saved: table1_split_summary.csv")
print(f"\n[CELL 04-13] Table 1: Split Summary")
print(table1.to_string(index=False))

# ===== TABLE 2: Cold-start eligibility by K+Q budget =====
K_CONFIGS = [(5, 10), (10, 20)]
eligibility_records = []

for K, Q in K_CONFIGS:
    min_pairs = K + Q
    n_eligible_train = (train_pair_counts >= min_pairs).sum()
    n_eligible_val = (val_pair_counts >= min_pairs).sum()
    n_eligible_test = (test_pair_counts >= min_pairs).sum()
    
    eligibility_records.append({
        'Config': f'K={K}, Q={Q}',
        'Min Pairs': min_pairs,
        'Train Eligible': f"{n_eligible_train:,} ({n_eligible_train/len(users_train)*100:.1f}%)",
        'Val Eligible': f"{n_eligible_val:,} ({n_eligible_val/len(users_val)*100:.1f}%)",
        'Test Eligible': f"{n_eligible_test:,} ({n_eligible_test/len(users_test)*100:.1f}%)",
    })

table2 = pd.DataFrame(eligibility_records)
table2.to_csv(VIZ_DIR / "table2_cold_start_eligibility.csv", index=False)
print(f"[CELL 04-13] Saved: table2_cold_start_eligibility.csv")
print(f"\n[CELL 04-13] Table 2: Cold-Start Eligibility")
print(table2.to_string(index=False))

print(f"\n[CELL 04-13] All visualizations saved to {VIZ_DIR}")

cell_end("CELL 04-13", t0)


[CELL 04-13] Generate plots and tables
[CELL 04-13] start=2026-01-07T14:40:35
[CELL 04-13] Creating visualizations in C:\anonymous-users-mooc-session-meta\reports\04_user_split_xuetangx\20260107_144034\visualizations
[CELL 04-13] Saved: fig1_split_distribution.png
[CELL 04-13] Saved: fig2_pairs_per_user.png
[CELL 04-13] Saved: table1_split_summary.csv

[CELL 04-13] Table 1: Split Summary
Split  Users   Pairs Users (%) Pairs (%)  Min Pairs/User  Median Pairs/User  Max Pairs/User
Train 33,736 212,923     80.0%     80.6%               1                  2            1318
  Val  4,217  24,698     10.0%      9.3%               1                  2             377
 Test  4,218  26,608     10.0%     10.1%               1                  2             564
[CELL 04-13] Saved: table2_cold_start_eligibility.csv

[CELL 04-13] Table 2: Cold-Start Eligibility
    Config  Min Pairs Train Eligible Val Eligible Test Eligible
 K=5, Q=10         15   3,006 (8.9%)   340 (8.1%)    346 (8.2%)
K=10, Q=20  

## ✅ Notebook 04 Complete

**Outputs:**
- ✅ `data/processed/xuetangx/user_splits/users_train.json` (80% of users)
- ✅ `data/processed/xuetangx/user_splits/users_val.json` (10% of users)
- ✅ `data/processed/xuetangx/user_splits/users_test.json` (10% of users)
- ✅ `data/processed/xuetangx/pairs/pairs_train.parquet`
- ✅ `data/processed/xuetangx/pairs/pairs_val.parquet`
- ✅ `data/processed/xuetangx/pairs/pairs_test.parquet`
- ✅ DuckDB views: `xuetangx_pairs_train`, `xuetangx_pairs_val`, `xuetangx_pairs_test`
- ✅ `reports/04_user_split_xuetangx/<run_tag>/report.json`

**Validation Passed:**
- ✅ No user overlap between splits (disjoint users)
- ✅ All pairs assigned to exactly one split
- ✅ User counts match in pairs and user lists
- ✅ Cold-start guarantee: test users completely unseen during training

**Next:** Notebook 05 (Episode Index)
- Create episodic meta-learning indices
- Sample K-shot support + Q query pairs per user
- Chronological validation (support timestamps < query timestamps)
- Multiple episodes per user for training