# SASRec Training for CoT-Rec

This notebook trains the SASRec (Self-Attentive Sequential Recommendation) model, which serves as the retriever in Stage 1 of CoT-Rec.

## Prerequisites
1. Run `preprocess_amazon.py` to generate:
   - `datasets/processed/Grocery_and_Gourmet_Food.csv`
   - `datasets/processed/Grocery_and_Gourmet_Food.json`
2. Upload these files to Colab or mount Google Drive

## Output
After training, this notebook will generate:
- `SASRec/checkpoint/Grocery_and_Gourmet_Food.pth` - Trained model
- `SASRec/checkpoint/Grocery_and_Gourmet_Food_rec_list_valid.pkl` - Validation recommendations
- `SASRec/checkpoint/Grocery_and_Gourmet_Food_rec_list_test.pkl` - Test recommendations


## Step 0: Setup and Installation


In [1]:
# Install required packages
!pip install torch pandas tqdm -q


In [None]:
# Mount Google Drive (if files are stored there)
from google.colab import drive
drive.mount('/content/drive')

# Or upload files directly in Colab
# Set your working directory
import os
WORK_DIR = '/content/drive/MyDrive/CoT-Rec'  # Change this to your directory
os.chdir(WORK_DIR)

# Create checkpoint directory
os.makedirs('SASRec/checkpoint', exist_ok=True)


Mounted at /content/drive


## Step 1: Configuration


In [None]:
# Configuration - Modify these as needed
DATASET_NAME = 'Grocery_and_Gourmet_Food'
EMBEDDING_DIM = 128
MAX_LENGTH = 32
NUM_LAYERS = 2
BATCH_SIZE = 256
NUM_EPOCHS = 100
NUM_PATIENCE = 5  # Early stopping patience
LEARNING_RATE = 1e-3
DROPOUT = 0.2
TOPK_LIST = [10]  # Top-k for evaluation
VERBOSE = 100  # Print loss every N batches
SEED = 2025

# Device configuration
import torch
DEVICE_ID = 0  # Use GPU 0, set to -1 for CPU
device = torch.device(f'cuda:{DEVICE_ID}' if DEVICE_ID >= 0 and torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# File paths
FILEPATH = f'datasets/processed/{DATASET_NAME}.csv'
print(f"Data file: {FILEPATH}")
print(f"Checkpoint directory: SASRec/checkpoint/")


## Step 2: Data Loading Class


In [None]:
import torch
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset


class ItemSequenceDataset(Dataset):
    """
    Dataset class for sequential recommendation.
    Splits each user's sequence into train/valid/test sets.
    """
    def __init__(self, filepath, max_length):
        df = pd.read_csv(filepath, names=['user_id', 'item_id'], usecols=[0, 1])
        self.num_users, self.num_items = df['user_id'].max() + 1, df['item_id'].max() + 1

        # Build sequences for each user
        self.all_records = [[] for _ in range(self.num_users)]
        for _, row in tqdm(df.iterrows(), desc="Building user sequences"):
            user_id, item_id = row.iloc[0], row.iloc[1]
            self.all_records[user_id].append(item_id)

        print('# Users:', self.num_users)
        print('# Items:', self.num_items)
        print('# Interactions:', len(df))

        X_train, y_train = [], []
        X_valid, y_valid = [], []
        X_test, y_test = [], []

        # Split sequences into train/valid/test
        for seq in tqdm(self.all_records, desc="Creating sequences"):
            # Training: all except last 2 items
            train_seq = seq[:-2]
            if len(train_seq) < max_length:
                X_train.append((max_length - len(train_seq) + 1) * [self.num_items] + train_seq[:-1])
                y_train.append((max_length - len(train_seq) + 1) * [self.num_items] + train_seq[1:])
            else:
                for i in range(len(train_seq) - max_length):
                    X_train.append(train_seq[i:i+max_length])
                    y_train.append(train_seq[i+1:i+max_length+1])

            # Validation: all except last 1 item (predicts second-to-last)
            valid_seq = seq[:-1]
            if len(valid_seq) - 1 < max_length:
                X_valid.append((max_length - len(valid_seq) + 1) * [self.num_items] + valid_seq[:-1])
            else:
                X_valid.append(valid_seq[-(max_length+1):-1])
            y_valid.append(valid_seq[-1])

            # Test: full sequence (predicts last item)
            test_seq = seq
            if len(test_seq) - 1 < max_length:
                X_test.append((max_length - len(test_seq) + 1) * [self.num_items] + test_seq[:-1])
            else:
                X_test.append(test_seq[-(max_length+1):-1])
            y_test.append(test_seq[-1])

        self.X_train, self.y_train = torch.tensor(X_train), torch.tensor(y_train)
        self.X_valid, self.y_valid = torch.tensor(X_valid), torch.tensor(y_valid)
        self.X_test, self.y_test = torch.tensor(X_test), torch.tensor(y_test)

        print('Data loading completed.')

    def __len__(self):
        return len(self.X_train)

    def __getitem__(self, idx):
        return self.X_train[idx], self.y_train[idx]

print("‚úÖ ItemSequenceDataset class loaded!")


In [None]:
import torch.nn as nn
import torch.nn.functional as F


class UnidirectionalSelfAttention(nn.Module):
    """Unidirectional (causal) self-attention layer."""
    def __init__(self, embedding_dim):
        super(UnidirectionalSelfAttention, self).__init__()
        self.query = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = nn.Linear(embedding_dim, embedding_dim, bias=False)

    def forward(self, x):
        # [batch_size, seq_len, embedding_dim]
        seq_len, embedding_dim = x.size(1), x.size(2)
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (embedding_dim ** 0.5)
        # Causal mask: only attend to previous positions
        mask = torch.tril(torch.ones(seq_len, seq_len)).bool().unsqueeze(0).to(attention_scores.device)
        attention_scores = attention_scores.masked_fill(~mask, float('-inf'))
        attention_weights = F.softmax(attention_scores, dim=-1)  # (batch_size, seq_len, seq_len)
        attention_output = torch.matmul(attention_weights, V)    # (batch_size, seq_len, embedding_dim)
        return attention_output


class TransformerLayer(nn.Module):
    """Transformer layer with self-attention and feed-forward network."""
    def __init__(self, embedding_dim, dropout):
        super(TransformerLayer, self).__init__()
        self.attn = UnidirectionalSelfAttention(embedding_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim)
        )
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # [batch_size, seq_len, embedding_dim]
        attn_output = self.attn(x)
        x = self.layer_norm(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.layer_norm(x + self.dropout(ff_output))
        return x


class SASRec(nn.Module):
    """
    SASRec: Self-Attentive Sequential Recommendation Model.

    Architecture:
    - Item embeddings + Position embeddings
    - Stack of Transformer layers
    - Output: logits for next item prediction
    """
    def __init__(self, num_items, embedding_dim=64, max_length=50, num_layers=2, dropout=0.2, std=1e-3):
        super(SASRec, self).__init__()
        self.std = std
        self.item_embedding = nn.Embedding(num_items + 1, embedding_dim, padding_idx=num_items)  # padding
        self.position_embedding = nn.Embedding(max_length, embedding_dim)
        self.attn_layers = nn.ModuleList([TransformerLayer(embedding_dim, dropout) for _ in range(num_layers)])
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights."""
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, item_ids):
        batch_size, seq_len = item_ids.shape  # [batch_size, seq_len]
        positions = torch.arange(seq_len, device=item_ids.device).unsqueeze(0).repeat(batch_size, 1)  # [batch_size, seq_len]
        item_embeds = self.item_embedding(item_ids)      # [batch_size, seq_len, embedding_dim]
        pos_embeds = self.position_embedding(positions)  # [batch_size, seq_len, embedding_dim]
        x = item_embeds + pos_embeds
        for attn_layer in self.attn_layers:
            x = attn_layer(x)
        logits = torch.matmul(x, self.item_embedding.weight[:-1].T)  # [batch_size, seq_len, num_items]
        return logits

print("‚úÖ Model classes loaded!")


In [None]:
import math


class Metrics:
    """Metrics calculator for Hit Rate, NDCG, and MRR."""
    def __init__(self, topk_list):
        self.topk_list = topk_list
        self.hit_total = {k: 0 for k in self.topk_list}
        self.ndcg_total = {k: 0 for k in self.topk_list}
        self.mrr_total = {k: 0 for k in self.topk_list}
        self.rec_list = []  # Stores (user_index, recommendation_list, target_item)
        self.total_nums = 0

    def get(self):
        """Get averaged metrics."""
        hit_total = {k: self.hit_total[k] / self.total_nums for k in self.topk_list}
        ndcg_total = {k: self.ndcg_total[k] / self.total_nums for k in self.topk_list}
        mrr_total = {k: self.mrr_total[k] / self.total_nums for k in self.topk_list}
        return hit_total, ndcg_total, mrr_total, self.rec_list

    def accumulate(self, ranks_list, y, start=0):
        """
        Accumulate metrics for a batch.

        Args:
            ranks_list: List of recommendation lists for each user in batch
            y: List of target items
            start: Starting user index (for tracking user IDs)
        """
        batch_size = len(y)
        for i in range(batch_size):
            ranks, true_item = ranks_list[i], y[i]
            if true_item in ranks:
                rank = ranks.index(true_item) + 1
                self.rec_list.append((start+i, ranks, true_item))
                for k in self.topk_list:
                    if rank <= k:
                        self.hit_total[k] += 1
                        self.ndcg_total[k] += 1 / math.log2(rank + 1)
                        self.mrr_total[k] += 1 / rank
        self.total_nums += batch_size


def get_top_k_recommendations(scores, all_records, k, phase):
    """
    Get top-k recommendations, filtering out already-interacted items.

    Args:
        scores: Prediction scores [batch_size, num_items]
        all_records: List of user interaction sequences
        k: Number of recommendations to return
        phase: 'valid' or 'test'

    Returns:
        Top-k item indices for each user
    """
    delta = 2 if phase == 'valid' else 1
    for idx, interacted_items in enumerate(all_records):
        # Mask out items user has already interacted with (except last delta items)
        scores[idx, interacted_items[:-delta]] = -torch.inf
    top_scores, top_indices = torch.topk(scores, k, dim=1)
    return top_indices

print("‚úÖ Utility classes loaded!")


## Step 5: Training and Evaluation Functions


In [None]:
from torch.utils.data import DataLoader
from torch import optim
from torch import nn


def train(dataloader, model, loss_func, optimizer, epoch, device, verbose=100):
    """
    Train the model for one epoch.

    Args:
        dataloader: DataLoader for training data
        model: SASRec model
        loss_func: Loss function
        optimizer: Optimizer
        epoch: Current epoch number
        device: Device to run on
        verbose: Print loss every N batches
    """
    num_batches = len(dataloader)
    model.train()
    train_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        logits = model(X)                         # [batch_size, seq_len, num_items]
        logits = logits.view(-1, logits.size(2))  # [batch_size * seq_len, num_items]
        y = y.view(-1)                            # [batch_size * seq_len]
        loss = loss_func(logits, y)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch % verbose == 0:
            print(f"loss: {train_loss/(batch+1):>7f}  [{batch+1:>5d}/{num_batches:>5d}] epoch: {epoch}")


def test(dataset, model, device, batch_size, topk_list, phase):
    """
    Evaluate the model on validation or test set.

    Args:
        dataset: ItemSequenceDataset
        model: SASRec model
        device: Device to run on
        batch_size: Batch size for evaluation
        topk_list: List of top-k values for evaluation
        phase: 'valid' or 'test'

    Returns:
        hit: Hit Rate dictionary
        ndcg: NDCG dictionary
        mrr: MRR dictionary
        rec_list: List of (user_index, recommendation_list, target_item)
    """
    X_all, y_all = (dataset.X_valid, dataset.y_valid) if phase == 'valid' else (dataset.X_test, dataset.y_test)
    model.eval()
    metrics = Metrics(topk_list)
    with torch.no_grad():
        start = 0
        while True:
            end = start + batch_size
            if end > len(y_all):
                end = len(y_all)
            X = X_all[start:end]
            y = y_all[start:end]
            X, y = X.to(device), y.to(device)
            # Get prediction scores for the last position in sequence
            scores = model(X)[:, -1, :].squeeze(1)  # [batch_size, num_items]
            # Get top-k recommendations
            ranks_list = get_top_k_recommendations(scores, dataset.all_records[start:end], max(topk_list), phase)
            metrics.accumulate(ranks_list.tolist(), y.tolist(), start)
            start += batch_size
            if end == len(y_all):
                break
    hit, ndcg, mrr, rec_list = metrics.get()
    print(f'[{phase}]')
    print("Hit:", hit)
    print("NDCG:", ndcg)
    print("MRR:", mrr)
    return hit, ndcg, mrr, rec_list

print("‚úÖ Training and evaluation functions loaded!")


## Step 6: Load Dataset


In [None]:
print("="*60)
print("Loading Dataset")
print("="*60)

# Set random seed
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# Load dataset
dataset = ItemSequenceDataset(FILEPATH, MAX_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"\nDataset loaded successfully!")
print(f"Training samples: {len(dataset)}")
print(f"Validation samples: {len(dataset.y_valid)}")
print(f"Test samples: {len(dataset.y_test)}")


## Step 7: Initialize Model


In [None]:
print("="*60)
print("Initializing Model")
print("="*60)

# Create model
model = SASRec(
    num_items=dataset.num_items,
    embedding_dim=EMBEDDING_DIM,
    max_length=MAX_LENGTH,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel created!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Initialize optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_func = nn.CrossEntropyLoss(ignore_index=dataset.num_items)

print(f"Optimizer: Adam (lr={LEARNING_RATE})")
print(f"Loss function: CrossEntropyLoss")
print(f"Device: {device}")


## Step 8: Initial Evaluation (Before Training)


In [None]:
print("="*60)
print("Initial Evaluation (Untrained Model)")
print("="*60)

# Evaluate on validation set
print("\nValidation Set:")
hit_valid, ndcg_valid, mrr_valid, _ = test(dataset, model, device, BATCH_SIZE, TOPK_LIST, phase='valid')

# Evaluate on test set
print("\nTest Set:")
hit_test, ndcg_test, mrr_test, _ = test(dataset, model, device, BATCH_SIZE, TOPK_LIST, phase='test')


## Step 9: Training Loop


In [None]:
import pickle

print("="*60)
print("Starting Training")
print("="*60)
print(f"Total epochs: {NUM_EPOCHS}")
print(f"Early stopping patience: {NUM_PATIENCE}")
print(f"Best model will be saved based on validation NDCG@{max(TOPK_LIST)}")
print("="*60)

patience = NUM_PATIENCE
best_ndcg_valid = 0.0
best_valid, best_test, best_epoch = None, None, None

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch}/{NUM_EPOCHS}")
    print(f"{'='*60}")

    # Train
    train(dataloader, model, loss_func, optimizer, epoch, device, VERBOSE)

    # Evaluate on validation set
    hit_valid, ndcg_valid, mrr_valid, rec_list_valid = test(
        dataset, model, device, BATCH_SIZE, TOPK_LIST, phase='valid'
    )

    # Evaluate on test set
    hit_test, ndcg_test, mrr_test, rec_list_test = test(
        dataset, model, device, BATCH_SIZE, TOPK_LIST, phase='test'
    )

    # Check if this is the best model
    current_ndcg = ndcg_valid[max(TOPK_LIST)]
    if current_ndcg >= best_ndcg_valid:
        patience = NUM_PATIENCE
        best_ndcg_valid = current_ndcg
        best_valid = (hit_valid, ndcg_valid, mrr_valid)
        best_test = (hit_test, ndcg_test, mrr_test)
        best_epoch = epoch

        # Save model and recommendation lists
        checkpoint_path = f"SASRec/checkpoint/{DATASET_NAME}.pth"
        torch.save(model, checkpoint_path)
        print(f"\n‚úÖ Best model saved! (NDCG@{max(TOPK_LIST)}: {best_ndcg_valid:.4f})")

        # Save recommendation lists
        with open(f'SASRec/checkpoint/{DATASET_NAME}_rec_list_valid.pkl', 'wb') as f:
            pickle.dump(rec_list_valid, f)
        with open(f'SASRec/checkpoint/{DATASET_NAME}_rec_list_test.pkl', 'wb') as f:
            pickle.dump(rec_list_test, f)
        print(f"‚úÖ Recommendation lists saved!")
    else:
        patience -= 1
        print(f"\n‚ö†Ô∏è  No improvement. Patience: {patience}/{NUM_PATIENCE}")
        if patience == 0:
            print(f"\nüõë Early stopping triggered!")
            break

print(f"\n{'='*60}")
print("Training Complete!")
print(f"{'='*60}")
print(f"Best epoch: {best_epoch}")
print(f"Best validation NDCG@{max(TOPK_LIST)}: {best_ndcg_valid:.4f}")
print(f"\nBest validation results:")
print(f"  Hit: {best_valid[0]}")
print(f"  NDCG: {best_valid[1]}")
print(f"  MRR: {best_valid[2]}")
print(f"\nBest test results:")
print(f"  Hit: {best_test[0]}")
print(f"  NDCG: {best_test[1]}")
print(f"  MRR: {best_test[2]}")


## Step 10: Save Training Log


In [None]:
# Save training log
log_file = f"SASRec/checkpoint/{DATASET_NAME}.log"
with open(log_file, 'w') as f:
    f.write(f'Dataset: {DATASET_NAME}\n')
    f.write(f'Best epoch: {best_epoch}\n')
    f.write(f'Best validation NDCG@{max(TOPK_LIST)}: {best_ndcg_valid:.4f}\n')
    f.write(f'\nBest validation results:\n')
    f.write(f'  Hit: {best_valid[0]}\n')
    f.write(f'  NDCG: {best_valid[1]}\n')
    f.write(f'  MRR: {best_valid[2]}\n')
    f.write(f'\nBest test results:\n')
    f.write(f'  Hit: {best_test[0]}\n')
    f.write(f'  NDCG: {best_test[1]}\n')
    f.write(f'  MRR: {best_test[2]}\n')

print(f"‚úÖ Training log saved: {log_file}")


## Step 11: Verify Output Files


In [None]:
import os

print("="*60)
print("Verifying Output Files")
print("="*60)

# Check if files exist
files_to_check = [
    f"SASRec/checkpoint/{DATASET_NAME}.pth",
    f"SASRec/checkpoint/{DATASET_NAME}_rec_list_valid.pkl",
    f"SASRec/checkpoint/{DATASET_NAME}_rec_list_test.pkl",
    f"SASRec/checkpoint/{DATASET_NAME}.log"
]

for file_path in files_to_check:
    if os.path.exists(file_path):
        size = os.path.getsize(file_path)
        print(f"‚úÖ {file_path} ({size:,} bytes)")
    else:
        print(f"‚ùå {file_path} - NOT FOUND")

# Load and verify recommendation lists
print("\n" + "="*60)
print("Verifying Recommendation Lists")
print("="*60)

with open(f'SASRec/checkpoint/{DATASET_NAME}_rec_list_valid.pkl', 'rb') as f:
    rec_list_valid = pickle.load(f)
with open(f'SASRec/checkpoint/{DATASET_NAME}_rec_list_test.pkl', 'rb') as f:
    rec_list_test = pickle.load(f)

print(f"\nValidation recommendations: {len(rec_list_valid)} entries")
print(f"Test recommendations: {len(rec_list_test)} entries")

# Show sample
if rec_list_valid:
    sample = rec_list_valid[0]
    print(f"\nSample validation entry:")
    print(f"  User index: {sample[0]}")
    print(f"  Recommendation list: {sample[1][:5]}... (showing first 5)")
    print(f"  Target item: {sample[2]}")

print("\n" + "="*60)
print("üéâ SASRec Training Complete!")
print("="*60)
print("\nNext steps:")
print("1. Use the pickle files in Stage 1 extraction:")
print("   - SASRec/checkpoint/Grocery_and_Gourmet_Food_rec_list_valid.pkl")
print("   - SASRec/checkpoint/Grocery_and_Gourmet_Food_rec_list_test.pkl")
print("2. Run Stage1_Extraction_Colab.ipynb to extract personalized information")
