This notebook will:

Load the shards from data/processed/tensor_shards_v2/
Build PyTorch datasets + dataloaders for each {domain, split}
Provide consistent evaluation metrics: Recall@K, MRR@K
Provide an evaluation loop skeleton (model-agnostic)

Assumption (from 05B outputs): shards contain
input_ids, attention_mask, pos_ids, labels, lengths and MARS has 1 shard per split.

In [16]:
# Quick (unsafe) workaround to avoid the libiomp5md.dll crash.
# Use this only to continue working in the notebook quickly.
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
print("Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.")

Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.


In [17]:
# CELL [06-01] — Imports & run header

import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import torch
from torch.utils.data import IterableDataset, DataLoader

print("[06-01] Starting 06_data_loader_and_eval_protocol.ipynb")
print("[06-01] torch:", torch.__version__)


[06-01] Starting 06_data_loader_and_eval_protocol.ipynb
[06-01] torch: 2.5.1


Config & paths

In [18]:
# CELL [06-02] — Config & paths

DATA_DIR = Path("../data/processed")
SHARDS_DIR = DATA_DIR / "tensor_shards_v2"
META_PATH = SHARDS_DIR / "metadata.json"

assert SHARDS_DIR.exists(), f"[06-02] Missing SHARDS_DIR: {SHARDS_DIR}"
assert META_PATH.exists(), f"[06-02] Missing metadata.json: {META_PATH}"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("[06-02] SHARDS_DIR:", SHARDS_DIR.resolve())
print("[06-02] META_PATH:", META_PATH.resolve())
print("[06-02] DEVICE:", DEVICE)


[06-02] SHARDS_DIR: D:\00_DS-ML-Workspace\session-transfer-mooc\data\processed\tensor_shards_v2
[06-02] META_PATH: D:\00_DS-ML-Workspace\session-transfer-mooc\data\processed\tensor_shards_v2\metadata.json
[06-02] DEVICE: cuda


Load metadata + sanity print

In [19]:
# CELL [06-03] — Load metadata

with open(META_PATH, "r") as f:
    meta = json.load(f)

print("[06-03] max_prefix_len:", meta["max_prefix_len"])
print("[06-03] pad_id:", meta["pad_id"], "unk_id:", meta["unk_id"])
print("[06-03] shard_size:", meta["shard_size"])
print("[06-03] source_vocab_size:", meta["vocab"]["source"]["size"])
print("[06-03] target_vocab_size:", meta["vocab"]["target"]["size"])
print("[06-03] tensor_fields:", meta["tensor_fields"])


[06-03] max_prefix_len: 20
[06-03] pad_id: 0 unk_id: 1
[06-03] shard_size: 250000
[06-03] source_vocab_size: 200002
[06-03] target_vocab_size: 702
[06-03] tensor_fields: ['input_ids', 'attention_mask', 'pos_ids', 'labels', 'lengths']


Discover shard files (from filesystem)

In [20]:
# CELL [06-04] — Discover shard files

def list_shards(domain: str, split: str) -> List[Path]:
    pattern = f"{domain}_{split}_shard_*.pt"
    files = sorted(SHARDS_DIR.glob(pattern))
    return files

domains = ["amazon", "yoochoose", "mars"]
splits = ["train", "val", "test"]

shard_index: Dict[str, Dict[str, List[Path]]] = {}
for d in domains:
    shard_index[d] = {}
    for s in splits:
        shard_index[d][s] = list_shards(d, s)
        print(f"[06-04] {d}/{s}: {len(shard_index[d][s])} shards")

# Hard assertions based on your real 05B output expectations
assert len(shard_index["mars"]["train"]) == 1, "[06-04] Expected 1 mars train shard"
assert len(shard_index["mars"]["val"]) == 1, "[06-04] Expected 1 mars val shard"
assert len(shard_index["mars"]["test"]) == 1, "[06-04] Expected 1 mars test shard"


[06-04] amazon/train: 56 shards
[06-04] amazon/val: 8 shards
[06-04] amazon/test: 6 shards
[06-04] yoochoose/train: 76 shards
[06-04] yoochoose/val: 8 shards
[06-04] yoochoose/test: 12 shards
[06-04] mars/train: 1 shards
[06-04] mars/val: 1 shards
[06-04] mars/test: 1 shards


Check a shard schema (must match exactly)

In [21]:
# CELL [06-05] — Validate shard schema

REQUIRED_KEYS = ["input_ids", "attention_mask", "pos_ids", "labels", "lengths"]

sample_path = shard_index["mars"]["train"][0]
sample = torch.load(sample_path, map_location="cpu", weights_only=True)

print("[06-05] Sample shard:", sample_path.name)
print("[06-05] Keys:", list(sample.keys()))

for k in REQUIRED_KEYS:
    assert k in sample, f"[06-05] Missing key in shard: {k}"

print("[06-05] Shapes:")
for k in REQUIRED_KEYS:
    v = sample[k]
    print("  ", k, tuple(v.shape), v.dtype)

# Basic consistency checks
n = sample["input_ids"].shape[0]
assert sample["labels"].shape[0] == n
assert sample["attention_mask"].shape[0] == n
assert sample["pos_ids"].shape[0] == n
assert sample["lengths"].shape[0] == n

print("[06-05] Schema checks PASSED")


[06-05] Sample shard: mars_train_shard_000.pt
[06-05] Keys: ['input_ids', 'attention_mask', 'pos_ids', 'labels', 'lengths']
[06-05] Shapes:
   input_ids (1744, 20) torch.int64
   attention_mask (1744, 20) torch.int64
   pos_ids (1744, 20) torch.int64
   labels (1744,) torch.int64
   lengths (1744,) torch.int64
[06-05] Schema checks PASSED


Iterable shard dataset (memory safe)
This streams shards one by one and yields individual examples.

In [22]:
# CELL [06-06] — IterableDataset over shards (example-level)

class ShardExamplesDataset(IterableDataset):
    def __init__(self, shard_paths: List[Path], shuffle_shards: bool = False):
        super().__init__()
        self.shard_paths = list(shard_paths)
        self.shuffle_shards = shuffle_shards

    def __iter__(self):
        paths = self.shard_paths
        if self.shuffle_shards:
            # shard-level shuffle (safe + cheap)
            g = torch.Generator()
            g.manual_seed(42)
            idx = torch.randperm(len(paths), generator=g).tolist()
            paths = [paths[i] for i in idx]

        for sp in paths:
            shard = torch.load(sp, map_location="cpu", weights_only=True)
            n = shard["input_ids"].shape[0]

            for i in range(n):
                yield {
                    "input_ids": shard["input_ids"][i],
                    "attention_mask": shard["attention_mask"][i],
                    "pos_ids": shard["pos_ids"][i],
                    "labels": shard["labels"][i],
                    "lengths": shard["lengths"][i],
                }


Collate function (batch tensors)

In [23]:
# CELL [06-07] — Collate function

def collate_batch(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    return {
        "input_ids": torch.stack([x["input_ids"] for x in batch], dim=0),
        "attention_mask": torch.stack([x["attention_mask"] for x in batch], dim=0),
        "pos_ids": torch.stack([x["pos_ids"] for x in batch], dim=0),
        "labels": torch.stack([x["labels"] for x in batch], dim=0),
        "lengths": torch.stack([x["lengths"] for x in batch], dim=0),
    }


DataLoader factory (domain/split)

In [24]:
# CELL [06-08] — DataLoader factory

def make_loader(domain: str, split: str, batch_size: int, shuffle_shards: bool) -> DataLoader:
    shard_paths = shard_index[domain][split]
    assert len(shard_paths) > 0, f"[06-08] No shards found for {domain}/{split}"

    ds = ShardExamplesDataset(shard_paths, shuffle_shards=shuffle_shards)

    # Note: For IterableDataset, shuffle is not supported here (we shuffle shards instead).
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=0,      # keep 0 for Windows stability
        pin_memory=False,   # CPU training for now
        collate_fn=collate_batch,
        drop_last=False
    )
    print(f"[06-08] Built loader {domain}/{split} | shards={len(shard_paths)} | batch_size={batch_size}")
    return loader


Quick loader test (MARS)

In [25]:
# CELL [06-09] — Quick loader test (MARS)

mars_train_loader = make_loader("mars", "train", batch_size=64, shuffle_shards=False)

batch = next(iter(mars_train_loader))
print("[06-09] Batch keys:", batch.keys())
for k, v in batch.items():
    print("  ", k, tuple(v.shape), v.dtype)

print("[06-09] Example labels:", batch["labels"][:10].tolist())


[06-08] Built loader mars/train | shards=1 | batch_size=64
[06-09] Batch keys: dict_keys(['input_ids', 'attention_mask', 'pos_ids', 'labels', 'lengths'])
   input_ids (64, 20) torch.int64
   attention_mask (64, 20) torch.int64
   pos_ids (64, 20) torch.int64
   labels (64,) torch.int64
   lengths (64,) torch.int64
[06-09] Example labels: [226, 129, 119, 209, 165, 150, 151, 197, 210, 90]


Evaluation Protocol (Metrics)

These metrics assume the model returns scores/logits over item IDs: shape [B, V]
(where V = vocab size) and labels is shape [B].

Metric helpers (Recall@K, MRR@K)

In [26]:
# CELL [06-10] — Metrics

@torch.no_grad()
def recall_at_k(scores: torch.Tensor, labels: torch.Tensor, k: int) -> float:
    """
    scores: [B, V]
    labels: [B]
    """
    topk = torch.topk(scores, k=k, dim=1).indices  # [B, k]
    hit = (topk == labels.unsqueeze(1)).any(dim=1).float()  # [B]
    return float(hit.mean().item())

@torch.no_grad()
def mrr_at_k(scores: torch.Tensor, labels: torch.Tensor, k: int) -> float:
    """
    Mean Reciprocal Rank at K.
    """
    topk = torch.topk(scores, k=k, dim=1).indices  # [B, k]
    labels_exp = labels.unsqueeze(1)               # [B, 1]
    match = (topk == labels_exp)                   # [B, k]

    # rank positions are 1..k
    ranks = torch.arange(1, k + 1, device=scores.device).unsqueeze(0)  # [1, k]
    rr = torch.where(match, 1.0 / ranks, torch.zeros_like(ranks, dtype=torch.float))
    rr_max = rr.max(dim=1).values  # [B]
    return float(rr_max.mean().item())


Eval loop (model-agnostic)

In [27]:
# The models later must implement
# scores = model(batch)  # scores shape [B, V]


In [28]:
# CELL [06-11] — Evaluation loop

@torch.no_grad()
def evaluate_model(model, loader: DataLoader, k: int = 20, device: str = DEVICE) -> Dict[str, float]:
    model.eval()
    total_recall = 0.0
    total_mrr = 0.0
    n_batches = 0

    for bi, batch in enumerate(loader, 1):
        # move to device
        for key in ["input_ids", "attention_mask", "pos_ids", "labels", "lengths"]:
            batch[key] = batch[key].to(device)

        scores = model(batch)  # must return [B, V]
        labels = batch["labels"]

        r = recall_at_k(scores, labels, k=k)
        m = mrr_at_k(scores, labels, k=k)

        total_recall += r
        total_mrr += m
        n_batches += 1

        if bi % 50 == 0:
            print(f"[06-11][EVAL] batch={bi} recall@{k}={r:.4f} mrr@{k}={m:.4f}")

    return {
        f"Recall@{k}": total_recall / max(n_batches, 1),
        f"MRR@{k}": total_mrr / max(n_batches, 1),
        "batches": n_batches
    }


Dummy model sanity check (for pipeline only)

This is only to verify the evaluation loop works end-to-end.
It returns random scores.

In [29]:
# CELL [06-12] — Dummy model (pipeline sanity only)

class DummyRandomModel(torch.nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.vocab_size = vocab_size

    def forward(self, batch):
        bsz = batch["input_ids"].shape[0]
        return torch.randn(bsz, self.vocab_size, device=batch["input_ids"].device)

dummy = DummyRandomModel(vocab_size=meta["vocab"]["target"]["size"]).to(DEVICE)

mars_val_loader = make_loader("mars", "val", batch_size=64, shuffle_shards=False)

metrics = evaluate_model(dummy, mars_val_loader, k=20, device=DEVICE)
print("[06-12] Dummy metrics:", metrics)


[06-08] Built loader mars/val | shards=1 | batch_size=64
[06-12] Dummy metrics: {'Recall@20': 0.028125, 'MRR@20': 0.004089781828224659, 'batches': 5}
