Imports & config

In [None]:
# CELL 1 — Imports & config

import pandas as pd
import torch
from pathlib import Path
import json

DATA_DIR = Path("../data/processed")
OUT = DATA_DIR / "tensor_shards"
OUT.mkdir(exist_ok=True)

MAX_PREFIX_LEN = 20
PAD_ID = 0
UNK_ID = 1
SHARD_SIZE = 250_000

print("[05B] Output dir:", OUT.resolve())


[05B] Output dir: C:\Users\User\Documents\ml-workspace\session-transfer-mooc\data\processed\tensor_shards


Load manifest & splits

In [2]:
# CELL 2 — Load manifest & splits

with open(DATA_DIR / "sessionization_manifest.json") as f:
    manifest = json.load(f)

splits = manifest["splits"]

print("[05B] Domains in splits:", splits.keys())


[05B] Domains in splits: dict_keys(['mars'])


Build vocabulary helper (NEW)

In [3]:
# CELL 2.1 — Vocabulary builder

from collections import Counter

def build_vocab_from_prefix_parts(files, top_k=None):
    counter = Counter()

    for p in files:
        df = pd.read_parquet(p, columns=["prefix", "target"])
        for row in df.itertuples(index=False):
            if isinstance(row.prefix, str):
                counter.update(row.prefix.split())
            counter.update([str(row.target)])

    # Reserve PAD=0, UNK=1
    item2id = {"<PAD>": PAD_ID, "<UNK>": UNK_ID}

    for item, _ in counter.most_common(top_k):
        if item not in item2id:
            item2id[item] = len(item2id)

    return item2id


Build SOURCE vocabulary (Amazon + YooChoose TRAIN)

In [4]:
# CELL 2.1.5 — Debug splits keys

print("[DEBUG] splits keys:", splits.keys())

for k, v in splits.items():
    print(f"[DEBUG] domain = {k}")
    if isinstance(v, dict):
        print("        subkeys:", v.keys())


[DEBUG] splits keys: dict_keys(['mars'])
[DEBUG] domain = mars
        subkeys: dict_keys(['train', 'val', 'test'])


In [5]:
# CELL 2.2 — Build SOURCE vocabulary

source_train_files = (
    splits["amazon"]["train"] +
    splits["yoochoose"]["train"]
)


print("[VOCAB][SOURCE] Number of train files:", len(source_train_files))

item2id_source = build_vocab_from_prefix_parts(
    source_train_files,
    top_k=200_000  # keep consistent with earlier design
)

print("[VOCAB][SOURCE] Vocab size:", len(item2id_source))


KeyError: 'amazon'

Load vocabularies

In [None]:
# CELL 3 — Load vocabularies

item2id_source = json.load(open(DATA_DIR / "vocab_topn/item2id_source.json"))
item2id_target = json.load(open(DATA_DIR / "vocab_topn/item2id_target.json"))

print("[05B] Source vocab size:", len(item2id_source))
print("[05B] Target vocab size:", len(item2id_target))


Tensor builder (core logic)

In [None]:
# CELL 4 — Tensor builder

def build_tensors(domain, split, files, item2id):
    buffer = {
        "prefix": [],
        "target": [],
        "length": [],
        "attn_mask": [],
        "pos_ids": []
    }
    shard_id = 0

    def flush():
        nonlocal shard_id
        if not buffer["prefix"]:
            return
        pt = {k: torch.LongTensor(v) for k, v in buffer.items()}
        out_path = OUT / f"{domain}_{split}_shard_{shard_id:03d}.pt"
        torch.save(pt, out_path)
        print(f"[05B] Saved {out_path}")
        shard_id += 1
        for k in buffer:
            buffer[k].clear()

    for part in files:
        df = pd.read_parquet(part, columns=["prefix", "target"])
        for row in df.itertuples(index=False):
            pref = row.prefix.split() if isinstance(row.prefix, str) else []
            ids = [item2id.get(x, UNK_ID) for x in pref]

            if len(ids) > MAX_PREFIX_LEN:
                ids = ids[-MAX_PREFIX_LEN:]

            length = len(ids)
            pad_len = MAX_PREFIX_LEN - length

            padded = [PAD_ID]*pad_len + ids
            attn_mask = [0]*pad_len + [1]*length
            pos_ids = list(range(MAX_PREFIX_LEN))

            buffer["prefix"].append(padded)
            buffer["target"].append(item2id.get(str(row.target), UNK_ID))
            buffer["length"].append(length)
            buffer["attn_mask"].append(attn_mask)
            buffer["pos_ids"].append(pos_ids)

            if len(buffer["prefix"]) >= SHARD_SIZE:
                flush()

    flush()
