### Purpose

Adapters experiments on top of pretrained SASRec encoder.

Goal:
- Add lightweight adapters to each Transformer layer.
- Copy encoder weights from pretrained checkpoint (if available).
- Train adapters + out head (optionally item_emb) and evaluate on MARS.
- Run a small grid over adapter bottleneck sizes and learning rates, saving best results.

### Imports & config

In [15]:
# Quick (unsafe) workaround to avoid the libiomp5md.dll crash.
# Use this only to continue working in the notebook quickly.
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
print("Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.")

Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.


In [9]:
import torch, random, numpy as np, json, time
from copy import deepcopy
from pathlib import Path
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

ROOT = Path('..')
DATA_DIR = ROOT/'data'/'processed'
CKPT_DIR = ROOT/'models'
CKPT_DIR.mkdir(exist_ok=True)
MARS_SHARD = DATA_DIR/'mars_shards'/'mars_shard_full.pt'
VOCAB_FILE = DATA_DIR/'vocab_mars'/'item2id_mars.json'
TEST_PAIRS = DATA_DIR/'mars_test_pairs.parquet'  # optional

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", DEVICE)


Device: cuda


In [10]:
class Adapter(nn.Module):
    def __init__(self, dim, bottleneck=16):
        super().__init__()
        self.down = nn.Linear(dim, bottleneck, bias=False)
        self.act = nn.ReLU(inplace=True)
        self.up = nn.Linear(bottleneck, dim, bias=False)
    def forward(self, x):
        # x: (B, L, D)
        return x + self.up(self.act(self.down(x)))

class SASRecSmall(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, max_len=20, num_layers=2):
        super().__init__()
        self.item_emb = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_emb = nn.Embedding(max_len, embed_dim)
        layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, dim_feedforward=2048, batch_first=True)
        self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
        self.out = nn.Linear(embed_dim, embed_dim, bias=False)
    def forward(self, x):
        B,L = x.size()
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B,L)
        seq = self.item_emb(x) + self.pos_emb(pos)
        seq = self.encoder(seq)
        last = seq[:, -1, :]
        logits = self.out(last)
        return logits, last

class SASRecWithAdapters(SASRecSmall):
    def __init__(self, vocab_size, embed_dim=64, max_len=20, num_layers=2, adapter_bottleneck=16):
        super().__init__(vocab_size, embed_dim, max_len, num_layers=num_layers)
        # One adapter per encoder layer
        self.adapters = nn.ModuleList([Adapter(embed_dim, adapter_bottleneck) for _ in range(num_layers)])
    def forward(self, x):
        B,L = x.size()
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B,L)
        out = self.item_emb(x) + self.pos_emb(pos)
        # iterate through encoder layers and adapters
        for i, layer in enumerate(self.encoder.layers):
            out = layer(out)
            out = self.adapters[i](out)
        last = out[:, -1, :]
        logits = self.out(last)
        return logits, last


### Loss + evaluation helpers

In [11]:
# Cell 3
def sampled_loss(final, y, emb_weights, num_neg=32):
    # final: (B, D), y: (B,) indices, emb_weights: (V, D)
    pos_scores = (final * emb_weights[y]).sum(dim=1)        # (B,)
    V = emb_weights.size(0)
    B = final.size(0)
    neg_idx = torch.randint(0, V, (B, num_neg), device=final.device)
    neg_w = emb_weights[neg_idx]                            # (B, N, D)
    neg_scores = (neg_w * final.unsqueeze(1)).sum(dim=2)   # (B, N)
    logits = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)  # (B, 1+N)
    labels = torch.zeros(B, dtype=torch.long, device=final.device)
    return F.cross_entropy(logits, labels)

def eval_tensor_prefix(model, P_tensor, T_tensor, K=20):
    model.eval()
    hits = 0
    rr = 0.0
    tot = P_tensor.size(0)
    with torch.no_grad():
        for i in range(tot):
            X = P_tensor[i].unsqueeze(0).to(DEVICE)
            tgt = int(T_tensor[i].item())
            _, final = model(X)
            scores = final @ model.item_emb.weight.t()
            topk = scores.topk(K, dim=1).indices.squeeze(0).cpu().numpy()
            if tgt in topk:
                hits += 1
                rank = int((topk == tgt).nonzero()[0]) + 1
                rr += 1.0 / rank
    return hits / max(1, tot), rr / max(1, tot)

def eval_test_df(model, df_test):
    if df_test is None: return None, None
    model.eval()
    hits = 0
    rr = 0.0
    tot = len(df_test)
    with torch.no_grad():
        for _, r in df_test.iterrows():
            pref = r['prefix'] if isinstance(r['prefix'], str) else ''
            pref_ids = [int(x) for x in pref.split()] if pref else []
            if len(pref_ids) > 20: pref_ids = pref_ids[-20:]
            padded = [0] * (20 - len(pref_ids)) + pref_ids
            X = torch.LongTensor([padded]).to(DEVICE)
            tgt = int(r['target'])
            _, final = model(X)
            scores = final @ model.item_emb.weight.t()
            topk = scores.topk(20, dim=1).indices.squeeze(0).cpu().numpy()
            if tgt in topk:
                hits += 1
                rank = int((topk == tgt).nonzero()[0]) + 1
                rr += 1.0 / rank
    return hits / max(1, tot), rr / max(1, tot)


### Load MARS shard, splits, and optional test df

In [12]:
mp = torch.load(MARS_SHARD)
P_all = mp['prefix']; T_all = mp['target']
N = P_all.size(0)
VAL_FRAC = 0.2
val_n = max(1, int(N * VAL_FRAC)); train_n = N - val_n
train_P, train_T = P_all[:train_n], T_all[:train_n]
val_P, val_T = P_all[train_n:], T_all[train_n:]

df_test = None
if TEST_PAIRS.exists():
    import pandas as pd
    df_test = pd.read_parquet(TEST_PAIRS)
    print("Loaded test pairs:", len(df_test))
else:
    print("No test pairs found at", TEST_PAIRS)

item2id = json.load(open(VOCAB_FILE))
vocab = len(item2id)
print(f"Pairs: {N}, train: {train_n}, val: {val_n}, vocab: {vocab}")


  mp = torch.load(MARS_SHARD)


Loaded test pairs: 238
Pairs: 2380, train: 1904, val: 476, vocab: 777


### Grid config (bottlenecks, lrs, seeds) and small utilities

In [13]:
# Cell 5
BOTTLE_GRID = [8, 16, 32]       # adapter bottleneck sizes to try
LR_GRID = [5e-4, 1e-4]          # learning rates to try
SEEDS = [42, 100]               # seeds
EPOCHS = 12
BATCH = 32
PATIENCE = 4

from torch.utils.data import Dataset, DataLoader
class PairsDataset(Dataset):
    def __init__(self, P, T): self.P=P; self.T=T
    def __len__(self): return self.P.size(0)
    def __getitem__(self, i): return self.P[i], int(self.T[i].item())
train_ds = PairsDataset(train_P, train_T)


### Main grid (train adapters + out); copies pretrained encoder if available

In [14]:
results = []
# find pretrained checkpoint (prefer full)
pretrain_candidates = sorted(CKPT_DIR.glob("*full*.pt"), key=lambda p: p.stat().st_mtime, reverse=True)
PRETRAIN = pretrain_candidates[0] if pretrain_candidates else None
print("PRETRAIN checkpoint:", PRETRAIN)

for bottleneck in BOTTLE_GRID:
    for lr in LR_GRID:
        for seed in SEEDS:
            # reproducibility
            torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
            model = SASRecWithAdapters(vocab_size=vocab, adapter_bottleneck=bottleneck).to(DEVICE)

            # try copying pretrained weights (encoder/pos_emb/out/item_emb if shapes match)
            if PRETRAIN is not None:
                try:
                    ck = torch.load(PRETRAIN, map_location='cpu')
                    state = ck.get('model_state', ck)
                    ms = model.state_dict()
                    for k, v in state.items():
                        if k in ms and ms[k].shape == v.shape:
                            ms[k] = v
                    model.load_state_dict(ms)
                    print(f"Copied pretrained weights into adapter model (b={bottleneck}, lr={lr}, seed={seed})")
                except Exception as e:
                    print("Warning: failed to copy pretrained checkpoint:", e)

            # freeze base params except adapters + out (optionally keep item_emb frozen)
            for name, p in model.named_parameters():
                if name.startswith('adapters') or name.startswith('out'):
                    p.requires_grad = True
                else:
                    # freeze encoder and pos_emb and item_emb initially
                    p.requires_grad = False

            train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=0,
                                      collate_fn=lambda b: (torch.stack([x[0] for x in b]).to(DEVICE),
                                                            torch.tensor([x[1] for x in b], device=DEVICE)))
            opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-6)

            best_val = -1; best_state = None; bad = 0
            for ep in range(EPOCHS):
                model.train(); running = 0.0; steps = 0
                for X, y in train_loader:
                    _, final = model(X)
                    loss = sampled_loss(final, y, model.item_emb.weight, num_neg=32)
                    opt.zero_grad(); loss.backward(); opt.step()
                    running += float(loss.item()); steps += 1
                val_rec, val_mrr = eval_tensor_prefix(model, val_P, val_T, K=20)
                print(f"[adapters] b={bottleneck} lr={lr} seed={seed} ep={ep} loss={running/max(1,steps):.4f} val_rec={val_rec:.4f} val_mrr={val_mrr:.4f}")
                if val_rec > best_val:
                    best_val = val_rec
                    best_state = deepcopy(model.state_dict())
                    bad = 0
                    print("  ✓ New best")
                else:
                    bad += 1
                    if bad >= PATIENCE:
                        print("  early stopping")
                        break

            # save best
            if best_state is not None:
                ck_name = CKPT_DIR / f"adapters_b{bottleneck}_lr{lr}_s{seed}.pt"
                torch.save({'model_state': best_state}, ck_name)
                print("Saved adapter checkpoint:", ck_name)
            results.append({'bottleneck': bottleneck, 'lr': lr, 'seed': seed, 'val_rec': best_val})
            # small evaluation on test if exists
            if df_test is not None and best_state is not None:
                # load best_state into model
                ms = model.state_dict(); 
                for k in best_state: 
                    if k in ms and ms[k].shape == best_state[k].shape:
                        ms[k] = best_state[k]
                model.load_state_dict(ms)
                trec, tmrr = eval_test_df(model, df_test)
            else:
                trec, tmrr = None, None
            results[-1].update({'test_rec': trec, 'test_mrr': tmrr})
            print("Result row:", results[-1])
# save grid results
res_df = pd.DataFrame(results)
res_df.to_csv(CKPT_DIR/'adapters_grid_results.csv', index=False)
print("Adapters grid finished. Results saved to", CKPT_DIR/'adapters_grid_results.csv')


PRETRAIN checkpoint: ..\models\sasrec_full_top200000_epoch0.pt_epoch0.pt


  ck = torch.load(PRETRAIN, map_location='cpu')


Copied pretrained weights into adapter model (b=8, lr=0.0005, seed=42)


  rank = int((topk == tgt).nonzero()[0]) + 1


[adapters] b=8 lr=0.0005 seed=42 ep=0 loss=3.8681 val_rec=0.0273 val_mrr=0.0032
  ✓ New best
[adapters] b=8 lr=0.0005 seed=42 ep=1 loss=3.7740 val_rec=0.0294 val_mrr=0.0041
  ✓ New best
[adapters] b=8 lr=0.0005 seed=42 ep=2 loss=3.6945 val_rec=0.0336 val_mrr=0.0050
  ✓ New best
[adapters] b=8 lr=0.0005 seed=42 ep=3 loss=3.6285 val_rec=0.0357 val_mrr=0.0043
  ✓ New best
[adapters] b=8 lr=0.0005 seed=42 ep=4 loss=3.5793 val_rec=0.0294 val_mrr=0.0044
[adapters] b=8 lr=0.0005 seed=42 ep=5 loss=3.5456 val_rec=0.0420 val_mrr=0.0049
  ✓ New best
[adapters] b=8 lr=0.0005 seed=42 ep=6 loss=3.5238 val_rec=0.0357 val_mrr=0.0040
[adapters] b=8 lr=0.0005 seed=42 ep=7 loss=3.4990 val_rec=0.0357 val_mrr=0.0040
[adapters] b=8 lr=0.0005 seed=42 ep=8 loss=3.4742 val_rec=0.0357 val_mrr=0.0044
[adapters] b=8 lr=0.0005 seed=42 ep=9 loss=3.4714 val_rec=0.0315 val_mrr=0.0045
  early stopping
Saved adapter checkpoint: ..\models\adapters_b8_lr0.0005_s42.pt


  rank = int((topk == tgt).nonzero()[0]) + 1


Result row: {'bottleneck': 8, 'lr': 0.0005, 'seed': 42, 'val_rec': 0.04201680672268908, 'test_rec': 0.046218487394957986, 'test_mrr': 0.005527609948480464}
Copied pretrained weights into adapter model (b=8, lr=0.0005, seed=100)
[adapters] b=8 lr=0.0005 seed=100 ep=0 loss=3.7883 val_rec=0.0210 val_mrr=0.0030
  ✓ New best
[adapters] b=8 lr=0.0005 seed=100 ep=1 loss=3.7046 val_rec=0.0357 val_mrr=0.0035
  ✓ New best
[adapters] b=8 lr=0.0005 seed=100 ep=2 loss=3.6522 val_rec=0.0378 val_mrr=0.0035
  ✓ New best
[adapters] b=8 lr=0.0005 seed=100 ep=3 loss=3.5990 val_rec=0.0399 val_mrr=0.0033
  ✓ New best
[adapters] b=8 lr=0.0005 seed=100 ep=4 loss=3.5637 val_rec=0.0336 val_mrr=0.0029
[adapters] b=8 lr=0.0005 seed=100 ep=5 loss=3.5246 val_rec=0.0294 val_mrr=0.0031
[adapters] b=8 lr=0.0005 seed=100 ep=6 loss=3.4994 val_rec=0.0273 val_mrr=0.0043
[adapters] b=8 lr=0.0005 seed=100 ep=7 loss=3.4796 val_rec=0.0399 val_mrr=0.0044
  early stopping
Saved adapter checkpoint: ..\models\adapters_b8_lr0.000

### Quick analysis & recommendation (Markdown + code)

In [None]:
import pandas as pd
df = pd.read_csv(CKPT_DIR/'adapters_grid_results.csv')
df_sorted = df.sort_values('val_rec', ascending=False)
print("Top results:\n", df_sorted.head(10))
df_sorted.to_csv(CKPT_DIR/'adapters_grid_results_sorted.csv', index=False)
print("Saved sorted results to", CKPT_DIR/'adapters_grid_results_sorted.csv')
