### Title & notes 

Purpose: Grid search over learning rates for the strategy:
- copy pretrained encoder weights
- re-initialize item embedding & out head
- freeze encoder + pos_emb
- train item_emb + out only

Outputs:
- checkpoint per lr/seed: ../models/reinit_emb_lr{lr}_s{seed}.pt
- CSV summary: ../models/reinit_emb_grid_results.csv


In [2]:
# Quick (unsafe) workaround to avoid the libiomp5md.dll crash.
# Use this only to continue working in the notebook quickly.
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
print("Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.")

Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.


### Imports & global config

In [3]:
import json, time, random, numpy as np, torch
from pathlib import Path
from copy import deepcopy
import torch.nn.functional as F

ROOT = Path("..")
DATA_DIR = ROOT/"data"/"processed"
CKPT_DIR = ROOT/"models"
CKPT_DIR.mkdir(exist_ok=True)
MARS_SHARD = DATA_DIR/"mars_shards"/"mars_shard_full.pt"
VOCAB_FILE = DATA_DIR/"vocab_mars"/"item2id_mars.json"
PRETRAIN = sorted(CKPT_DIR.glob("*full*.pt"), key=lambda p:p.stat().st_mtime, reverse=True)
PRETRAIN = PRETRAIN[0] if PRETRAIN else None

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", DEVICE)
print("PRETRAIN:", PRETRAIN)


device: cuda
PRETRAIN: ..\models\sasrec_full_top200000_epoch0.pt_epoch0.pt


### Model definition & helpers

In [4]:
import torch.nn as nn

class SASRecSmall(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, max_len=20):
        super().__init__()
        self.item_emb = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_emb = nn.Embedding(max_len, embed_dim)
        enc_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, dim_feedforward=2048, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=2)
        self.out = nn.Linear(embed_dim, embed_dim, bias=False)
    def forward(self, x):
        B,L = x.size()
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B,L)
        seq = self.item_emb(x) + self.pos_emb(pos)
        seq = self.encoder(seq)
        last = seq[:,-1,:]
        logits = self.out(last)
        return logits, last

def sampled_loss(final, y, emb, neg=32):
    pos = (final * emb[y]).sum(dim=1)
    V = emb.size(0); B = final.size(0)
    neg_idx = torch.randint(0, V, (B, neg), device=final.device)
    negW = emb[neg_idx]
    neg_scores = (negW * final.unsqueeze(1)).sum(dim=2)
    logits = torch.cat([pos.unsqueeze(1), neg_scores], dim=1)
    labels = torch.zeros(B, dtype=torch.long, device=final.device)
    return F.cross_entropy(logits, labels)


### Load data shard & splits

In [5]:
mp = torch.load(MARS_SHARD)
P_all = mp['prefix']; T_all = mp['target']
N = P_all.size(0)
VAL_FRAC = 0.2
val_n = max(1, int(N * VAL_FRAC)); train_n = N - val_n
train_P, train_T = P_all[:train_n], T_all[:train_n]
val_P, val_T = P_all[train_n:], T_all[train_n:]
vocab = len(json.load(open(VOCAB_FILE)))
print("pairs:", N, "train:", train_n, "val:", val_n, "vocab:", vocab)


pairs: 2380 train: 1904 val: 476 vocab: 777


  mp = torch.load(MARS_SHARD)


### Training loop function

In [6]:
from torch.utils.data import Dataset, DataLoader

class PairsDataset(Dataset):
    def __init__(self,P,T): self.P=P; self.T=T
    def __len__(self): return self.P.size(0)
    def __getitem__(self,i): return self.P[i], int(self.T[i].item())

def validate_model(model, val_P, val_T, K=20):
    model.eval()
    hits=0; rr=0.0; tot=val_P.size(0)
    with torch.no_grad():
        for i in range(tot):
            X = val_P[i].unsqueeze(0).to(DEVICE)
            tgt = int(val_T[i].item())
            _, final = model(X)
            scores = final @ model.item_emb.weight.t()
            topk = scores.topk(K, dim=1).indices.squeeze(0).cpu().numpy()
            if tgt in topk:
                hits += 1
                rank = int((topk==tgt).nonzero()[0]) + 1
                rr += 1.0 / rank
    return hits / tot, rr / tot


### Grid run (LRs & seeds)

In [7]:
LRs=[1e-2,5e-3,1e-3,5e-4,1e-4]
SEEDS=[42,100,2023]
EPOCHS=20; BATCH=32
results=[]

for lr in LRs:
    for seed in SEEDS:
        torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
        model = SASRecSmall(vocab).to(DEVICE)
        if PRETRAIN:
            ck = torch.load(PRETRAIN, map_location=DEVICE)
            st = ck.get('model_state', ck)
            ms = model.state_dict()
            for k,v in st.items():
                if k in ms and ms[k].shape == v.shape:
                    ms[k] = v
            model.load_state_dict(ms)
        # reinit item emb + out
        nn.init.normal_(model.item_emb.weight, mean=0.0, std=0.01)
        nn.init.normal_(model.out.weight, mean=0.0, std=0.01)
        # freeze encoder + pos_emb
        for name,p in model.named_parameters():
            if name.startswith('encoder') or name.startswith('pos_emb'):
                p.requires_grad = False
            else:
                p.requires_grad = True
        opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-6)
        loader = DataLoader(PairsDataset(train_P, train_T), batch_size=BATCH, shuffle=True, num_workers=0)
        best_val=-1; best_state=None; patience=0
        for ep in range(EPOCHS):
            model.train(); running=0.0; steps=0
            for Xidx, yidx in loader:
                X = Xidx.to(DEVICE); y = yidx.to(DEVICE)
                _, final = model(X)
                loss = sampled_loss(final, y, model.item_emb.weight, neg=32)
                opt.zero_grad(); loss.backward(); opt.step()
                running += float(loss.item()); steps += 1
            val_rec, val_mrr = validate_model(model, val_P, val_T, K=20)
            print(f"lr={lr} seed={seed} ep={ep} loss={running/max(1,steps):.4f} val_rec={val_rec:.4f} val_mrr={val_mrr:.4f}")
            if val_rec > best_val:
                best_val = val_rec; best_state = deepcopy(model.state_dict()); patience=0
            else:
                patience += 1
                if patience >= 5:
                    break
        out_path = CKPT_DIR/f"reinit_emb_lr{lr}_s{seed}.pt"
        torch.save({'model_state': best_state}, out_path)
        results.append({'lr':lr,'seed':seed,'val_rec':best_val})
import pandas as pd
pd.DataFrame(results).to_csv(CKPT_DIR/'reinit_emb_grid_results.csv', index=False)
print("Grid done. Saved results to", CKPT_DIR/'reinit_emb_grid_results.csv')


  ck = torch.load(PRETRAIN, map_location=DEVICE)
  rank = int((topk==tgt).nonzero()[0]) + 1


lr=0.01 seed=42 ep=0 loss=3.2958 val_rec=0.0735 val_mrr=0.0153
lr=0.01 seed=42 ep=1 loss=2.8451 val_rec=0.0798 val_mrr=0.0298
lr=0.01 seed=42 ep=2 loss=2.6776 val_rec=0.0861 val_mrr=0.0275
lr=0.01 seed=42 ep=3 loss=2.5633 val_rec=0.1008 val_mrr=0.0386
lr=0.01 seed=42 ep=4 loss=2.4984 val_rec=0.1408 val_mrr=0.0362
lr=0.01 seed=42 ep=5 loss=2.4146 val_rec=0.1282 val_mrr=0.0409
lr=0.01 seed=42 ep=6 loss=2.3329 val_rec=0.1534 val_mrr=0.0398
lr=0.01 seed=42 ep=7 loss=2.2555 val_rec=0.1723 val_mrr=0.0465
lr=0.01 seed=42 ep=8 loss=2.1906 val_rec=0.2017 val_mrr=0.0617
lr=0.01 seed=42 ep=9 loss=2.1104 val_rec=0.2164 val_mrr=0.0652
lr=0.01 seed=42 ep=10 loss=2.0385 val_rec=0.2185 val_mrr=0.0692
lr=0.01 seed=42 ep=11 loss=1.9612 val_rec=0.2269 val_mrr=0.0770
lr=0.01 seed=42 ep=12 loss=1.9119 val_rec=0.2542 val_mrr=0.0837
lr=0.01 seed=42 ep=13 loss=1.8273 val_rec=0.2857 val_mrr=0.0976
lr=0.01 seed=42 ep=14 loss=1.7673 val_rec=0.3046 val_mrr=0.1126
lr=0.01 seed=42 ep=15 loss=1.7035 val_rec=0.3256 v

  ck = torch.load(PRETRAIN, map_location=DEVICE)


lr=0.01 seed=100 ep=0 loss=3.2997 val_rec=0.0609 val_mrr=0.0154
lr=0.01 seed=100 ep=1 loss=2.8402 val_rec=0.0609 val_mrr=0.0213
lr=0.01 seed=100 ep=2 loss=2.6531 val_rec=0.0672 val_mrr=0.0262
lr=0.01 seed=100 ep=3 loss=2.5777 val_rec=0.0903 val_mrr=0.0295
lr=0.01 seed=100 ep=4 loss=2.4926 val_rec=0.1050 val_mrr=0.0278
lr=0.01 seed=100 ep=5 loss=2.4168 val_rec=0.1092 val_mrr=0.0363
lr=0.01 seed=100 ep=6 loss=2.3370 val_rec=0.1387 val_mrr=0.0421
lr=0.01 seed=100 ep=7 loss=2.2764 val_rec=0.1681 val_mrr=0.0509
lr=0.01 seed=100 ep=8 loss=2.2088 val_rec=0.1786 val_mrr=0.0544
lr=0.01 seed=100 ep=9 loss=2.1343 val_rec=0.1807 val_mrr=0.0602
lr=0.01 seed=100 ep=10 loss=2.0601 val_rec=0.1912 val_mrr=0.0641
lr=0.01 seed=100 ep=11 loss=1.9748 val_rec=0.2164 val_mrr=0.0757
lr=0.01 seed=100 ep=12 loss=1.9295 val_rec=0.2395 val_mrr=0.0829
lr=0.01 seed=100 ep=13 loss=1.8474 val_rec=0.2605 val_mrr=0.0925
lr=0.01 seed=100 ep=14 loss=1.7763 val_rec=0.2773 val_mrr=0.0985
lr=0.01 seed=100 ep=15 loss=1.7019 