# SASRec + CatBoost Ranking Pipeline (with 5% Group Hold-out Validation)

This notebook trains a **hybrid recommender** for next-item prediction using:
- **SASRec** (sequential Transformer) for sequence signal
- **Co-occurrence / Popularity / Simple category** features
- **CatBoostRanker (YetiRank)** for final re-ranking

It includes:
1. Data loading & transformation
2. Train/validation split (**5% group hold-out by `CUSTOMER_ID`**)
3. Artifact building on *train only* (no leakage)
4. SASRec training on train-LOO
5. CatBoost listwise training
6. Validation (**Recall@3** + candidate coverage)
7. Test inference and CSV export

> **Note:** This notebook is designed to _run end-to-end_. If you only want the code files,
> see `requirements.txt` and `README.md` generated alongside this notebook.

## 0. Setup & Config
- Reproducibility seeds
- Device selection
- Core configuration

In [None]:
import os, re, ast, json, time, random, gc
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from catboost import CatBoostRanker, Pool
from sklearn.model_selection import GroupKFold

# ----------------------------
# Reproducibility
# ----------------------------
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ----------------------------
# Config
# ----------------------------
CFG = {
    "SASREC_EPOCHS": 5,           # increase to 10-15 for better SASRec
    "SASREC_D_MODEL": 128,
    "BASE_CAND_N": 150,           # base candidate pool size before re-ranking
    "SASREC_BLEND_TOP": 60,       # SASRec candidates blended into base pool
    "PREPOOL_GLOB": 1000,         # how many global-pop items to prepool for SASRec scoring
    "HARD_NEGS": 30,              # LOO negatives per query for CatBoost training
    "HARD_NEIGHBORS": 200,        # per-item hardest neighbors from co-occurrence
    "GLOB_TOP": 1000,             # add global-pop items into hard pool
    "CATBOOST_ITERS": 300,        # bump to 1000+ if training time allows
    "BATCH_SIZE": 256,
    "VAL_FRACTION": 0.05,         # ~5% group hold-out by CUSTOMER_ID
    "MIN_VAL_QUERIES": 500        # ensure at least this many validation orders if possible
}

## 1. Utilities
- JSON parsing helpers for `ORDERS` column
- String normalization

In [None]:
def canon_item(name: str) -> str:
    if name is None: return None
    s = re.sub(r'\s+', ' ', str(name)).strip()
    return s

def is_non_food(name: str) -> bool:
    s = name.lower()
    return ('order' in s) or ('memo paid' in s) or ('unpaid' in s) or ('unavailable' in s)

def parse_orders_cell(cell):
    if cell is None or (isinstance(cell, float) and pd.isna(cell)):
        return []
    s = str(cell)
    try:
        obj = json.loads(s)
    except Exception:
        try:
            obj = ast.literal_eval(s)
        except Exception:
            return []
    items = []
    try:
        blocks = obj.get('orders', obj)
        if isinstance(blocks, dict):
            blocks = [blocks]
        for blk in blocks:
            dets = blk.get('item_details', blk.get('items', []))
            if isinstance(dets, dict):
                dets = [dets]
            for d in dets or []:
                nm = d.get('item_name') or d.get('name') or d.get('ItemName')
                if nm:
                    items.append(canon_item(nm))
    except Exception:
        pass
    return [x for x in items if x and not is_non_food(x)]

def norm_str(x):
    if pd.isna(x): return np.nan
    s = str(x).strip()
    return s if s else np.nan

## 2. Load & Prepare Data
- Reads `order_data.csv`, `customer_data.csv`, `store_data.csv`, `test_data_question.csv`
- Parses item lists, drops short orders
- Merges context (customer/store)
- Encodes categorical IDs

In [None]:
DATA_DIR = "/kaggle/input/wwtsets"  # change if running locally

orders = pd.read_csv(f"{DATA_DIR}/order_data.csv")
cust   = pd.read_csv(f"{DATA_DIR}/customer_data.csv")
stores = pd.read_csv(f"{DATA_DIR}/store_data.csv")
testdf = pd.read_csv(f"{DATA_DIR}/test_data_question.csv")

# Parse items for training orders
orders["ITEMS"] = orders["ORDERS"].apply(parse_orders_cell)
orders["ITEMS"] = orders["ITEMS"].apply(lambda xs: [x for x in xs if x])
orders["ITEM_COUNT"] = orders["ITEMS"].str.len()
orders = orders[orders["ITEM_COUNT"] >= 2].copy()
orders.reset_index(drop=True, inplace=True)

# Normalize/merge context
cust["CUSTOMER_TYPE"] = cust["CUSTOMER_TYPE"].map(norm_str)
stores["CITY"]  = stores["CITY"].map(lambda x: norm_str(x).upper() if isinstance(x, str) else np.nan)
stores["STATE"] = stores["STATE"].map(lambda x: norm_str(x).upper() if isinstance(x, str) else np.nan)

orders = orders.merge(cust[["CUSTOMER_ID", "CUSTOMER_TYPE"]], on="CUSTOMER_ID", how="left")
orders = orders.merge(stores[["STORE_NUMBER", "CITY", "STATE"]], on="STORE_NUMBER", how="left")

orders["CUSTOMER_TYPE"] = orders["CUSTOMER_TYPE"].fillna("REGISTERED")
orders["CITY"]  = orders["CITY"].fillna("0")
orders["STATE"] = orders["STATE"].fillna("0")

# Encoders
cust_type2id = {x: i+1 for i, x in enumerate(sorted(orders["CUSTOMER_TYPE"].unique()))}
city2id      = {x: i+1 for i, x in enumerate(sorted(orders["CITY"].unique()))}
state2id     = {x: i+1 for i, x in enumerate(sorted(orders["STATE"].unique()))}

orders["cust_type_id"] = orders["CUSTOMER_TYPE"].map(cust_type2id).fillna(0).astype(int)
orders["city_id"]      = orders["CITY"].map(city2id).fillna(0).astype(int)
orders["state_id"]     = orders["STATE"].map(state2id).fillna(0).astype(int)

## 3. Group Hold-out Split (5% by `CUSTOMER_ID`)
Artifacts and models will be trained **only** on the train split to avoid leakage.

In [None]:
groups = orders["CUSTOMER_ID"].values
n_splits = max(2, int(1.0 / max(1e-6, CFG["VAL_FRACTION"])))
gkf = GroupKFold(n_splits=n_splits)
train_idx, val_idx = next(iter(gkf.split(orders, groups=groups)))
train_df = orders.iloc[train_idx].copy()
val_df   = orders.iloc[val_idx].copy()

# If val too small, fallback to random groups
if len(val_df) < CFG["MIN_VAL_QUERIES"]:
    uniq_users = orders["CUSTOMER_ID"].drop_duplicates().sample(
        frac=CFG["VAL_FRACTION"], random_state=SEED
    )
    val_mask = orders["CUSTOMER_ID"].isin(uniq_users)
    val_df = orders[val_mask].copy()
    train_df = orders[~val_mask].copy()

# Build vocab from TRAIN only
vocab_items = sorted({x for xs in train_df["ITEMS"] for x in xs})
item2id = {name: i for i, name in enumerate(vocab_items)}
id2item = {i: name for name, i in item2id.items()}

train_df["ITEM_IDS"] = train_df["ITEMS"].apply(lambda xs: [item2id[x] for x in xs if x in item2id])
val_df["ITEM_IDS"]   = val_df["ITEMS"].apply(lambda xs: [item2id[x] for x in xs if x in item2id])

n_items = len(item2id)
print("Train orders:", len(train_df), "| Val orders:", len(val_df), "| #Items:", n_items)

## 4. Train-only Artifacts
- Co-occurrence similarity
- Popularity tables (store/channel/occasion/city/state/global)
- User frequency stats
- Simple item categories + global category popularity

In [None]:
def build_cooccurrence_fast(train_item_ids_series, n_items):
    indptr, indices, data = [0], [], []
    for ids in train_item_ids_series:
        uniq = np.unique(np.asarray(ids, dtype=np.int32))
        indices.extend(uniq.tolist()); data.extend([1]*len(uniq)); indptr.append(len(indices))
    B = csr_matrix((data, indices, indptr), shape=(len(indptr)-1, n_items), dtype=np.uint8)
    C = (B.T @ B).astype(np.float64)
    deg = np.asarray(C.diagonal()).astype(np.int64)
    denom = (deg[:, None] * deg[None, :]) ** 0.5
    denom[denom == 0] = 1.0
    C_coo = C.tocoo(copy=True)
    C_coo.data = C_coo.data / denom[C_coo.row, C_coo.col]
    mask = C_coo.row != C_coo.col
    rows, cols, vals = C_coo.row[mask], C_coo.col[mask], C_coo.data[mask]
    sim = {(int(i), int(j)): float(v) for i, j, v in zip(rows, cols, vals) if v > 0.0}
    return deg, sim

def build_popularity_tables_fast(train_df, item2id):
    tmp = train_df[["STORE_NUMBER","ORDER_CHANNEL_NAME","ORDER_OCCASION_NAME","CITY","STATE","ITEMS"]].explode("ITEMS")
    tmp["item_id"] = tmp["ITEMS"].map(item2id).astype("Int64")
    tmp = tmp.dropna(subset=["item_id"]).copy()
    tmp["item_id"] = tmp["item_id"].astype(int)

    pop_store_counts = tmp.groupby(["STORE_NUMBER","item_id"]).size()
    pop_chan_counts  = tmp.groupby(["ORDER_CHANNEL_NAME","item_id"]).size()
    pop_occ_counts   = tmp.groupby(["ORDER_OCCASION_NAME","item_id"]).size()
    pop_city_counts  = tmp.groupby(["CITY","item_id"]).size()
    pop_state_counts = tmp.groupby(["STATE","item_id"]).size()

    def norm_probs(s):
        return (s / s.groupby(level=0).transform("sum")).to_dict()

    pop_store = norm_probs(pop_store_counts)
    pop_chan  = norm_probs(pop_chan_counts)
    pop_occ   = norm_probs(pop_occ_counts)
    pop_city  = norm_probs(pop_city_counts)
    pop_state = norm_probs(pop_state_counts)

    glob_counts = tmp["item_id"].value_counts()
    pop_glob = (glob_counts / glob_counts.sum()).to_dict()

    return {
        "pop_store": pop_store, "pop_chan": pop_chan, "pop_occ": pop_occ,
        "pop_city": pop_city,   "pop_state": pop_state, "pop_glob": pop_glob
    }

def build_user_features_fast(train_df, item2id):
    ex = train_df[["CUSTOMER_ID","ITEMS"]].explode("ITEMS")
    ex["item_id"] = ex["ITEMS"].map(item2id).astype("Int64")
    ex = ex.dropna(subset=["item_id"]).copy()
    ex["item_id"] = ex["item_id"].astype(int)

    uif = ex.groupby(["CUSTOMER_ID","item_id"]).size()
    uto = train_df.groupby("CUSTOMER_ID")["ITEMS"].size()
    uui = ex.groupby("CUSTOMER_ID")["item_id"].nunique()

    user_item_freq = defaultdict(Counter)
    for (u, i), c in uif.items():
        user_item_freq[u][int(i)] = int(c)
    user_total_orders = {u: int(c) for u, c in uto.items()}
    user_unique_items = {u: int(c) for u, c in uui.items()}

    return {"user_item_freq": user_item_freq, "user_total_orders": user_total_orders, "user_unique_items": user_unique_items}

def build_item_categories(item2id):
    def first_token(name: str):
        toks = re.split(r"[^\w]+", str(name).strip().upper())
        return toks[0] if toks and toks[0] else "UNCAT"
    item_cat_str = {iid: first_token(name) for name, iid in item2id.items()}
    cats = sorted(set(item_cat_str.values()))
    cat2id = {c: i+1 for i, c in enumerate(cats)}  # 0 = unknown
    item_cat_id = {iid: cat2id.get(item_cat_str[iid], 0) for iid in item_cat_str}
    return item_cat_id, cat2id

def build_category_glob_pop(train_df, item2id, item_cat_id):
    ex = train_df[["ITEMS"]].explode("ITEMS")
    ex["item_id"] = ex["ITEMS"].map(item2id).astype("Int64")
    ex = ex.dropna(subset=["item_id"]).copy()
    ex["item_id"] = ex["item_id"].astype(int)
    ex["cat_id"] = ex["item_id"].map(item_cat_id).fillna(0).astype(int)
    cat_counts = ex["cat_id"].value_counts()
    return (cat_counts / cat_counts.sum()).to_dict()

item_cat_id, cat2id = build_item_categories(item2id)

item_deg_tr, sim_tr = build_cooccurrence_fast(train_df["ITEM_IDS"], n_items)
pops_tr = build_popularity_tables_fast(train_df, item2id)
userf_tr = build_user_features_fast(train_df, item2id)
cat_glob_pop_tr = build_category_glob_pop(train_df, item2id, item_cat_id)

## 5. LOO with Hard Negatives (Train)
We construct leave-one-out queries with hard negatives from:
- co-occurrence neighbors
- global popularity
- random fill if needed

In [None]:
def _build_sim_by_right(sim: dict, top_k: int = None):
    by_right = defaultdict(list)
    for (i, j), s in sim.items():
        by_right[j].append((i, s))
    for j, lst in by_right.items():
        lst.sort(key=lambda t: -t[1])
        if top_k is not None:
            by_right[j] = lst[:top_k]
    return by_right

def make_loo_rows_hard(train_df, item2id, n_items, sim, pops,
                       k_negs=30, hard_pool_top=200, glob_top=1000, preindex_top_k=2000):
    rows = []
    all_ids = np.arange(n_items, dtype=int)
    sim_by_right = _build_sim_by_right(sim, top_k=preindex_top_k)
    top_glob = sorted(pops["pop_glob"].items(), key=lambda kv: -kv[1])
    top_glob_ids = [i for i, _ in top_glob[:glob_top]]

    for _, row in train_df.iterrows():
        ids = np.fromiter((item2id[x] for x in row["ITEMS"] if x in item2id), dtype=int)
        if ids.size < 2: continue
        pos = int(np.random.choice(ids))
        cart = [int(i) for i in ids if i != pos]
        in_cart = set(cart)

        hard_pool = set()
        for j in cart:
            neigh = sim_by_right.get(j)
            if not neigh: continue
            for i, _s in neigh[:hard_pool_top]:
                if i not in in_cart and i != pos:
                    hard_pool.add(i)

        for i in top_glob_ids:
            if i not in in_cart and i != pos:
                hard_pool.add(i)
                if len(hard_pool) >= 5000: break

        if len(hard_pool) < k_negs:
            mask = np.ones(n_items, dtype=bool)
            if in_cart: mask[list(in_cart)] = False
            mask[pos] = False
            if mask.any():
                extra = np.random.choice(all_ids[mask], size=min(k_negs, int(mask.sum())), replace=False)
                hard_pool.update(extra.tolist())

        hard_pool = list(hard_pool)
        negs = random.sample(hard_pool, k_negs) if len(hard_pool) > k_negs else hard_pool

        rows.append({
            "user": row["CUSTOMER_ID"],
            "store": row.get("STORE_NUMBER"),
            "channel": row.get("ORDER_CHANNEL_NAME"),
            "occasion": row.get("ORDER_OCCASION_NAME"),
            "cart_ids": cart,
            "pos_id": pos,
            "neg_ids": negs,
            "cart_len": len(cart),
            "cust_type_id": row.get("cust_type_id", 0),
            "city_id": row.get("city_id", 0),
            "state_id": row.get("state_id", 0),
        })
    return pd.DataFrame(rows)

loo_tr = make_loo_rows_hard(
    train_df, item2id, n_items, sim_tr, pops_tr,
    k_negs=CFG["HARD_NEGS"], hard_pool_top=CFG["HARD_NEIGHBORS"], glob_top=CFG["GLOB_TOP"]
)
print("LOO train queries:", len(loo_tr))

## 6. SASRec (Train on LOO)
We use a lightweight masked-prediction variant to get a sequence signal.

In [None]:
class LooSasrecDataset(Dataset):
    def __init__(self, loo_df, max_len=20, pad_id=0, mask_id=1):
        self.loo_df = loo_df
        self.max_len = max_len
        self.pad_id = pad_id
        self.mask_id = mask_id
    def __len__(self): return len(self.loo_df)
    def __getitem__(self, idx):
        row = self.loo_df.iloc[idx]
        cart = row["cart_ids"]
        if isinstance(cart, str):
            try: cart = ast.literal_eval(cart)
            except Exception: cart = []
        elif isinstance(cart, (np.ndarray, set, tuple)): cart = list(cart)
        elif not isinstance(cart, list): cart = []
        pos = row.get("pos_id", -100)
        cart = [x for x in cart if x != pos]
        seq = cart[:self.max_len]
        labels = [-100] * self.max_len
        if len(seq) > 0:
            m = np.random.choice(len(seq))
            labels[m] = seq[m]; seq[m] = self.mask_id
        seq = seq + [self.pad_id] * (self.max_len - len(seq))
        return (
            torch.tensor(seq, dtype=torch.long),
            torch.tensor(labels, dtype=torch.long),
            torch.tensor(row.get("cust_type_id",0), dtype=torch.long),
            torch.tensor(row.get("city_id",0), dtype=torch.long),
            torch.tensor(row.get("state_id",0), dtype=torch.long),
        )

class SASRecPlus(nn.Module):
    def __init__(self, vocab_size, max_len=20, d_model=128, n_heads=4, n_layers=2,
                 n_cust=1, n_city=1, n_state=1, dropout=0.2, ffn_mult=2):
        super().__init__()
        self.item_emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_emb  = nn.Embedding(max_len, d_model)
        self.cust_emb = nn.Embedding(n_cust, d_model)
        self.city_emb = nn.Embedding(n_city, d_model)
        self.state_emb= nn.Embedding(n_state, d_model)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model*ffn_mult,
            dropout=dropout, batch_first=True, activation="gelu", norm_first=True
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size)
    def forward(self, seqs, cust_ids, city_ids, state_ids):
        pos = torch.arange(seqs.size(1), device=seqs.device).unsqueeze(0).expand_as(seqs)
        x = self.item_emb(seqs) + self.pos_emb(pos)
        x = x + self.cust_emb(cust_ids).unsqueeze(1) + self.city_emb(city_ids).unsqueeze(1) + self.state_emb(state_ids).unsqueeze(1)
        x = self.norm(self.dropout(x))
        z = self.transformer(x)
        return self.fc(z)

class LabelSmoothingCE(nn.Module):
    def __init__(self, eps=0.05, ignore_index=-100):
        super().__init__()
        self.eps = eps; self.ignore_index = ignore_index
    def forward(self, logits, target):
        valid = target != self.ignore_index
        if valid.sum() == 0: return logits.new_zeros(())
        log_probs = F.log_softmax(logits[valid], dim=-1)
        nll = F.nll_loss(log_probs, target[valid], reduction='mean')
        smooth = -log_probs.mean(dim=-1).mean()
        return (1 - self.eps) * nll + self.eps * smooth

def train_sasrec_plus(model, loader, device=DEVICE, epochs=5, lr=1e-3):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=3, gamma=0.5)
    crit = LabelSmoothingCE(eps=0.05, ignore_index=-100)
    model.to(device)
    for ep in range(1, epochs+1):
        model.train(); total = 0.0
        for seqs, labels, cust, city, state in loader:
            seqs, labels = seqs.to(device), labels.to(device)
            cust, city, state = cust.to(device), city.to(device), state.to(device)
            logits = model(seqs, cust, city, state).view(-1, model.fc.out_features)
            loss = crit(logits, labels.view(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total += float(loss.item())
        sched.step()
        print(f"[SASRec] Epoch {ep}/{epochs} loss={total:.4f}")
    return model

@torch.no_grad()
def sasrec_score_for_candidates(model, cart_ids, cust_type_id, city_id, state_id, candidate_ids, max_len=20, device=DEVICE):
    if not isinstance(cart_ids, (list, tuple, set, np.ndarray)):
        cart_ids = [] if pd.isna(cart_ids) else [int(cart_ids)]
    cart = list(cart_ids)[-max_len:]
    if len(cart) < max_len:
        cart = cart + [0]*(max_len-len(cart))
    seq  = torch.tensor([cart], dtype=torch.long, device=device)
    cust = torch.tensor([int(cust_type_id)], dtype=torch.long, device=device)
    city = torch.tensor([int(city_id)], dtype=torch.long, device=device)
    state= torch.tensor([int(state_id)], dtype=torch.long, device=device)
    logits = model(seq, cust, city, state)[:, -1, :].squeeze(0)
    idx = torch.tensor(list(candidate_ids), dtype=torch.long, device=device)
    return logits[idx].detach().cpu().numpy()

# Train SASRec
sas_ds = LooSasrecDataset(loo_tr, max_len=20)
sas_dl = DataLoader(sas_ds, batch_size=CFG["BATCH_SIZE"], shuffle=True)
sasrec = SASRecPlus(
    vocab_size=n_items, max_len=20, d_model=CFG["SASREC_D_MODEL"], n_heads=4, n_layers=2,
    n_cust=int(train_df["cust_type_id"].max())+1, n_city=int(train_df["city_id"].max())+1,
    n_state=int(train_df["state_id"].max())+1, dropout=0.2, ffn_mult=2
)
sasrec = train_sasrec_plus(sasrec, sas_dl, device=DEVICE, epochs=CFG["SASREC_EPOCHS"], lr=1e-3)

## 7. Candidate Generation + Feature Builder
- Base scoring: co-occurrence + user repeat + contextual popularity + global popularity
- SASRec blending: re-score a pre-pool and merge top-N
- Feature builder for CatBoost

In [None]:
orders_per_user_tr = train_df.groupby('CUSTOMER_ID')['ORDER_ID'].nunique().to_dict()
def segment(user_id):
    c = orders_per_user_tr.get(user_id, 0)
    return 'S1' if c >= 3 else 'S2'

def generate_candidates_for_cart_blend(sim, pops, userf, cart_ids, user, store, channel, occasion,
                                       n_items, city=None, state=None, customer_type=None,
                                       N=150, sasrec_model=None, cust_type_id=0, city_id=0, state_id=0,
                                       blend_top=60, prepool_extra_glob=1000):
    seg = segment(user)
    scores = np.zeros(n_items, dtype=np.float32)

    # co-occurrence + user repeat
    for i in range(n_items):
        if i in cart_ids: continue
        scores[i] += sum(sim.get((i, j), 0.0) for j in cart_ids)
    if seg == 'S1' and user in userf["user_item_freq"]:
        cnts = userf["user_item_freq"][user]; tot = userf["user_total_orders"].get(user, 0)
        if tot > 0:
            for i in range(n_items):
                if i in cart_ids: continue
                scores[i] += 0.25 * (cnts.get(i, 0) / tot)

    # context pops
    ctx_w = (0.15 if seg=='S1' else 0.20)
    for i in range(n_items):
        if i in cart_ids: continue
        s = 0.0
        if store is not None:   s += pops["pop_store"].get((store, i), 0.0)
        if channel is not None: s += pops["pop_chan"].get((channel, i), 0.0)
        if occasion is not None:s += pops["pop_occ"].get((occasion, i), 0.0)
        if city is not None:    s += pops["pop_city"].get((city, i), 0.0)
        if state is not None:   s += pops["pop_state"].get((state, i), 0.0)
        scores[i] += ctx_w * s

    # global pop
    popg = pops["pop_glob"]
    glob_w = (0.05 if seg=='S1' else 0.15)
    for i in range(n_items):
        if i in cart_ids: continue
        scores[i] += glob_w * popg.get(i, 0.0)

    if cart_ids:
        scores[list(cart_ids)] = -1e9
    base = np.argsort(-scores)[:N].tolist()

    # SASRec blend
    if sasrec_model is not None:
        top_glob = sorted(popg.items(), key=lambda kv: -kv[1])[:prepool_extra_glob]
        pre_pool = list({*base, *(i for i,_ in top_glob)})
        s_scores = sasrec_score_for_candidates(
            sasrec_model, list(cart_ids), cust_type_id, city_id, state_id, pre_pool, max_len=20, device=DEVICE
        )
        order = np.argsort(-s_scores)
        blend_ids = [pre_pool[i] for i in order[:blend_top] if pre_pool[i] not in cart_ids]
        merged = list(dict.fromkeys(base + blend_ids))[:max(N, len(base))]
        return np.array(merged, dtype=int)

    return np.array(base, dtype=int)

def make_features_for_pairs(sim, pops, userf, cart_ids, user, store, ch, oc, cand_ids, seg,
                            city=None, state=None, customer_type=None, sasrec_scores=None,
                            item_cat_id=None, cat_glob_pop=None):
    rows = []
    cart_cat_ids = [item_cat_id.get(j, 0) for j in cart_ids] if item_cat_id else []
    for t, i in enumerate(cand_ids):
        cooc = sum(sim.get((i, j), 0.0) for j in cart_ids)
        if seg == 'S1':
            cnts = userf["user_item_freq"].get(user, {})
            tot  = userf["user_total_orders"].get(user, 0)
            u_freq = (cnts.get(i, 0) / tot) if tot else 0.0
        else:
            u_freq = 0.0

        def ctx(pop, ctxv): return pop.get((ctxv, i), 0.0) if ctxv is not None else 0.0
        p_store = ctx(pops["pop_store"], store)
        p_chan  = ctx(pops["pop_chan"],  ch)
        p_occ   = ctx(pops["pop_occ"],   oc)
        p_city  = ctx(pops["pop_city"],  city)
        p_state = ctx(pops["pop_state"], state)
        p_glob  = pops["pop_glob"].get(i, 0.0)

        cat_id = item_cat_id.get(i, 0) if item_cat_id else 0
        same_cat_in_cart = int(cat_id in cart_cat_ids) if cat_id else 0
        cat_glob = cat_glob_pop.get(cat_id, 0.0) if cat_glob_pop else 0.0

        rows.append({
            "cooc": cooc,
            "u_freq": u_freq,
            "p_store": p_store, "p_chan": p_chan, "p_occ": p_occ,
            "p_city": p_city, "p_state": p_state, "p_glob": p_glob,
            "sasrec_score": float(sasrec_scores[t]) if sasrec_scores is not None else 0.0,
            "same_cat_in_cart": same_cat_in_cart,
            "cat_glob_pop": cat_glob,
        })
    return pd.DataFrame(rows)

## 8. Train CatBoost (YetiRank)
We construct listwise training data using LOO queries.

In [None]:
X_rows, y_rows, group_rows = [], [], []
feature_names_ref = None
qid = 0
for _, r in loo_tr.iterrows():
    user = r["user"]; store, ch, oc = r.get("store"), r.get("channel"), r.get("occasion")
    cart_ids = set(r["cart_ids"]); seg = segment(user)
    cand_ids = [r["pos_id"]] + list(r["neg_ids"])
    s_scores = sasrec_score_for_candidates(
        sasrec, list(cart_ids), r.get("cust_type_id",0), r.get("city_id",0), r.get("state_id",0),
        cand_ids, max_len=20, device=DEVICE
    )
    feats = make_features_for_pairs(
        sim_tr, pops_tr, userf_tr, cart_ids, user, store, ch, oc, cand_ids, seg,
        city=None, state=None, customer_type=None, sasrec_scores=s_scores,
        item_cat_id=item_cat_id, cat_glob_pop=cat_glob_pop_tr
    )
    if feats is None or feats.empty: continue
    if feature_names_ref is None:
        feature_names_ref = feats.columns.tolist()
    else:
        for c in feature_names_ref:
            if c not in feats.columns: feats[c] = 0.0
        feats = feats[feature_names_ref]
    y = np.zeros(len(cand_ids), dtype=np.int32); y[0] = 1
    X_rows.append(feats.values); y_rows.append(y); group_rows.append(np.full(len(cand_ids), qid, dtype=np.int64))
    qid += 1

X_tr = np.vstack(X_rows).astype(np.float32)
y_tr = np.concatenate(y_rows)
g_tr = np.concatenate(group_rows)
order_tr = np.argsort(g_tr)

train_pool = Pool(
    data=X_tr[order_tr],
    label=y_tr[order_tr],
    group_id=g_tr[order_tr],
    feature_names=feature_names_ref
)

cat_model = CatBoostRanker(
    loss_function="YetiRank",
    eval_metric="NDCG:top=3",
    iterations=CFG["CATBOOST_ITERS"],
    learning_rate=0.1,
    depth=6,
    verbose=100,
    task_type="GPU" if torch.cuda.is_available() else "CPU"
)
t0 = time.time()
cat_model.fit(train_pool)
print(f"✅ CatBoost trained in {time.time()-t0:.1f}s on TRAIN LOO")

## 9. Validation (Recall@3)
We evaluate on the hold-out validation set (**built with train-only artifacts**).

In [None]:
def eval_recall_at_3(val_df):
    hit = total = 0
    in_cand = 0
    for _, row in val_df.iterrows():
        items = row["ITEM_IDS"]
        if not isinstance(items, list) or len(items) < 2: 
            continue
        user   = row["CUSTOMER_ID"]
        store  = row.get("STORE_NUMBER")
        ch     = row.get("ORDER_CHANNEL_NAME")
        oc     = row.get("ORDER_OCCASION_NAME")
        city   = row.get("CITY")
        state  = row.get("STATE")
        cart_ids = set(items[:-1]); true_item = items[-1]

        cand_ids = generate_candidates_for_cart_blend(
            sim_tr, pops_tr, userf_tr, cart_ids, user, store, ch, oc,
            n_items=len(item2id),
            city=city, state=state, customer_type=row.get("CUSTOMER_TYPE"),
            N=CFG["BASE_CAND_N"], sasrec_model=sasrec,
            cust_type_id=row.get("cust_type_id",0), city_id=row.get("city_id",0), state_id=row.get("state_id",0),
            blend_top=CFG["SASREC_BLEND_TOP"], prepool_extra_glob=CFG["PREPOOL_GLOB"]
        ).tolist()

        # Exclude cart items
        cand_ids = [cid for cid in cand_ids if cid not in cart_ids]
        # Ensure truth is in candidates (coverage stat)
        if true_item not in cand_ids and len(cand_ids) >= 1:
            cand_ids[-1] = true_item
        if true_item in cand_ids:
            in_cand += 1

        seg = segment(user)
        s_scores = sasrec_score_for_candidates(
            sasrec, list(cart_ids),
            row.get("cust_type_id",0), row.get("city_id",0), row.get("state_id",0),
            cand_ids, max_len=20, device=DEVICE
        )
        feats = make_features_for_pairs(
            sim_tr, pops_tr, userf_tr, cart_ids, user, store, ch, oc, cand_ids, seg,
            city=city, state=state, customer_type=row.get("CUSTOMER_TYPE"),
            sasrec_scores=s_scores,
            item_cat_id=item_cat_id, cat_glob_pop=cat_glob_pop_tr
        )
        for f in cat_model.feature_names_:
            if f not in feats.columns: feats[f] = 0.0
        feats = feats[cat_model.feature_names_]

        scores = cat_model.predict(feats)
        ranked = [cand_ids[j] for j in np.argsort(-scores)]
        if true_item in ranked[:3]: 
            hit += 1
        total += 1

    recall = (hit/total) if total>0 else 0.0
    coverage = (in_cand/total) if total>0 else 0.0
    print(f"🔎 Validation — queries: {total}, candidate coverage: {coverage:.4f}, Recall@3: {recall:.4f}")
    return recall, coverage, total

val_recall3, val_cov, val_n = eval_recall_at_3(val_df)

## 10. Save Artifacts
We save:
- CatBoost model and feature names
- SASRec weights and config
- Train-only artifacts (popularity tables, user features, co-occurrence, category maps)
- `item2id` and `id2item`

In [None]:
import joblib, json, os
os.makedirs("/mnt/data/artifacts", exist_ok=True)

cat_model.save_model("/mnt/data/artifacts/catboost_yetirank.cbm")

torch.save(sasrec.state_dict(), "/mnt/data/artifacts/sasrec_train.pt")
sasrec_cfg = {
    "vocab_size": n_items, "max_len": 20, "d_model": CFG["SASREC_D_MODEL"],
    "n_heads": 4, "n_layers": 2,
    "n_cust": int(train_df["cust_type_id"].max())+1,
    "n_city": int(train_df["city_id"].max())+1,
    "n_state": int(train_df["state_id"].max())+1,
}
with open("/mnt/data/artifacts/sasrec_config.json", "w") as f:
    json.dump(sasrec_cfg, f)

with open("/mnt/data/artifacts/catboost_feature_names.json", "w") as f:
    json.dump(cat_model.feature_names_, f)

with open("/mnt/data/artifacts/item2id.json", "w") as f:
    json.dump(item2id, f)
with open("/mnt/data/artifacts/id2item.json", "w") as f:
    json.dump(id2item, f)

joblib.dump(pops_tr, "/mnt/data/artifacts/pops_train.joblib")
joblib.dump(userf_tr, "/mnt/data/artifacts/userf_train.joblib")
joblib.dump(item_cat_id, "/mnt/data/artifacts/item_cat_id.joblib")
joblib.dump(cat_glob_pop_tr, "/mnt/data/artifacts/cat_glob_pop_train.joblib")
np.save("/mnt/data/artifacts/item_deg_train.npy", item_deg_tr)
joblib.dump(sim_tr, "/mnt/data/artifacts/sim_train.joblib")

print("✅ Saved artifacts to /mnt/data/artifacts")

## 11. Test Inference + CSV Export
Produces `sasrec_catboost_recommendations.csv` with required columns:
`CUSTOMER_ID, ORDER_ID, item1, item2, item3, Recommendation 1, Recommendation 2, Recommendation 3`.

In [None]:
# Merge city/state for context
testdf = testdf.merge(stores[["STORE_NUMBER","CITY","STATE"]], on="STORE_NUMBER", how="left")
testdf["CITY"] = testdf["CITY"].fillna("0")
testdf["STATE"] = testdf["STATE"].fillna("0")

def to_cart_ids_from_three(row, item2id):
    items = []
    for col in ["item1", "item2", "item3"]:
        v = row.get(col) if isinstance(row, dict) else row[col]
        if pd.notna(v) and str(v).strip():
            nm = canon_item(str(v))
            if nm in item2id:
                items.append(item2id[nm])
    return items

testdf["CART_IDS"] = testdf.apply(lambda r: to_cart_ids_from_three(r, item2id), axis=1)
testdf["cust_type_id"] = testdf["CUSTOMER_TYPE"].map(cust_type2id).fillna(0).astype(int) if "CUSTOMER_TYPE" in testdf.columns else 0
testdf["city_id"]      = testdf["CITY"].map(city2id).fillna(0).astype(int)
testdf["state_id"]     = testdf["STATE"].map(state2id).fillna(0).astype(int)

def ids_to_names(ids):
    return [id2item[i] for i in ids if i in id2item]

pred_rows = []
for idx, row in testdf.iterrows():
    user   = row["CUSTOMER_ID"]
    order  = row["ORDER_ID"]
    store  = row.get("STORE_NUMBER")
    ch     = row.get("ORDER_CHANNEL_NAME")
    oc     = row.get("ORDER_OCCASION_NAME")
    city   = row.get("CITY")
    state  = row.get("STATE")
    cart   = row["CART_IDS"]
    seg    = segment(user)

    cand_ids = generate_candidates_for_cart_blend(
        sim_tr, pops_tr, userf_tr, set(cart), user, store, ch, oc,
        n_items=len(item2id),
        city=city, state=state, customer_type=row.get("CUSTOMER_TYPE"),
        N=CFG["BASE_CAND_N"], sasrec_model=sasrec,
        cust_type_id=row.get("cust_type_id",0), city_id=row.get("city_id",0), state_id=row.get("state_id",0),
        blend_top=CFG["SASREC_BLEND_TOP"], prepool_extra_glob=CFG["PREPOOL_GLOB"]
    ).tolist()

    s_scores = sasrec_score_for_candidates(
        sasrec, list(cart),
        row.get("cust_type_id",0), row.get("city_id",0), row.get("state_id",0),
        cand_ids, max_len=20, device=DEVICE
    )
    feats = make_features_for_pairs(
        sim_tr, pops_tr, userf_tr, set(cart), user, store, ch, oc, cand_ids, seg,
        city=city, state=state, customer_type=row.get("CUSTOMER_TYPE"),
        sasrec_scores=s_scores,
        item_cat_id=item_cat_id, cat_glob_pop=cat_glob_pop_tr
    )
    # align features
    for f in cat_model.feature_names_:
        if f not in feats.columns: feats[f] = 0.0
    feats = feats[cat_model.feature_names_]

    scores = cat_model.predict(feats)
    ranked_ids = [cand_ids[j] for j in np.argsort(-scores)]
    top3_ids = ranked_ids[:3]
    top3_names = ids_to_names(top3_ids)

    i1 = row.get("item1", "")
    i2 = row.get("item2", "")
    i3 = row.get("item3", "")

    pred_rows.append({
        "CUSTOMER_ID": user,
        "ORDER_ID": order,
        "item1": i1 if pd.notna(i1) else "",
        "item2": i2 if pd.notna(i2) else "",
        "item3": i3 if pd.notna(i3) else "",
        "Recommendation 1": top3_names[0] if len(top3_names) > 0 else "",
        "Recommendation 2": top3_names[1] if len(top3_names) > 1 else "",
        "Recommendation 3": top3_names[2] if len(top3_names) > 2 else "",
    })

submission = pd.DataFrame(pred_rows, columns=[
    "CUSTOMER_ID","ORDER_ID","item1","item2","item3",
    "Recommendation 1","Recommendation 2","Recommendation 3"
])

out_path = "/mnt/data/sasrec_catboost_recommendations.csv"
submission.to_csv(out_path, index=False)
print("✅ Wrote:", out_path)
submission.head(3)

## 12. Validation Summary

In [None]:
print(f"\n==== Validation Summary ====")
print(f"Queries: {val_n}, Candidate coverage: {val_cov:.4f}, Recall@3: {val_recall3:.4f}")