### imports, config, and build MARS shard

In [1]:
# Quick (unsafe) workaround to avoid the libiomp5md.dll crash.
# Use this only to continue working in the notebook quickly.
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
print("Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.")

Set KMP_DUPLICATE_LIB_OK=TRUE — restart kernel and re-run cells now.


#### Purpose: transfer the pretrained SASRec encoder to MARS, fine-tune, validate, and save results.

### Imports, global config, paths

In [3]:
import json
import time
from pathlib import Path
from copy import deepcopy


import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader


# Paths (adjust if needed)
ROOT = Path('..')
DATA_DIR = ROOT / 'data' / 'processed'
CKPT_DIR = ROOT / 'models'
CKPT_DIR.mkdir(exist_ok=True)
MARS_VOCAB_DIR = DATA_DIR / 'vocab_mars'
MARS_SHARD_DIR = DATA_DIR / 'mars_shards'
MARS_SHARD_DIR.mkdir(exist_ok=True)


# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)


# Fine-tune hyperparams (safe defaults)
MAX_PREFIX_LEN = 20
EMBED_DIM = 64
FT_BATCH_SIZE = 32
FT_LR = 1e-5
FT_EPOCHS = 12
FT_NEG = 32
UNFREEZE_AFTER = 3
EARLY_STOPPING_PATIENCE = 4
FP16 = True
VAL_FRAC = 0.2


# Filenames
MARS_INTERACTIONS = DATA_DIR / 'mars_interactions.parquet'
MARS_PAIRS = DATA_DIR / 'mars_prefix_target.parquet'
MARS_SHARD_FILE = MARS_SHARD_DIR / 'mars_shard_full.pt'


print('Paths set.')

Using device: cuda
Paths set.


### SASRecSmall model (exact compatible implementation)

In [5]:
class SASRecSmall(nn.Module):
    def __init__(self, vocab_size, embed_dim=EMBED_DIM, max_len=MAX_PREFIX_LEN, num_heads=4, num_layers=2, dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_len = max_len


        self.item_emb = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_emb = nn.Embedding(max_len, embed_dim)


        encoder_layer = nn.TransformerEncoderLayer(
        d_model=embed_dim,
        nhead=num_heads,
        dim_feedforward=2048,
        dropout=dropout,
        batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)


        # output head: map embedding -> embedding (we use sampled softmax that uses final*item_emb.T)
        self.out = nn.Linear(embed_dim, embed_dim, bias=False)


    def forward(self, x):
        # x: (B, L)
        B, L = x.size()
        pos_ids = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        seq = self.item_emb(x) + self.pos_emb(pos_ids)
        seq = self.encoder(seq)
        last = seq[:, -1, :]
        logits = self.out(last)
        return logits, last


    print('SASRecSmall defined')

SASRecSmall defined


### Build or load MARS vocab and shard (prefix-target pairs)

In [6]:
if not (MARS_VOCAB_DIR / 'item2id_mars.json').exists():
    if not MARS_INTERACTIONS.exists():
        raise FileNotFoundError(f"Missing {MARS_INTERACTIONS} — run sessionization notebook first")
    df_m = pd.read_parquet(MARS_INTERACTIONS)
    mars_items = sorted(df_m['item_id'].astype(str).unique())
    item2id_mars = {it: idx+1 for idx, it in enumerate(mars_items)}
    item2id_mars['<OOV>'] = 0
    MARS_VOCAB_DIR.mkdir(exist_ok=True)
    json.dump(item2id_mars, open(MARS_VOCAB_DIR / 'item2id_mars.json','w'))
    print('Saved item2id_mars.json')
else:
    item2id_mars = json.load(open(MARS_VOCAB_DIR / 'item2id_mars.json'))

vocab_size_mars = len(item2id_mars)
print('MARS vocab size:', vocab_size_mars)

# Build shard from MARS pairs if missing
if not MARS_SHARD_FILE.exists():
    if not MARS_PAIRS.exists():
        raise FileNotFoundError(f"Missing {MARS_PAIRS} — create prefix-target pairs first")
    print('Building MARS shard from', MARS_PAIRS)
    df_pairs = pd.read_parquet(MARS_PAIRS)
    prefixes, targets, lengths = [], [], []
    for _, r in df_pairs.iterrows():
        pref = r['prefix'] if isinstance(r['prefix'], str) else ''
        pref_ids = [ item2id_mars.get(x, 0) for x in pref.split() ] if pref else []
        if len(pref_ids) > MAX_PREFIX_LEN:
            pref_ids = pref_ids[-MAX_PREFIX_LEN:]
        padded = [0]*(MAX_PREFIX_LEN - len(pref_ids)) + pref_ids
        prefixes.append(padded)
        targets.append(item2id_mars.get(str(r['target']), 0))
        lengths.append(len(pref_ids))
    pt = {
        'prefix': torch.LongTensor(prefixes),
        'target': torch.LongTensor(targets),
        'length': torch.LongTensor(lengths)
    }
    torch.save(pt, MARS_SHARD_FILE)
    print('Wrote MARS shard:', MARS_SHARD_FILE, 'pairs:', len(prefixes))
else:
    print('MARS shard exists:', MARS_SHARD_FILE)

# Load shard into memory for small dataset
mp = torch.load(MARS_SHARD_FILE)
print('Loaded MARS shard with pairs:', mp['prefix'].size(0))

MARS vocab size: 777
MARS shard exists: ..\data\processed\mars_shards\mars_shard_full.pt
Loaded MARS shard with pairs: 2380


  mp = torch.load(MARS_SHARD_FILE)


### Auto-find pretrained checkpoint, create mars_model, copy weights, freeze encoder

In [7]:
models = sorted(CKPT_DIR.glob('*.pt'), key=lambda p: p.stat().st_mtime, reverse=True)
print('Found models:', [p.name for p in models[:10]])

# prefer 'full' > 'phaseb' > 'warmup' > newest
def score_name(fn):
    n = fn.name.lower()
    if 'full' in n: return 100
    if 'phaseb' in n: return 80
    if 'warmup' in n: return 50
    return 10

models = sorted(models, key=lambda p: (score_name(p), p.stat().st_mtime), reverse=True)
PRETRAIN_CKPT = models[0] if models else None
print('Auto-selected checkpoint:', PRETRAIN_CKPT)

# instantiate mars model
mars_model = SASRecSmall(vocab_size=vocab_size_mars, embed_dim=EMBED_DIM, max_len=MAX_PREFIX_LEN).to(device)

# load checkpoint & copy weights safely
if PRETRAIN_CKPT and PRETRAIN_CKPT.exists():
    ck = torch.load(PRETRAIN_CKPT, map_location=device)
    pretrained_state = ck['model_state'] if 'model_state' in ck else ck
    mars_state = mars_model.state_dict()
    copied, partial, skipped = [], [], []
    for k,v in pretrained_state.items():
        if k in mars_state:
            try:
                if mars_state[k].shape == v.shape:
                    mars_state[k] = v
                    copied.append(k)
                else:
                    if 'item_emb.weight' in k and v.ndim==2:
                        n = min(mars_state[k].shape[0], v.shape[0])
                        mars_state[k][:n] = v[:n]
                        partial.append((k, n))
                    else:
                        skipped.append((k, 'shape_mismatch'))
            except Exception as e:
                skipped.append((k, f'error:{e}'))
        else:
            skipped.append((k, 'missing_in_target'))
    mars_model.load_state_dict(mars_state)
    print('Loaded checkpoint:', PRETRAIN_CKPT.name)
    print('Copied exact:', len(copied), 'partial:', len(partial), 'skipped:', len(skipped))
else:
    print('No pretrained checkpoint found — training from scratch')

# Freeze encoder/embeddings initially
for name, p in mars_model.named_parameters():
    if name.startswith('encoder') or name.startswith('item_emb') or name.startswith('pos_emb'):
        p.requires_grad = False
print('Froze encoder + embeddings for initial fine-tune')

Found models: ['sasrec_full_top200000_epoch0.pt_epoch0.pt', 'sasrec_phaseB_top200000_epoch1_epoch1.pt', 'sasrec_phaseB_top200000_epoch0_epoch0.pt', 'sasrec_warmup_top200000_epoch2_epoch2.pt', 'sasrec_warmup_top200000_epoch1_epoch1.pt', 'sasrec_warmup_top200000_epoch0_epoch0.pt']
Auto-selected checkpoint: ..\models\sasrec_full_top200000_epoch0.pt_epoch0.pt
Loaded checkpoint: sasrec_full_top200000_epoch0.pt_epoch0.pt
Copied exact: 25 partial: 1 skipped: 2
Froze encoder + embeddings for initial fine-tune


  ck = torch.load(PRETRAIN_CKPT, map_location=device)


### Build DataLoader and validation split

In [9]:
class MarsInMem(IterableDataset):
    def __init__(self, pt): self.pt = pt
    def __iter__(self):
        P, L, T = self.pt['prefix'], self.pt['length'], self.pt['target']
        for i in range(P.size(0)):
            yield P[i], int(L[i].item()), int(T[i].item())

mp = torch.load(MARS_SHARD_FILE)
num_pairs = mp['prefix'].size(0)
val_n = max(1, int(num_pairs * VAL_FRAC))
train_n = num_pairs - val_n

P_all = mp['prefix']
T_all = mp['target']
val_prefixes = P_all[train_n:]
val_targets = T_all[train_n:]

mars_ds = MarsInMem(mp)
# Note: DataLoader with IterableDataset does not accept shuffle=True. If you need shuffling,
# either implement shuffling inside the IterableDataset or set up a map-style Dataset.
# For MARS (small dataset), deterministic ordering is fine.
mars_loader = DataLoader(mars_ds, batch_size=FT_BATCH_SIZE, collate_fn=lambda b: (
    torch.stack([x[0] for x in b], dim=0).to(device),
    torch.tensor([x[1] for x in b], dtype=torch.long).to(device),
    torch.tensor([x[2] for x in b], dtype=torch.long).to(device)
), num_workers=0)

print('Dataloaders ready. train_pairs:', train_n, 'val_pairs:', val_n)

Dataloaders ready. train_pairs: 1904 val_pairs: 476


  mp = torch.load(MARS_SHARD_FILE)


### Evaluation helpers (Recall@K & MRR) and sampled softmax loss stub

In [10]:
import torch.nn.functional as F

def evaluate_on_validation(model, k=20):
    model.eval()
    hits = 0
    rr_sum = 0.0
    total = 0
    with torch.no_grad():
        for i in range(val_prefixes.size(0)):
            X = val_prefixes[i].unsqueeze(0).to(device)
            target = int(val_targets[i].item())
            logits, final = model(X)
            scores = torch.matmul(final, model.item_emb.weight.t())
            topk = scores.topk(k, dim=1).indices.squeeze(0).cpu().numpy()
            total += 1
            if target in topk:
                hits += 1
                rank_idx = int((topk == target).nonzero()[0]) + 1
                rr_sum += 1.0 / rank_idx
    recall_at_k = hits / total if total>0 else 0.0
    mrr = rr_sum / total if total>0 else 0.0
    return recall_at_k, mrr

# Simple sampled softmax loss (reuse existing implementation if available)
def sampled_softmax_loss(final, y, emb_weights, num_negatives=32):
    # final: (B, D), emb_weights: (V, D)
    # Build positives scores
    pos_scores = (final * emb_weights[y]).sum(dim=1)  # assumes y is LongTensor of indices
    # sample negatives uniformly
    V = emb_weights.size(0)
    batch = final.size(0)
    neg_idx = torch.randint(0, V, (batch, num_negatives), device=final.device)
    neg_w = emb_weights[neg_idx]  # (B, N, D)
    neg_scores = (neg_w * final.unsqueeze(1)).sum(dim=2)  # (B, N)
    # combine
    logits = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)  # (B, 1+N)
    labels = torch.zeros(batch, dtype=torch.long, device=final.device)
    return F.cross_entropy(logits, labels)

print('Helpers ready')

Helpers ready


### Fine-tune loop with freeze/unfreeze, early stopping, saving best

In [12]:
# Initialize optimizer for parameters that require grad; if none, fall back to all params
trainable_params = [p for p in mars_model.parameters() if p.requires_grad]
if len(trainable_params) == 0:
    print('Warning: no trainable params found. Unfreezing output head (out) and using all params for optimizer.')
    # ensure at least the output head is trainable
    for name, p in mars_model.named_parameters():
        if name.startswith('out'):
            p.requires_grad = True
    trainable_params = [p for p in mars_model.parameters() if p.requires_grad]

opt = torch.optim.AdamW(trainable_params, lr=FT_LR, weight_decay=1e-6)
scaler = torch.cuda.amp.GradScaler(enabled=FP16)

# Sanity: print trainable parameter names
trainable_names = [name for name, p in mars_model.named_parameters() if p.requires_grad]
print('Trainable parameter names (sample):', trainable_names[:20])

best_val = -1.0
best_state = None
no_improve = 0

for epoch in range(FT_EPOCHS):
    t0 = time.time()
    mars_model.train()
    running = 0.0
    steps = 0
    for step, (X, L, y) in enumerate(mars_loader):
        with torch.cuda.amp.autocast(enabled=FP16):
            logits, final = mars_model(X)
            loss = sampled_softmax_loss(final, y, mars_model.item_emb.weight, num_negatives=FT_NEG)

        # If loss has no grad_fn (no trainable params), raise informative error and try to recover
        if not getattr(loss, 'requires_grad', False):
            print('ERROR: computed loss does not require grad. Diagnostics:')
            print('  - number of trainable params:', len([p for p in mars_model.parameters() if p.requires_grad]))
            print('  - trainable param names:', [n for n, p in mars_model.named_parameters() if p.requires_grad])
            # Attempt to recover by unfreezing all params and reinitializing optimizer
            print('Attempting to unfreeze all parameters and reinitialize optimizer...')
            for p in mars_model.parameters():
                p.requires_grad = True
            opt = torch.optim.AdamW(mars_model.parameters(), lr=FT_LR, weight_decay=1e-6)
            scaler = torch.cuda.amp.GradScaler(enabled=FP16)
            # recompute loss once more (forward again)
            with torch.cuda.amp.autocast(enabled=FP16):
                logits, final = mars_model(X)
                loss = sampled_softmax_loss(final, y, mars_model.item_emb.weight, num_negatives=FT_NEG)
            if not getattr(loss, 'requires_grad', False):
                raise RuntimeError('Recovery failed: loss still does not require grad after unfreezing. Please check model parameter requires_grad flags.')

        opt.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        running += float(loss.item())
        steps += 1
    train_time = time.time() - t0
    avg_loss = running / max(1, steps)
    val_recall, val_mrr = evaluate_on_validation(mars_model)
    print(f"[FT] epoch {epoch} train_loss={avg_loss:.4f} val_rec@20={val_recall:.4f} val_mrr={val_mrr:.4f} time={train_time:.1f}s")

    # save checkpoint
    ckpt = CKPT_DIR / f"mars_finetune_epoch{epoch}.pt"
    torch.save({"epoch": epoch, "model_state": mars_model.state_dict(), "opt_state": opt.state_dict()}, ckpt)
    print('Saved checkpoint:', ckpt)

    # early stopping
    if val_recall > best_val:
        best_val = val_recall
        best_state = deepcopy(mars_model.state_dict())
        no_improve = 0
        print('New best val_rec@20:', best_val)
    else:
        no_improve += 1
        print('No improvement, patience', no_improve, '/', EARLY_STOPPING_PATIENCE)
        if no_improve >= EARLY_STOPPING_PATIENCE:
            print('Early stopping triggered')
            break

    # unfreeze after UNFREEZE_AFTER
    if epoch == UNFREEZE_AFTER:
        print('Unfreezing encoder and reinitializing optimizer')
        for p in mars_model.parameters():
            p.requires_grad = True
        opt = torch.optim.AdamW(mars_model.parameters(), lr=FT_LR, weight_decay=1e-6)
        scaler = torch.cuda.amp.GradScaler(enabled=FP16)

# restore best
if best_state is not None:
    mars_model.load_state_dict(best_state)
    torch.save({'model_state': best_state}, CKPT_DIR / 'mars_finetune_best.pt')
    print('Saved best model to mars_finetune_best.pt')

# save encoder-only
encoder_state = {k:v for k,v in mars_model.state_dict().items() if k.startswith('encoder') or k.startswith('item_emb') or k.startswith('pos_emb')}
torch.save(encoder_state, CKPT_DIR / 'mars_encoder_only.pt')
print('Saved encoder-only to mars_encoder_only.pt')

  scaler = torch.cuda.amp.GradScaler(enabled=FP16)
  with torch.cuda.amp.autocast(enabled=FP16):
  scaler = torch.cuda.amp.GradScaler(enabled=FP16)
  with torch.cuda.amp.autocast(enabled=FP16):


Trainable parameter names (sample): ['out.weight']
ERROR: computed loss does not require grad. Diagnostics:
  - number of trainable params: 1
  - trainable param names: ['out.weight']
Attempting to unfreeze all parameters and reinitialize optimizer...


  rank_idx = int((topk == target).nonzero()[0]) + 1


[FT] epoch 0 train_loss=4.2654 val_rec@20=0.0189 val_mrr=0.0021 time=1.1s
Saved checkpoint: ..\models\mars_finetune_epoch0.pt
New best val_rec@20: 0.018907563025210083
[FT] epoch 1 train_loss=4.0094 val_rec@20=0.0252 val_mrr=0.0021 time=1.0s
Saved checkpoint: ..\models\mars_finetune_epoch1.pt
New best val_rec@20: 0.025210084033613446
[FT] epoch 2 train_loss=3.8746 val_rec@20=0.0189 val_mrr=0.0018 time=1.0s
Saved checkpoint: ..\models\mars_finetune_epoch2.pt
No improvement, patience 1 / 4
[FT] epoch 3 train_loss=3.7910 val_rec@20=0.0189 val_mrr=0.0020 time=1.0s
Saved checkpoint: ..\models\mars_finetune_epoch3.pt
No improvement, patience 2 / 4
Unfreezing encoder and reinitializing optimizer


  scaler = torch.cuda.amp.GradScaler(enabled=FP16)


[FT] epoch 4 train_loss=3.7791 val_rec@20=0.0189 val_mrr=0.0025 time=1.0s
Saved checkpoint: ..\models\mars_finetune_epoch4.pt
No improvement, patience 3 / 4
[FT] epoch 5 train_loss=3.6720 val_rec@20=0.0189 val_mrr=0.0031 time=1.0s
Saved checkpoint: ..\models\mars_finetune_epoch5.pt
No improvement, patience 4 / 4
Early stopping triggered
Saved best model to mars_finetune_best.pt
Saved encoder-only to mars_encoder_only.pt


### Final evaluation on test pairs (if available)

In [13]:
TEST_PAIRS = DATA_DIR / 'mars_test_pairs.parquet'
if not TEST_PAIRS.exists():
    print('No mars_test_pairs.parquet — create a test split and run this cell later')
else:
    df_test = pd.read_parquet(TEST_PAIRS)
    mars_model.eval()
    K = 20
    hits = 0
    rr_sum = 0.0
    total = 0
    with torch.no_grad():
        for _, r in df_test.iterrows():
            pref = r['prefix'] if isinstance(r['prefix'], str) else ''
            pref_ids = [ item2id_mars.get(x,0) for x in pref.split() ] if pref else []
            if len(pref_ids) > MAX_PREFIX_LEN:
                pref_ids = pref_ids[-MAX_PREFIX_LEN:]
            padded = [0]*(MAX_PREFIX_LEN-len(pref_ids)) + pref_ids
            X = torch.LongTensor([padded]).to(device)
            logits, final = mars_model(X)
            scores = torch.matmul(final, mars_model.item_emb.weight.t())
            topk = scores.topk(K, dim=1).indices.squeeze(0).cpu().numpy()
            target = item2id_mars.get(str(r['target']), 0)
            total += 1
            if target in topk:
                hits += 1
                rank = int((topk == target).nonzero()[0]) + 1
                rr_sum += 1.0 / rank
    recall_at_k = hits / total if total>0 else 0.0
    mrr = rr_sum / total if total>0 else 0.0
    print(f'Final Eval Recall@{K}: {recall_at_k:.4f}, MRR: {mrr:.4f} (n={total})')

No mars_test_pairs.parquet — create a test split and run this cell later
