### Goal

Reptile meta-learning starter for session-based recommendation. Uses tasks built from pretraining datasets (Yoochoose / Amazon categories). Produces meta-model that can be adapted quickly to MARS.


In [1]:
# Quick (unsafe) workaround to avoid the libiomp5md.dll crash.
# Use this only to continue working in the notebook quickly.
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
print("Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.")

Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.


### Notes
- Cell 1: imports & config
- Cell 2: hashing utils + map one example
- Cell 3: build tasks (scan prefix-target parquet files, convert tokens -> hashed ids, create tasks)
- Cell 4: quick task sanity report
- Cell 5: SASRecSmall model (same as earlier)
- Cell 6: Reptile training loop (train meta-initialization)
- Cell 7: Few-shot adaptation to MARS & evaluation
- Cell 8: Save results and tips

### Cell 1 — Imports & config

In [2]:
# Cell 1 - imports & config
import json
import time
import math
import hashlib
from pathlib import Path
from copy import deepcopy
from tqdm import tqdm

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Paths (adjust to your repo layout)
ROOT = Path('..')
DATA_DIR = ROOT / 'data' / 'processed'
META_DIR = DATA_DIR / 'meta_vocab'
META_DIR.mkdir(parents=True, exist_ok=True)
TASKS_OUT = META_DIR / 'tasks_reduced_hashed_top200k.pt'
TASKS_CSV = META_DIR / 'tasks_summary_hashed_top200k.csv'

# find all candidate prefix-target parquet files recursively (includes amazon parts)
prefix_glob = list((DATA_DIR).rglob("*prefix*target*.parquet"))   # recursive search
# if you used a different folder for amazon parts, e.g. ../data/processed/amazon_prefix_parts, use:
# prefix_glob = list((DATA_DIR/'amazon_prefix_parts').glob("*.parquet"))
print("Found candidate prefix-target files:", len(prefix_glob))

# Hash vocab size
K = 200_000
PAD_IDX = 0
HASH_MOD = K

# Task builder params
MIN_PAIRS_PER_TASK = 50   # keep tasks with >= this many pairs
MAX_TASKS = 300           # reduce number of tasks for quicker experiments (tunable)
MAX_PREFIX_LEN = 20

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE, 'K=', K)


Found candidate prefix-target files: 164
Device: cuda K= 200000


### Cell 2 — Hashing helpers

In [3]:
# Cell 2 - deterministic hash -> id
def token_to_hash_id(token: str, K=HASH_MOD):
    # deterministic MD5 hash mapping, returns 1..K (0 reserved for PAD)
    if token is None or token == '':
        return PAD_IDX
    # normalize token to str
    s = str(token)
    h = hashlib.md5(s.encode('utf-8')).hexdigest()
    idx = (int(h, 16) % K) + 1
    return idx

# quick check
examples = ['B002KQ6BT6', '12345', 'movie_abc', '42']
print([ (t, token_to_hash_id(t)) for t in examples ])


[('B002KQ6BT6', 22123), ('12345', 97916), ('movie_abc', 96575), ('42', 195175)]


### Cell 3 — Build hashed tasks from prefix-target files

In [4]:
# Cell 3 - build tasks
from collections import defaultdict
tasks = []    # each task: {'name': name, 'P': LongTensor (N, L), 'T': LongTensor (N)}

def process_file_to_task(path: Path, max_prefix_len=MAX_PREFIX_LEN):
    df = pd.read_parquet(path)
    # expect columns 'prefix' and 'target' (prefix as space-separated tokens or list)
    # tolerate both formats
    P_list = []
    T_list = []
    for _, r in df.iterrows():
        pref = r.get('prefix', '')
        # if prefix already stored as list-like, handle; else treat as string
        if isinstance(pref, (list, tuple)):
            tokens = [str(x) for x in pref if x is not None and x != '']
        else:
            # assume space-separated token ids or ASINs; handle empty string
            tokens = [t for t in str(pref).split() if t != '']
        # map tokens via hashing
        ids = [token_to_hash_id(t) for t in tokens]
        if len(ids) > max_prefix_len:
            ids = ids[-max_prefix_len:]
        # remove leading PADs if they were created by empty tokens; but keep lengths for nonzeros
        if len(ids) == 0:
            padded = [PAD_IDX] * max_prefix_len
            nonzero_len = 0
        else:
            padded = [PAD_IDX] * (max_prefix_len - len(ids)) + ids
            nonzero_len = sum(1 for x in ids if x != PAD_IDX)
        # if target missing, skip
        target = r.get('target', None)
        if target is None:
            continue
        # map target to hashed id (string target will be hashed)
        tid = token_to_hash_id(target)
        P_list.append(padded)
        T_list.append(int(tid))
    if len(P_list) < MIN_PAIRS_PER_TASK:
        return None
    P_t = torch.LongTensor(P_list)
    T_t = torch.LongTensor(T_list)
    # compute example nonzero length stats quickly
    nonzero_example_len = (P_t != PAD_IDX).sum(dim=1).clamp(max=MAX_PREFIX_LEN)
    return {'name': path.name, 'P': P_t, 'T': T_t, 'n_pairs': P_t.size(0),
            'median_nonzero_len': int(nonzero_example_len.median().item()),
            'frac_nonzero_gt0': float((nonzero_example_len>0).float().mean().item())}

# iterate
count = 0
for p in tqdm(prefix_glob, desc="Scanning files"):
    t = process_file_to_task(p)
    if t is None:
        continue
    tasks.append(t)
    count += 1
    if MAX_TASKS and count >= MAX_TASKS:
        break

print("Built tasks:", len(tasks))
# Save tasks in compact format: store P and T as tensors (could be large)
torch.save(tasks, TASKS_OUT)
# Also write CSV summary
rows = [{'name': t['name'], 'pairs': t['n_pairs'],
         'median_nonzero_len': t['median_nonzero_len'],
         'frac_nonzero_gt0': t['frac_nonzero_gt0']} for t in tasks]
pd.DataFrame(rows).to_csv(TASKS_CSV, index=False)
print("Saved tasks:", TASKS_OUT, "summary:", TASKS_CSV)


Scanning files: 100%|██████████| 164/164 [25:05<00:00,  9.18s/it]


Built tasks: 164
Saved tasks: ..\data\processed\meta_vocab\tasks_reduced_hashed_top200k.pt summary: ..\data\processed\meta_vocab\tasks_summary_hashed_top200k.csv


### Cell 4 — Quick task sanity report

In [5]:
# Cell 4 - Sanity checks
tasks = torch.load(TASKS_OUT)
print("Total tasks loaded:", len(tasks))
# show top 10 by pairs
sorted_tasks = sorted(tasks, key=lambda x: x['n_pairs'], reverse=True)
for t in sorted_tasks[:10]:
    print(t['name'], "pairs=", t['n_pairs'], "median_nonzero_len=", t['median_nonzero_len'],
          "frac_nonzero_gt0=", f"{t['frac_nonzero_gt0']:.3f}")
# distribution of frac_nonzero_gt0
fracs = [t['frac_nonzero_gt0'] for t in tasks]
print("frac_nonzero_gt0 median:", np.median(fracs), "mean:", np.mean(fracs))


  tasks = torch.load(TASKS_OUT)


Total tasks loaded: 164
amazon_prefix_target_part0000.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0001.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0003.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0004.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0006.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0007.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0009.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0010.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0012.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.000
amazon_prefix_target_part0013.parquet pairs= 200000 median_nonzero_len= 4 frac_nonzero_gt0= 1.

### Cell 5 — SASRecSmall (same architecture as pretrain) — instantiate meta-model

In [15]:
# Cell 5 - SASRecSmall
class SASRecSmall(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, max_len=20, num_heads=4, num_layers=2, dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_len = max_len
        self.item_emb = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_emb = nn.Embedding(max_len, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=2048, dropout=dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.out = nn.Linear(embed_dim, embed_dim, bias=False)
    def forward(self, x):
        B, L = x.size()
        pos_ids = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        seq = self.item_emb(x) + self.pos_emb(pos_ids)
        seq = self.encoder(seq)
        last = seq[:, -1, :]
        logits = self.out(last)
        return logits, last

# Create meta-model with hashed vocab
META_VOCAB = K + 1  # 0..K
meta_model = SASRecSmall(vocab_size=META_VOCAB, embed_dim=64, max_len=MAX_PREFIX_LEN).to(DEVICE)
print("Meta-model created. Vocab:", META_VOCAB)


Meta-model created. Vocab: 200001


### Cell 6 — Reptile meta-training loop (simple version)

In [17]:
# Cell 6 - Reptile
# hyperparams (tune)
META_ITERS = 1000            # number of meta-iterations
TASK_BATCH = 16              # number of tasks sampled per meta-iteration 
INNER_STEPS = 16             # SGD steps per task (support) 
SUPPORT_BATCH = 128          # batch size for support updates
INNER_LR = 5e-3
META_STEP = 0.05             # step size to move meta weights toward adapted weights
VAL_TASKS_SAMPLE = 50       # tasks to evaluate on during meta training (optional)

# helper to get minibatches from a task
def task_sampler_from_task(tdict):
    P = tdict['P']
    T = tdict['T']
    N = P.size(0)
    idxs = np.arange(N)
    def gen(batch_size=SUPPORT_BATCH):
        np.random.shuffle(idxs)
        for i in range(0, N, batch_size):
            sel = idxs[i:i+batch_size]
            yield P[sel], T[sel]
    return gen

# utility: copy model parameters (state_dict)
def clone_state_dict(state):
    return {k: v.clone().detach() for k,v in state.items()}

# training loop
tasks = torch.load(TASKS_OUT)
opt_null = None
print("Starting Reptile meta-training (tasks:", len(tasks), ")")
meta_state = meta_model.state_dict()

for it in range(META_ITERS):
    sampled = np.random.choice(len(tasks), size=min(TASK_BATCH, len(tasks)), replace=False)
    meta_state_before = clone_state_dict(meta_state)
    adapted_states = []
    for tid in sampled:
        tinfo = tasks[tid]
        # build a small copy model
        local_model = SASRecSmall(vocab_size=META_VOCAB, embed_dim=64, max_len=MAX_PREFIX_LEN).to(DEVICE)
        local_model.load_state_dict(meta_state)  # start from meta
        local_opt = torch.optim.AdamW(local_model.parameters(), lr=INNER_LR, weight_decay=1e-6)
        # inner-loop: iterate INNER_STEPS over support batches
        gen = task_sampler_from_task(tinfo)()
        step = 0
        try:
            while step < INNER_STEPS:
                Xb, yb = next(gen)
                Xb = Xb.to(DEVICE)
                yb = yb.to(DEVICE)
                local_model.train()
                _, final = local_model(Xb)
                # sampled softmax loss
                V = local_model.item_emb.weight.size(0)
                pos_scores = (final * local_model.item_emb.weight[yb]).sum(dim=1)
                neg_idx = torch.randint(0, V, (Xb.size(0), 32), device=DEVICE)
                neg_w = local_model.item_emb.weight[neg_idx]
                neg_scores = (neg_w * final.unsqueeze(1)).sum(dim=2)
                logits = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
                labels = torch.zeros(Xb.size(0), dtype=torch.long, device=DEVICE)
                loss = F.cross_entropy(logits, labels)
                local_opt.zero_grad(); loss.backward(); local_opt.step()
                step += 1
        except StopIteration:
            pass
        adapted_states.append(clone_state_dict(local_model.state_dict()))
        # free local model
        del local_model, local_opt

    # meta-update: move meta_state toward average adapted_state
    avg_state = {}
    for k in meta_state:
        stacked = torch.stack([s[k].to('cpu') for s in adapted_states], dim=0)
        avg = torch.mean(stacked, dim=0)
        avg_state[k] = avg.to(meta_state[k].device)
    # apply reptile update: meta = meta + eps * (avg - meta)
    for k in meta_state:
        meta_state[k] = meta_state[k] + META_STEP * (avg_state[k].to(meta_state[k].device) - meta_state[k])

    # every N iterations optionally evaluate quick validation
    if (it+1) % 50 == 0 or it == 0:
        # compute a very cheap diagnostic: random task few-shot adapt and eval on its heldout pairs
        # we'll do one quick task eval to monitor progress
        idx = np.random.randint(len(tasks))
        tdiag = tasks[idx]
        # split task into support/query
        N = tdiag['P'].size(0)
        qn = max(1, int(0.2 * N))
        perm = np.random.permutation(N)
        sup_idx = perm[:-qn]; qry_idx = perm[-qn:]
        # adapt from meta_state for a few steps
        tmp_model = SASRecSmall(vocab_size=META_VOCAB, embed_dim=64, max_len=MAX_PREFIX_LEN).to(DEVICE)
        tmp_model.load_state_dict(meta_state)
        tmp_opt = torch.optim.AdamW(tmp_model.parameters(), lr=INNER_LR)
        # support steps
        for s in range(5):
            sel = sup_idx[s::5][:SUPPORT_BATCH] if len(sup_idx)>0 else sup_idx
            if len(sel)==0: break
            Xb = tdiag['P'][sel].to(DEVICE)
            yb = tdiag['T'][sel].to(DEVICE)
            _, final = tmp_model(Xb)
            pos_scores = (final * tmp_model.item_emb.weight[yb]).sum(dim=1)
            neg_idx = torch.randint(0, tmp_model.item_emb.weight.size(0), (Xb.size(0), 32), device=DEVICE)
            neg_w = tmp_model.item_emb.weight[neg_idx]
            neg_scores = (neg_w * final.unsqueeze(1)).sum(dim=2)
            logits = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
            loss = F.cross_entropy(logits, torch.zeros(Xb.size(0), dtype=torch.long, device=DEVICE))
            tmp_opt.zero_grad(); loss.backward(); tmp_opt.step()
        # evaluate on query
        hits = 0; total = 0
        for qi in qry_idx:
            Xq = tdiag['P'][qi].unsqueeze(0).to(DEVICE)
            tq = int(tdiag['T'][qi].item())
            _, final = tmp_model(Xq)
            scores = torch.matmul(final, tmp_model.item_emb.weight.t())
            topk = scores.topk(20, dim=1).indices.squeeze(0).cpu().numpy()
            total += 1
            if tq in topk: hits += 1
        quick_recall = hits / total if total>0 else 0.0
        print(f"[Reptile] iter {it+1}/{META_ITERS} quick_recall@20={quick_recall:.4f}")

# After meta loop - save meta_state as meta initialization
meta_model.load_state_dict(meta_state)
torch.save({'meta_state': meta_state}, META_DIR / 'reptile_meta_state_top200k.pt')
print("Saved meta init to:", META_DIR / 'reptile_meta_state_top200k.pt')


  tasks = torch.load(TASKS_OUT)


Starting Reptile meta-training (tasks: 164 )
[Reptile] iter 1/1000 quick_recall@20=0.0001
[Reptile] iter 50/1000 quick_recall@20=0.0004
[Reptile] iter 100/1000 quick_recall@20=0.0003
[Reptile] iter 150/1000 quick_recall@20=0.0036
[Reptile] iter 200/1000 quick_recall@20=0.0035
[Reptile] iter 250/1000 quick_recall@20=0.0007
[Reptile] iter 300/1000 quick_recall@20=0.0000
[Reptile] iter 350/1000 quick_recall@20=0.0013
[Reptile] iter 400/1000 quick_recall@20=0.0035
[Reptile] iter 450/1000 quick_recall@20=0.0013
[Reptile] iter 500/1000 quick_recall@20=0.0007
[Reptile] iter 550/1000 quick_recall@20=0.0026
[Reptile] iter 600/1000 quick_recall@20=0.0037
[Reptile] iter 650/1000 quick_recall@20=0.0012
[Reptile] iter 700/1000 quick_recall@20=0.0008
[Reptile] iter 750/1000 quick_recall@20=0.0082
[Reptile] iter 800/1000 quick_recall@20=0.0038
[Reptile] iter 850/1000 quick_recall@20=0.0008
[Reptile] iter 900/1000 quick_recall@20=0.0061
[Reptile] iter 950/1000 quick_recall@20=0.0166
[Reptile] iter 100

### Cell 7 — Few-shot adaptation to MARS and final eval

In [8]:
# Cell 7 - Adapt to MARS (few-shot) and evaluate
# load MARS shard (built earlier), map using same hashing
MARS_SHARD_FILE = DATA_DIR / 'mars_shards' / 'mars_shard_full.pt'
if not MARS_SHARD_FILE.exists():
    raise FileNotFoundError("Please ensure MARS shard exists (built in 07_transfer_to_mars).")

mp = torch.load(MARS_SHARD_FILE)
P_all = mp['prefix']   # NOTE: these were constructed earlier with original item2id mapping; we re-hash tokens here
# We need to rebuild a hashed MARS shard to be consistent. If your MARS_shard already uses text tokens, rebuild; otherwise convert.
# For simplicity assume mars_prefix_target.parquet exists and we rebuild hashed prefixes here
MARS_PAIRS = DATA_DIR / 'mars_prefix_target.parquet'
if not MARS_PAIRS.exists():
    raise FileNotFoundError("Please create mars_prefix_target.parquet first.")
df_mars_pairs = pd.read_parquet(MARS_PAIRS)
# build hashed MARS tensors
P_list = []
T_list = []
for _, r in df_mars_pairs.iterrows():
    pref = r['prefix'] if isinstance(r['prefix'], str) else ''
    tokens = [t for t in str(pref).split() if t!='']
    ids = [token_to_hash_id(t) for t in tokens]
    if len(ids) > MAX_PREFIX_LEN: ids = ids[-MAX_PREFIX_LEN:]
    padded = [PAD_IDX]*(MAX_PREFIX_LEN - len(ids)) + ids
    P_list.append(padded)
    T_list.append(token_to_hash_id(r['target']))
P_H = torch.LongTensor(P_list)
T_H = torch.LongTensor(T_list)
print("Built hashed MARS pairs:", P_H.size(0))

# Split train/val/test
n = P_H.size(0)
test_n = max(1, int(0.1*n))
val_n = max(1, int(0.1*n))
train_n = n - val_n - test_n
train_P, train_T = P_H[:train_n].to(DEVICE), T_H[:train_n].to(DEVICE)
val_P, val_T = P_H[train_n:train_n+val_n].to(DEVICE), T_H[train_n:train_n+val_n].to(DEVICE)
test_P, test_T = P_H[train_n+val_n:].to(DEVICE), T_H[train_n+val_n:].to(DEVICE)
print("MARS splits: train", train_P.size(0), "val", val_P.size(0), "test", test_P.size(0))

# Load meta init
meta_ck = torch.load(META_DIR / 'reptile_meta_state_top200k.pt', map_location=DEVICE)
meta_state = meta_ck['meta_state']
adapt_model = SASRecSmall(vocab_size=META_VOCAB, embed_dim=64, max_len=MAX_PREFIX_LEN).to(DEVICE)
adapt_model.load_state_dict(meta_state)

# Few-shot fine-tune (support small K shots) — try different K_shots
def adapt_and_eval(K_shots=50, adapt_steps=10, lr=1e-4):
    # sample K_shots from train
    idxs = np.random.choice(train_P.size(0), size=min(K_shots, train_P.size(0)), replace=False)
    Xs = train_P[idxs]
    ys = train_T[idxs]
    model = SASRecSmall(vocab_size=META_VOCAB, embed_dim=64, max_len=MAX_PREFIX_LEN).to(DEVICE)
    model.load_state_dict(meta_state)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    for s in range(adapt_steps):
        model.train()
        _, final = model(Xs)
        pos_scores = (final * model.item_emb.weight[ys]).sum(dim=1)
        neg_idx = torch.randint(0, model.item_emb.weight.size(0), (Xs.size(0), 32), device=DEVICE)
        neg_w = model.item_emb.weight[neg_idx]
        neg_scores = (neg_w * final.unsqueeze(1)).sum(dim=2)
        logits = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
        loss = F.cross_entropy(logits, torch.zeros(Xs.size(0), dtype=torch.long, device=DEVICE))
        opt.zero_grad(); loss.backward(); opt.step()
    # evaluate on test
    model.eval()
    hits = 0; rr_sum = 0.0; total = 0
    with torch.no_grad():
        for i in range(test_P.size(0)):
            Xq = test_P[i].unsqueeze(0)
            target = int(test_T[i].item())
            _, final = model(Xq)
            scores = torch.matmul(final, model.item_emb.weight.t())
            topk = scores.topk(20, dim=1).indices.squeeze(0).cpu().numpy()
            total += 1
            if target in topk:
                hits += 1
                rank = int((topk == target).nonzero()[0]) + 1
                rr_sum += 1.0 / rank
    recall = hits/total if total>0 else 0.0
    mrr = rr_sum/total if total>0 else 0.0
    return recall, mrr

for K_shots in [10, 50, 100, 200]:
    r, m = adapt_and_eval(K_shots=K_shots, adapt_steps=10, lr=1e-4)
    print(f"Few-shot K={K_shots} -> Recall@20={r:.4f}, MRR={m:.4f}")


  mp = torch.load(MARS_SHARD_FILE)


Built hashed MARS pairs: 2380
MARS splits: train 1904 val 238 test 238


  meta_ck = torch.load(META_DIR / 'reptile_meta_state_top200k.pt', map_location=DEVICE)


Few-shot K=10 -> Recall@20=0.0000, MRR=0.0000


  rank = int((topk == target).nonzero()[0]) + 1


Few-shot K=50 -> Recall@20=0.0042, MRR=0.0003
Few-shot K=100 -> Recall@20=0.0000, MRR=0.0000
Few-shot K=200 -> Recall@20=0.0000, MRR=0.0000


### Cell 8 — Save & short tips

In [9]:
# Cell 8 - Save and tips
torch.save({'meta_state': meta_state, 'K': K, 'pad': PAD_IDX}, META_DIR / 'meta_info_top200k.pt')
print("Saved meta info.")

# Quick tips:
# - If many tasks still have very low frac_nonzero_gt0 (<0.05) consider increasing K or using per-file remapping.
# - Reptile hyperparams: TASK_BATCH, INNER_STEPS, META_STEP - tune depending on compute.
# - Consider adapter layers (09_adapters) combined with Reptile: adapt fewer params and get robust few-shot.


Saved meta info.


In [12]:
# Build meta_item2id_top200k.json from tasks and re-run the overlap diagnostic
# Paste and run this in the same notebook/kernel after your previous cells.

import json
from pathlib import Path
import torch
from collections import Counter
import numpy as np

ROOT = Path('..')
META_DIR = ROOT / 'data' / 'processed' / 'meta_vocab'
META_DIR.mkdir(parents=True, exist_ok=True)

TASKS_FILE = META_DIR / 'tasks_reduced_hashed_top200k.pt'   # adjust if your tasks file has a different name
OUT_MAP = META_DIR / 'meta_item2id_top200k.json'
TOPK = 200000   # your chosen K (you set K=200k earlier)

print("Loading tasks from:", TASKS_FILE)
tasks = torch.load(TASKS_FILE)

# Count token frequencies across tasks (tokens stored as ints in P tensors)
cnt = Counter()
for t in tasks:
    P = t['P']  # torch.LongTensor (n_pairs, L)
    # flatten and count non-zero tokens (exclude padding 0)
    flat = P.view(-1).cpu().numpy()
    flat_nonzero = flat[flat != 0]
    cnt.update(flat_nonzero.tolist())

print("Unique tokens counted:", len(cnt))

# select top-K most frequent tokens
most_common = cnt.most_common(TOPK)
meta_item2id = {}
# reserve 0 for OOV/pad
meta_item2id['<OOV>'] = 0
next_id = 1
for token, freq in most_common:
    # store keys as strings (your downstream code expects string keys)
    meta_item2id[str(int(token))] = next_id
    next_id += 1

print(f"Built meta map with {len(meta_item2id)-1} tokens (plus <OOV>). Saving to:", OUT_MAP)
with open(OUT_MAP, 'w') as f:
    json.dump(meta_item2id, f)

# Quick overlap diagnostic (same as before) to confirm
hits = 0
tot = 0
sample_tasks = tasks[:min(len(tasks),50)]
for t in sample_tasks:
    P = t['P']
    flat = np.unique(P.view(-1).cpu().numpy())
    tot += len(flat)
    hits += sum(1 for v in flat if str(int(v)) in meta_item2id)
print("sample meta overlap (first 50 tasks):", hits, tot, (hits/tot if tot>0 else 0.0))

# Print a few entries to sanity-check
some_keys = list(meta_item2id.items())[:10]
print("sample mapping entries:", some_keys)


Loading tasks from: ..\data\processed\meta_vocab\tasks_reduced_hashed_top200k.pt


  tasks = torch.load(TASKS_FILE)


Unique tokens counted: 200000
Built meta map with 200000 tokens (plus <OOV>). Saving to: ..\data\processed\meta_vocab\meta_item2id_top200k.json
sample meta overlap (first 50 tasks): 3179118 3179168 0.9999842726147219
sample mapping entries: [('<OOV>', 0), ('154710', 1), ('34011', 2), ('31351', 3), ('100781', 4), ('76434', 5), ('108491', 6), ('29454', 7), ('127543', 8), ('188382', 9)]


In [13]:
# DIAGNOSTICS: tasks sanity & vocab overlap
from pathlib import Path
import torch
import json
import numpy as np
import pandas as pd

TASKS_FILE = Path("..") / "data" / "processed" / "meta_vocab" / "tasks_reduced_hashed_top200k.pt"
META_MAP = Path("..") / "data" / "processed" / "meta_vocab" / "meta_item2id_top200k.json"
MARS_VOCAB = Path("..") / "data" / "processed" / "vocab_mars" / "item2id_mars.json"

tasks = torch.load(TASKS_FILE)
print("Tasks loaded:", len(tasks))

# per-task stats
rows = []
for i, t in enumerate(tasks):
    p = t['P'] if isinstance(t, dict) else t['P']
    nonzero = (p.sum(dim=1) > 0).float().mean().item()
    rows.append((i, len(p), float(nonzero)))
df = pd.DataFrame(rows, columns=['task_idx','pairs','frac_nonzero_prefix'])
print(df.describe())
print("Examples (first 10):")
print(df.head(10))

# vocab overlap
meta_map = json.load(open(META_MAP))
meta_vsize = len(meta_map)
print("meta vocab size:", meta_vsize)

# check sample of task tokens => how many map to meta vocab
def sample_vocab_overlap(tasks, ncheck=1000):
    hits=0; tot=0
    for t in tasks[:min(len(tasks),50)]:  # sample first 50 tasks
        P = t['P']
        flat = P.view(-1).unique().cpu().numpy()
        tot += len(flat)
        hits += sum(1 for v in flat if str(int(v)) in meta_map)  # tokens saved as strings
    return hits, tot, hits/tot if tot>0 else 0.0

hits,tot,frac = sample_vocab_overlap(tasks)
print("sample meta overlap (first 50 tasks):", hits, tot, frac)


  tasks = torch.load(TASKS_FILE)


Tasks loaded: 164
        task_idx          pairs  frac_nonzero_prefix
count  164.00000     164.000000                164.0
mean    81.50000  135930.170732                  1.0
std     47.48684   89265.090386                  0.0
min      0.00000    2380.000000                  1.0
25%     40.75000   12800.000000                  1.0
50%     81.50000  200000.000000                  1.0
75%    122.25000  200000.000000                  1.0
max    163.00000  200000.000000                  1.0
Examples (first 10):
   task_idx   pairs  frac_nonzero_prefix
0         0    2380                  1.0
1         1  200000                  1.0
2         2  200000                  1.0
3         3   12544                  1.0
4         4  200000                  1.0
5         5  200000                  1.0
6         6   10989                  1.0
7         7  200000                  1.0
8         8  200000                  1.0
9         9   12115                  1.0
meta vocab size: 200001
sample me

In [14]:
# ADAPTERS + Reptile meta-training (adapter inner updates)
import torch.nn as nn
import copy
import random
from tqdm.auto import trange

# small adapter module
class Adapter(nn.Module):
    def __init__(self, dim, bottleneck=32):
        super().__init__()
        self.down = nn.Linear(dim, bottleneck)
        self.act = nn.ReLU()
        self.up = nn.Linear(bottleneck, dim)
    def forward(self, x):
        return self.up(self.act(self.down(x))) + x

# attach adapters to a SASRecSmall instance (in-place)
def attach_adapters(model, bottleneck=32):
    adapters = {}
    # find encoder layers
    for i,layer in enumerate(model.encoder.layers):
        a = Adapter(model.embed_dim, bottleneck=bottleneck)
        layer.adapter = a
        adapters[f"encoder.layers.{i}.adapter"] = a
    # also add adapter on final projection if desired
    model.adapters = adapters
    return model

# apply adapter forward: modify SASRecSmall.forward to use adapter if present
_orig_forward = SASRecSmall.forward
def _forward_with_adapters(self, x):
    B, L = x.size()
    pos_ids = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
    seq = self.item_emb(x) + self.pos_emb(pos_ids)
    # pass through encoder layers and apply adapter after each encoder layer
    for i,layer in enumerate(self.encoder.layers):
        seq = layer(seq)
        if hasattr(layer, "adapter"):
            seq = layer.adapter(seq)
    last = seq
