# Neural Matrix Factorization with SBERT Embeddings


This notebook implements an enhanced Neural Matrix Factorization model that incorporates Sentence-BERT embeddings for content-based representation of items, improving the model's ability to handle cold-start scenarios.

## Overview:
 1. Data loading and preprocessing
 2. Sentence-BERT embedding extraction
 3. Model definition (NeuMF++ architecture with content embeddings)
 4. Training with early stopping
 5. Evaluation and cold-start analysis
 6. Model saving for inference

## 1. Setup and Data Loading

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [20]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, f1_score
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
import gc
import os
import pickle

In [37]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

# Define model save directory - for Google Drive
MODEL_DIR = "/content/drive/MyDrive/bt4222data/sbert_models"
os.makedirs(MODEL_DIR, exist_ok=True)
print(f"Model will be saved to: {MODEL_DIR}")

Using device: cuda
Model will be saved to: /content/drive/MyDrive/bt4222data/sbert_models


### Loading Data

In [38]:
reviews_df = pd.read_csv("/content/drive/My Drive/bt4222data/Reviews Data Cleaned/cleaned_reviews.csv", keep_default_na=False)
metadata_df = pd.read_csv("/content/drive/My Drive/bt4222data/Meta Data Cleaned/final_metadata_cleaned.csv", keep_default_na=False)


In [39]:
# Print basic info about the datasets
print(f"Reviews shape: {reviews_df.shape}")
print(f"Metadata shape: {metadata_df.shape}")

Reviews shape: (1689188, 18)
Metadata shape: (492009, 19)


## 2. Data Preprocessing

In [40]:
# Filter to common ASINs if needed
common_asins = set(reviews_df['asin']).intersection(set(metadata_df['asin']))
print(f"Common ASINs: {len(common_asins)}")

reviews_df = reviews_df[reviews_df['asin'].isin(common_asins)]
metadata_df = metadata_df[metadata_df['asin'].isin(common_asins)]

# Convert ratings to binary (1 if rating >= 3, else 0)
reviews_df['interaction'] = (reviews_df['overall'] >= 3).astype(int)

# Display binary interaction distribution
print(f"Interaction distribution (binary):\n{reviews_df['interaction'].value_counts()}")

# Encode user and item IDs
user_encoder = LabelEncoder()
item_encoder = LabelEncoder()

reviews_df['user_idx'] = user_encoder.fit_transform(reviews_df['reviewerID'])
reviews_df['item_idx'] = item_encoder.fit_transform(reviews_df['asin'])

# Create a mapping from asin to item_idx for later use
asin_to_idx = dict(zip(reviews_df['asin'], reviews_df['item_idx']))

# Add item_idx to metadata_df
metadata_df['item_idx'] = metadata_df['asin'].map(asin_to_idx)

# Fill any NaN in description with empty string
metadata_df['description'] = metadata_df['description'].fillna('')
metadata_df['title'] = metadata_df['title'].fillna('')

print(f"Unique users: {reviews_df['reviewerID'].nunique()}")
print(f"Unique items: {reviews_df['asin'].nunique()}")

Common ASINs: 61709


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  reviews_df['interaction'] = (reviews_df['overall'] >= 3).astype(int)


Interaction distribution (binary):
interaction
1    1436504
0     183627
Name: count, dtype: int64
Unique users: 192395
Unique items: 61709


## 3. Sentence-BERT Embeddings Extraction

In [41]:
def generate_item_embeddings(metadata_df, model_name='all-MiniLM-L6-v2', batch_size=32, cache_file=None):
    """
    Generate Sentence-BERT embeddings for item metadata (title + description)
    """
    # Check if cached embeddings exist
    if cache_file and os.path.exists(cache_file):
        print(f"Loading cached embeddings from {cache_file}")
        return np.load(cache_file)

    # Load Sentence-BERT model
    print(f"Loading Sentence-BERT model: {model_name}")
    model = SentenceTransformer(model_name, device=device)

    # Combine title and description
    texts = []
    for _, row in tqdm(metadata_df.iterrows(), total=len(metadata_df), desc="Preparing texts"):
        # Combine title and description, with special handling for missing values
        combined_text = str(row['title'])
        if row['description']:
            combined_text += " " + str(row['description'])
        texts.append(combined_text)

    # Generate embeddings in batches to avoid OOM issues
    print("Generating embeddings...")
    embeddings = model.encode(texts, batch_size=batch_size, show_progress_bar=True)

    # Save to cache if specified
    if cache_file:
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        np.save(cache_file, embeddings)
        print(f"Saved embeddings to {cache_file}")

    return embeddings

# Generate or load item embeddings
CACHE_DIR = os.path.join(MODEL_DIR, "embeddings/")
os.makedirs(CACHE_DIR, exist_ok=True)
EMBEDDING_CACHE = os.path.join(CACHE_DIR, "item_sbert_embeddings.npy")

# Check if embeddings already exist
if os.path.exists(EMBEDDING_CACHE):
    print(f"Loading cached embeddings from {EMBEDDING_CACHE}")
    item_embeddings = np.load(EMBEDDING_CACHE)
    print("Embeddings loaded successfully")
else:
    # Sort metadata by item_idx to ensure consistent ordering
    metadata_df = metadata_df.sort_values('item_idx').reset_index(drop=True)
    item_embeddings = generate_item_embeddings(metadata_df, cache_file=EMBEDDING_CACHE)

print(f"Embeddings shape: {item_embeddings.shape}")

Loading Sentence-BERT model: all-MiniLM-L6-v2


Preparing texts:   0%|          | 0/61709 [00:00<?, ?it/s]

Generating embeddings...


Batches:   0%|          | 0/1929 [00:00<?, ?it/s]

Saved embeddings to /content/drive/MyDrive/bt4222data/sbert_models/embeddings/item_sbert_embeddings.npy
Embeddings shape: (61709, 384)


## 4. Dataset and DataLoader creation

In [42]:
class NeuMFDataset(Dataset):
    def __init__(self, interactions_df, item_embeddings):
        self.users = torch.tensor(interactions_df['user_idx'].values, dtype=torch.long)
        self.items = torch.tensor(interactions_df['item_idx'].values, dtype=torch.long)
        self.labels = torch.tensor(interactions_df['interaction'].values, dtype=torch.float)
        self.item_embeddings = torch.tensor(item_embeddings, dtype=torch.float)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        user_id = self.users[idx]
        item_id = self.items[idx]
        label = self.labels[idx]
        item_emb = self.item_embeddings[item_id]

        return user_id, item_id, item_emb, label

# Train/validation/test split using stratified sampling
train_df, temp_df = train_test_split(
    reviews_df, test_size=0.3, random_state=seed,
    stratify=reviews_df['interaction']
)

val_df, test_df = train_test_split(
    temp_df, test_size=0.5, random_state=seed,
    stratify=temp_df['interaction']
)

print(f"Train size: {len(train_df)}")
print(f"Validation size: {len(val_df)}")
print(f"Test size: {len(test_df)}")

# Create datasets and loaders
BATCH_SIZE = 1024

train_dataset = NeuMFDataset(train_df, item_embeddings)
val_dataset = NeuMFDataset(val_df, item_embeddings)
test_dataset = NeuMFDataset(test_df, item_embeddings)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

Train size: 1134091
Validation size: 243020
Test size: 243020


## 5. Model Definition

NeuMF++ model that incorporates Sentence-BERT embeddings for enhanced content-based recommendations.
This model combines collaborative filtering with content-based filtering approaches.

In [43]:
class NeuMFPlusPlus(nn.Module):
    """
    Neural Matrix Factorization Plus Plus (NeuMF++) model integrating Sentence-BERT embeddings
    for content-based representation of items.
    """
    def __init__(self, num_users, num_items, item_bert_dim,
                 embedding_dim=64, mlp_dims=[128, 64, 32], dropout_rate=0.2):
        super(NeuMFPlusPlus, self).__init__()

        # GMF part
        self.user_gmf_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_gmf_embedding = nn.Embedding(num_items, embedding_dim)

        # MLP part
        self.user_mlp_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_mlp_embedding = nn.Embedding(num_items, embedding_dim)

        # Item content projection (Sentence-BERT embedding)
        self.item_bert_projection = nn.Linear(item_bert_dim, embedding_dim)
        self.bert_bn = nn.BatchNorm1d(embedding_dim)

        # MLP layers
        mlp_input_dim = embedding_dim * 2 + embedding_dim  # user + item + bert
        self.mlp_layers = nn.ModuleList()
        self.mlp_batch_norms = nn.ModuleList()

        # First layer
        self.mlp_layers.append(nn.Linear(mlp_input_dim, mlp_dims[0]))
        self.mlp_batch_norms.append(nn.BatchNorm1d(mlp_dims[0]))

        # Hidden layers
        for i in range(len(mlp_dims)-1):
            self.mlp_layers.append(nn.Linear(mlp_dims[i], mlp_dims[i+1]))
            self.mlp_batch_norms.append(nn.BatchNorm1d(mlp_dims[i+1]))

        # Output layer
        self.output_layer = nn.Linear(mlp_dims[-1] + embedding_dim, 1)

        # Activation and dropout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.sigmoid = nn.Sigmoid()

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        # Initialize embeddings with normal distribution
        nn.init.normal_(self.user_gmf_embedding.weight, std=0.01)
        nn.init.normal_(self.item_gmf_embedding.weight, std=0.01)
        nn.init.normal_(self.user_mlp_embedding.weight, std=0.01)
        nn.init.normal_(self.item_mlp_embedding.weight, std=0.01)

        # Initialize linear layers with Xavier/Glorot
        for layer in self.mlp_layers:
            nn.init.xavier_uniform_(layer.weight)

        nn.init.xavier_uniform_(self.output_layer.weight)
        nn.init.xavier_uniform_(self.item_bert_projection.weight)

        # Initialize biases with zeros
        for layer in self.mlp_layers:
            nn.init.zeros_(layer.bias)

        nn.init.zeros_(self.output_layer.bias)
        nn.init.zeros_(self.item_bert_projection.bias)

    def forward(self, user_indices, item_indices, item_bert_emb):
        # GMF part
        user_gmf_emb = self.user_gmf_embedding(user_indices)
        item_gmf_emb = self.item_gmf_embedding(item_indices)
        gmf_output = user_gmf_emb * item_gmf_emb

        # MLP part
        user_mlp_emb = self.user_mlp_embedding(user_indices)
        item_mlp_emb = self.item_mlp_embedding(item_indices)

        # Process item Sentence-BERT embedding
        item_bert_emb = self.item_bert_projection(item_bert_emb)
        item_bert_emb = self.bert_bn(item_bert_emb)
        item_bert_emb = self.relu(item_bert_emb)

        # Concatenate user, item and BERT embeddings
        mlp_input = torch.cat([user_mlp_emb, item_mlp_emb, item_bert_emb], dim=1)

        # Apply MLP layers
        for i, layer in enumerate(self.mlp_layers):
            mlp_input = layer(mlp_input)
            mlp_input = self.mlp_batch_norms[i](mlp_input)
            mlp_input = self.relu(mlp_input)
            mlp_input = self.dropout(mlp_input)

        # Concatenate GMF and MLP parts
        concat_output = torch.cat([gmf_output, mlp_input], dim=1)

        # Final prediction
        prediction = self.sigmoid(self.output_layer(concat_output))

        return prediction.squeeze()

## 6. Training and Evaluation Functions

### Evaluating Model

In [44]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    """
    Train model for one epoch
    """
    model.train()
    total_loss = 0

    train_bar = tqdm(train_loader, desc="Training")
    for user_ids, item_ids, item_embs, labels in train_bar:
        # Move tensors to device
        user_ids = user_ids.to(device)
        item_ids = item_ids.to(device)
        item_embs = item_embs.to(device)
        labels = labels.to(device)

        # Forward pass
        predictions = model(user_ids, item_ids, item_embs)
        loss = criterion(predictions, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(labels)
        train_bar.set_postfix({'loss': loss.item()})

    return total_loss / len(train_loader.dataset)

def evaluate(model, data_loader, criterion, device):
    """
    Evaluate model on a dataset
    """
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        for user_ids, item_ids, item_embs, labels in tqdm(data_loader, desc="Evaluating"):
            # Move tensors to device
            user_ids = user_ids.to(device)
            item_ids = item_ids.to(device)
            item_embs = item_embs.to(device)
            labels = labels.to(device)

            # Get predictions
            predictions = model(user_ids, item_ids, item_embs)
            loss = criterion(predictions, labels)
            total_loss += loss.item() * len(labels)

            # Store predictions and labels
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Convert predictions to binary (threshold = 0.5)
    binary_preds = (all_preds >= 0.5).astype(int)

    # Calculate metrics
    accuracy = accuracy_score(all_labels, binary_preds)
    precision = precision_score(all_labels, binary_preds)
    recall = recall_score(all_labels, binary_preds)
    f1 = f1_score(all_labels, binary_preds)
    auc = roc_auc_score(all_labels, all_preds)

    avg_loss = total_loss / len(data_loader.dataset)

    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc,
        'predictions': all_preds,
        'labels': all_labels
    }

### Evaluating Cold Start

In [45]:
def evaluate_cold_start_items(model, all_df, test_df, min_interactions=5):
    """
    Evaluate model specifically on cold start items from the test set.
    Cold start items are defined as items with minimal user interactions.
    """
    print("\nEvaluating model on cold start items...")
    print(f"Cold start items are defined as items with <= {min_interactions} interactions")

    # Count interactions per item in the full dataset
    item_counts = all_df['asin'].value_counts()

    # Find cold start items that appear in the test set
    cold_start_items = set(item_counts[item_counts <= min_interactions].index) & set(test_df['asin'].unique())

    if not cold_start_items:
        print(f"No cold start items found with <= {min_interactions} interactions in the test set.")
        return None

    print(f"Found {len(cold_start_items)} cold start items in the test set.")

    # Filter test data to only include cold start items
    cold_start_test_df = test_df[test_df['asin'].isin(cold_start_items)]

    if len(cold_start_test_df) == 0:
        print("No data available for cold start evaluation.")
        return None

    print(f"Number of interactions with cold start items: {len(cold_start_test_df)}")

    # Create dataset and dataloader
    cold_start_dataset = NeuMFDataset(cold_start_test_df, item_embeddings)
    cold_start_loader = DataLoader(cold_start_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Evaluate using the same criterion
    criterion = nn.BCELoss()
    cold_start_metrics = evaluate(model, cold_start_loader, criterion, device)

    print("\nCold Start Items Metrics:")
    print("-" * 50)
    print(f"Loss: {cold_start_metrics['loss']:.4f}")
    print(f"Accuracy: {cold_start_metrics['accuracy']:.4f}")
    print(f"Precision: {cold_start_metrics['precision']:.4f}")
    print(f"Recall: {cold_start_metrics['recall']:.4f}")
    print(f"F1 Score: {cold_start_metrics['f1']:.4f}")
    print(f"AUC: {cold_start_metrics['auc']:.4f}")

    return cold_start_metrics

### Generate Recommendations

In [46]:
def generate_recommendations(model, user_encoder, item_encoder, all_df, test_df, n_recommendations=5):
    """
    Generate n recommendations for users in the test set.
    Items are ranked by predicted score.
    """
    model.eval()

    # Get unique users from test set
    test_users = test_df['reviewerID'].unique()

    # Get all items
    all_items = all_df['asin'].unique()

    # Create tensor of all item embeddings
    all_item_indices = torch.tensor([asin_to_idx[asin] for asin in all_items if asin in asin_to_idx], dtype=torch.long)
    all_item_embs = torch.tensor(item_embeddings[all_item_indices], dtype=torch.float)

    recommendations = {}

    for user in tqdm(test_users[:10], desc="Generating recommendations"):  # Just do 10 users for demo
        user_idx = user_encoder.transform([user])[0]

        # Get items the user hasn't interacted with yet
        user_items = all_df[all_df['reviewerID'] == user]['asin'].values
        unseen_items = np.setdiff1d(all_items, user_items)

        if len(unseen_items) == 0:
            print(f"User {user} has interacted with all items!")
            continue

        # Convert unseen items to indices
        unseen_item_indices = torch.tensor([asin_to_idx[asin] for asin in unseen_items if asin in asin_to_idx], dtype=torch.long)

        if len(unseen_item_indices) == 0:
            print(f"No valid unseen items for user {user}")
            continue

        # Get embeddings for unseen items
        unseen_item_embs = torch.tensor(item_embeddings[unseen_item_indices], dtype=torch.float)

        # Process in batches to avoid memory issues
        batch_size = 1024
        all_scores = []

        for i in range(0, len(unseen_item_indices), batch_size):
            batch_indices = unseen_item_indices[i:i+batch_size]
            batch_embs = unseen_item_embs[i:i+batch_size]

            user_tensor = torch.tensor([user_idx] * len(batch_indices), dtype=torch.long).to(device)
            item_tensor = batch_indices.to(device)
            item_embs_tensor = batch_embs.to(device)

            with torch.no_grad():
                scores = model(user_tensor, item_tensor, item_embs_tensor)
                all_scores.append(scores.cpu().numpy())

        if all_scores:
            all_scores = np.concatenate(all_scores)

            # Get the indices of the top N scores
            if len(all_scores) >= n_recommendations:
                top_n_indices = np.argsort(all_scores)[-n_recommendations:][::-1]
                recommended_items = [unseen_items[i] for i in top_n_indices]
            else:
                # If we have fewer items than requested recommendations
                recommended_items = [unseen_items[i] for i in np.argsort(all_scores)[::-1]]

            recommendations[user] = recommended_items
        else:
            print(f"No scores computed for user {user}")

    return recommendations


## 7. Model Training with Early Stopping

In [47]:
def main():
    print("\n" + "="*80)
    print("STEP 1: SAVING ENCODERS FOR INFERENCE")
    print("="*80)

    # Save encoders for inference
    encoder_path = os.path.join(MODEL_DIR, 'sbert_encoders.pkl')
    with open(encoder_path, 'wb') as f:
        pickle.dump({
            'user_encoder': user_encoder,
            'item_encoder': item_encoder,
            'asin_to_idx': asin_to_idx
        }, f)
    print(f"Encoders saved to {encoder_path}")

    print("\n" + "="*80)
    print("STEP 2: INITIALIZING MODEL")
    print("="*80)

    # Model parameters
    num_users = len(user_encoder.classes_)
    num_items = len(item_encoder.classes_)
    item_bert_dim = item_embeddings.shape[1]

    print(f"Number of users: {num_users}")
    print(f"Number of items: {num_items}")
    print(f"BERT embedding dimension: {item_bert_dim}")

    # Model hyperparameters
    EMBEDDING_DIM = 64
    MLP_DIMS = [128, 64, 32]
    LEARNING_RATE = 0.001
    WEIGHT_DECAY = 1e-5
    EPOCHS = 20
    EARLY_STOPPING_PATIENCE = 3

    print(f"Embedding dimension: {EMBEDDING_DIM}")
    print(f"MLP dimensions: {MLP_DIMS}")
    print(f"Learning rate: {LEARNING_RATE}")
    print(f"Weight decay: {WEIGHT_DECAY}")
    print(f"Maximum epochs: {EPOCHS}")
    print(f"Early stopping patience: {EARLY_STOPPING_PATIENCE}")

    # Initialize model
    model = NeuMFPlusPlus(
        num_users=num_users,
        num_items=num_items,
        item_bert_dim=item_bert_dim,
        embedding_dim=EMBEDDING_DIM,
        mlp_dims=MLP_DIMS
    ).to(device)

    print("\n" + "="*80)
    print("STEP 3: TRAINING THE MODEL")
    print("="*80)

    # Loss and optimizer
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    # Training loop with early stopping
    train_losses = []
    val_losses = []
    val_aucs = []
    best_val_auc = 0
    patience_counter = 0
    best_model_path = os.path.join(MODEL_DIR, 'neumf_sbert_model.pt')

    print("\nStarting training...")
    for epoch in range(EPOCHS):
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
        train_losses.append(train_loss)

        # Validate
        val_metrics = evaluate(model, val_loader, criterion, device)
        val_losses.append(val_metrics['loss'])
        val_aucs.append(val_metrics['auc'])

        # Print metrics
        print(f"Epoch {epoch+1}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Accuracy: {val_metrics['accuracy']:.4f}")
        print(f"Val Precision: {val_metrics['precision']:.4f}, Val Recall: {val_metrics['recall']:.4f}")
        print(f"Val F1 Score: {val_metrics['f1']:.4f}, Val AUC: {val_metrics['auc']:.4f}")

        # Early stopping based on validation AUC
        if val_metrics['auc'] > best_val_auc:
            best_val_auc = val_metrics['auc']
            patience_counter = 0

            # Save the best model
            torch.save({
                'model_state_dict': model.state_dict(),
                'num_users': num_users,
                'num_items': num_items,
                'item_bert_dim': item_bert_dim,
                'embedding_dim': EMBEDDING_DIM,
                'mlp_dims': MLP_DIMS,
                'best_val_auc': best_val_auc
            }, best_model_path)
            print(f"Saved best model with Val AUC: {best_val_auc:.4f}")
        else:
            patience_counter += 1
            print(f"Early stopping patience: {patience_counter}/{EARLY_STOPPING_PATIENCE}")

        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered!")
            break

        print("-" * 50)

    print("\n" + "="*80)
    print("STEP 4: LOADING BEST MODEL FOR EVALUATION")
    print("="*80)

    # Try to load the best model, but continue with the current model if loading fails
    try:
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded best model with validation AUC: {checkpoint['best_val_auc']:.4f}")
    except (FileNotFoundError, RuntimeError, pickle.UnpicklingError) as e:
        print(f"Could not load saved model: {e}")
        print("Continuing with current model state.")

    print("\n" + "="*80)
    print("STEP 5: EVALUATING ON TEST SET")
    print("="*80)

    # Evaluate on test set
    test_metrics = evaluate(model, test_loader, criterion, device)
    print("\nTest Results:")
    print("-" * 50)
    print(f"Test Loss: {test_metrics['loss']:.4f}")
    print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Test Precision: {test_metrics['precision']:.4f}")
    print(f"Test Recall: {test_metrics['recall']:.4f}")
    print(f"Test F1 Score: {test_metrics['f1']:.4f}")
    print(f"Test AUC: {test_metrics['auc']:.4f}")

    print("\n" + "="*80)
    print("STEP 6: EVALUATING ON COLD START ITEMS")
    print("="*80)
    print("Cold start items are defined as items with <= 5 interactions in the dataset")

    # Evaluate on cold start items
    all_df = pd.concat([train_df, val_df, test_df])
    cold_start_metrics = evaluate_cold_start_items(
        model=model,
        all_df=all_df,
        test_df=test_df,
        min_interactions=5  # Define cold start items as those with <= 5 interactions
    )

    print("\n" + "="*80)
    print("STEP 7: GENERATING SAMPLE RECOMMENDATIONS")
    print("="*80)

    # Generate recommendations
    recommendations = generate_recommendations(
        model=model,
        user_encoder=user_encoder,
        item_encoder=item_encoder,
        all_df=all_df,
        test_df=test_df,
        n_recommendations=5
    )

    # Display sample recommendations
    print("\nSAMPLE RECOMMENDATIONS (ASINs only):")
    print("-" * 50)
    for i, (user, items) in enumerate(recommendations.items()):
        if i >= 3:  # Just show 3 users for brevity
            break
        print(f"\nUser ID: {user}")
        print(f"Top 5 recommended items (ASINs):")
        for item in items:
            print(f"{item}")

    print("\n" + "="*80)
    print("STEP 8: PLOTTING TRAINING CURVES")
    print("="*80)

    # Plot training curves
    plt.figure(figsize=(15, 5))

    # Loss curves
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # AUC curve
    plt.subplot(1, 2, 2)
    plt.plot(val_aucs, label='Validation AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.title('Validation AUC')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_DIR, "sbert_training_curves.png"))
    plt.close()
    print(f"Training curves saved to {os.path.join(MODEL_DIR, 'sbert_training_curves.png')}")

    print("\n" + "="*80)
    print("TRAINING COMPLETE")
    print("="*80)
    print(f"Model saved to: {best_model_path}")
    print(f"Encoders saved to: {os.path.join(MODEL_DIR, 'sbert_encoders.pkl')}")
    print(f"SBERT embeddings saved to: {EMBEDDING_CACHE}")
    print(f"Training curves saved to: {os.path.join(MODEL_DIR, 'sbert_training_curves.png')}")
    print("\nYou can now use the inference notebook to generate recommendations for specific users.")


## Running the Model

In [48]:
if __name__ == "__main__":
    main()

    # Memory cleanup
    del train_dataset, val_dataset, test_dataset
    del train_loader, val_loader, test_loader
    gc.collect()
    torch.cuda.empty_cache()


STEP 1: SAVING ENCODERS FOR INFERENCE
Encoders saved to /content/drive/MyDrive/bt4222data/sbert_models/sbert_encoders.pkl

STEP 2: INITIALIZING MODEL
Number of users: 192395
Number of items: 61709
BERT embedding dimension: 384
Embedding dimension: 64
MLP dimensions: [128, 64, 32]
Learning rate: 0.001
Weight decay: 1e-05
Maximum epochs: 20
Early stopping patience: 3

STEP 3: TRAINING THE MODEL

Starting training...


Training:   0%|          | 0/1108 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/238 [00:00<?, ?it/s]

Epoch 1/20
Train Loss: 0.3597
Val Loss: 0.3240, Val Accuracy: 0.8867
Val Precision: 0.8867, Val Recall: 1.0000
Val F1 Score: 0.9399, Val AUC: 0.7125
Saved best model with Val AUC: 0.7125
--------------------------------------------------


Training:   0%|          | 0/1108 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/238 [00:00<?, ?it/s]

Epoch 2/20
Train Loss: 0.3022
Val Loss: 0.3178, Val Accuracy: 0.8876
Val Precision: 0.8906, Val Recall: 0.9955
Val F1 Score: 0.9401, Val AUC: 0.7305
Saved best model with Val AUC: 0.7305
--------------------------------------------------


Training:   0%|          | 0/1108 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/238 [00:00<?, ?it/s]

Epoch 3/20
Train Loss: 0.2704
Val Loss: 0.3246, Val Accuracy: 0.8863
Val Precision: 0.8951, Val Recall: 0.9875
Val F1 Score: 0.9390, Val AUC: 0.7269
Early stopping patience: 1/3
--------------------------------------------------


Training:   0%|          | 0/1108 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/238 [00:00<?, ?it/s]

Epoch 4/20
Train Loss: 0.2428
Val Loss: 0.3419, Val Accuracy: 0.8841
Val Precision: 0.8968, Val Recall: 0.9823
Val F1 Score: 0.9376, Val AUC: 0.7169
Early stopping patience: 2/3
--------------------------------------------------


Training:   0%|          | 0/1108 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/238 [00:00<?, ?it/s]

Epoch 5/20
Train Loss: 0.2160
Val Loss: 0.3720, Val Accuracy: 0.8757
Val Precision: 0.9007, Val Recall: 0.9663
Val F1 Score: 0.9324, Val AUC: 0.7071
Early stopping patience: 3/3
Early stopping triggered!

STEP 4: LOADING BEST MODEL FOR EVALUATION
Loaded best model with validation AUC: 0.7305

STEP 5: EVALUATING ON TEST SET


  checkpoint = torch.load(best_model_path, map_location=device)


Evaluating:   0%|          | 0/238 [00:00<?, ?it/s]


Test Results:
--------------------------------------------------
Test Loss: 0.3168
Test Accuracy: 0.8879
Test Precision: 0.8907
Test Recall: 0.9956
Test F1 Score: 0.9403
Test AUC: 0.7329

STEP 6: EVALUATING ON COLD START ITEMS
Cold start items are defined as items with <= 5 interactions in the dataset

Evaluating model on cold start items...
Cold start items are defined as items with <= 5 interactions
Found 4857 cold start items in the test set.
Number of interactions with cold start items: 6583


Evaluating:   0%|          | 0/7 [00:00<?, ?it/s]


Cold Start Items Metrics:
--------------------------------------------------
Loss: 0.3744
Accuracy: 0.8615
Precision: 0.8649
Recall: 0.9944
F1 Score: 0.9251
AUC: 0.6972

STEP 7: GENERATING SAMPLE RECOMMENDATIONS


Generating recommendations:   0%|          | 0/10 [00:00<?, ?it/s]


SAMPLE RECOMMENDATIONS (ASINs only):
--------------------------------------------------

User ID: A15PUGYZ6C2IPU
Top 5 recommended items (ASINs):
B000OGX5AM
B004EBUXHQ
B003FVVMS0
B002ZIMEMW
B007NZGPAY

User ID: AKXT3E60ZZQCY
Top 5 recommended items (ASINs):
B000OGX5AM
B003FVVMS0
B000OG6I6A
B0018QNYSK
B002ZIMEMW

User ID: ABHM4V3BH2C2T
Top 5 recommended items (ASINs):
B000OGX5AM
B003FVVMS0
B004EBUXHQ
B000OG6I6A
B002ZIMEMW

STEP 8: PLOTTING TRAINING CURVES
Training curves saved to /content/drive/MyDrive/bt4222data/sbert_models/sbert_training_curves.png

TRAINING COMPLETE
Model saved to: /content/drive/MyDrive/bt4222data/sbert_models/neumf_sbert_model.pt
Encoders saved to: /content/drive/MyDrive/bt4222data/sbert_models/sbert_encoders.pkl
SBERT embeddings saved to: /content/drive/MyDrive/bt4222data/sbert_models/embeddings/item_sbert_embeddings.npy
Training curves saved to: /content/drive/MyDrive/bt4222data/sbert_models/sbert_training_curves.png

You can now use the inference notebook to g