In [None]:
import pandas as pd
import numpy as np
import copy
import pickle
from tqdm.notebook import tqdm
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from transformers import AutoModel
from transformers import get_cosine_schedule_with_warmup
import os

import warnings
warnings.filterwarnings('ignore')

import builtins
from datetime import datetime
get_time = lambda: f"[{datetime.now():%H:%M:%S}]"
original_print = builtins.print
builtins.print = lambda *args, **kwargs: original_print(get_time(), "   ", *args, **kwargs)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Training script for the embedder models to be used before the boosting model

In [None]:
def seed_everything(seed=42, torch_stuff=True):
    """
    Seeds everything for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)

    if torch_stuff:
        torch.manual_seed(seed)
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    if torch_stuff and torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


### Load data

In [None]:
val_users = pd.read_parquet('val_users_tokens.parquet')
val_reviews = pd.read_parquet('val_reviews_tokens.parquet')
val_matches = pd.read_csv('val_matches.csv')

val_users_indexed = val_users.set_index('user_id')
val_reviews_indexed = val_reviews.set_index(['review_id', 'accommodation_id'])

train_reviews = pd.read_parquet('train_reviews_tokens.parquet')
train_matches = pd.read_csv('train_matches.csv')

### Function and classes used in training

In [None]:
def pad_sequence(sequences, batch_first=False, padding_value=0):
    """
    Pad a list of variable length sequences with padding_value.
    """
    max_len = max(len(seq) for seq in sequences)
    padded_sequences = []

    for seq in sequences:
        padding = torch.full((max_len - len(seq),), padding_value, dtype=seq.dtype)
        padded_sequences.append(torch.cat((seq, padding)))

    if batch_first:
        return torch.stack(padded_sequences)
    else:
        return torch.stack(padded_sequences).transpose(0, 1)
    
def collate_fn(batch):
    """
    Collate function for DataLoader.
    """
    user_input_ids = [item['user_input_ids'] for item in batch]
    user_attention_mask = [item['user_attention_mask'] for item in batch]
    review_input_ids = [item['review_input_ids'] for item in batch]
    review_attention_mask = [item['review_attention_mask'] for item in batch]
    
    user_input_ids = pad_sequence(user_input_ids, batch_first=True, padding_value=0)
    user_attention_mask = pad_sequence(user_attention_mask, batch_first=True, padding_value=0)
    review_input_ids = pad_sequence(review_input_ids, batch_first=True, padding_value=0)
    review_attention_mask = pad_sequence(review_attention_mask, batch_first=True, padding_value=0)
    
    return {
        'user_input_ids': user_input_ids,
        'user_attention_mask': user_attention_mask,
        'review_input_ids': review_input_ids,
        'review_attention_mask': review_attention_mask,
    }

class TrainDataset(Dataset):
    def __init__(self, users_dict, reviews_dict, positive_pairs, group_to_indices):
        """
        Args:
            users_dict (dict): Dictionary with user_id as key and tuple of input_ids and attention_mask as value.
            reviews_dict (dict): Dictionary with review_id as key and tuple of input_ids and attention_mask as value.
            positive_pairs (list): List of tuples (user_id, review_id).
            group_to_indices (dict): Dictionary with group_id as key and list of indices as value.
        """
        self.users_dict = users_dict
        self.reviews_dict = reviews_dict
        self.positive_pairs = positive_pairs
        self.group_to_indices = group_to_indices
        self.groups = list(group_to_indices.keys())
    
    def __len__(self):
        return len(self.positive_pairs)

    def __getitem__(self, idx):
        user_id, review_id = self.positive_pairs[idx]
        user_input_ids, user_attention_mask = self.users_dict[user_id]
        review_input_ids, review_attention_mask = self.reviews_dict[review_id]
        
        # to tensors
        user_input_ids = torch.tensor(user_input_ids)
        user_attention_mask = torch.tensor(user_attention_mask)
        review_input_ids = torch.tensor(review_input_ids)
        review_attention_mask = torch.tensor(review_attention_mask)

        return {
            'user_input_ids': user_input_ids,
            'user_attention_mask': user_attention_mask,
            'review_input_ids': review_input_ids,
            'review_attention_mask': review_attention_mask,
        }

class GroupBatchSampler(Sampler):
    def __init__(self, group_to_indices, batch_size, drop_last=False):
        """
        Args:
            group_to_indices (dict): Dictionary with group_id as key and list of indices as value.
            batch_size (int): Size of mini-batch.
            drop_last (bool): If True, the sampler will drop the last batch if its size would be less than batch_size.
        """
        self.group_to_indices = group_to_indices
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.groups = list(group_to_indices.keys())
        
    def __iter__(self):
        all_batches = []
        
        for group_id in self.groups:
            indices = self.group_to_indices[group_id]
            
            if len(indices) <= self.batch_size:
                all_batches.append(indices)
            else:
                indices = np.array(indices)
                np.random.shuffle(indices)
                
                for i in range(0, len(indices) - self.batch_size + 1, self.batch_size):
                    batch = indices[i:i + self.batch_size].tolist()
                    all_batches.append(batch)
                
                leftover = len(indices) % self.batch_size
                if leftover > 0 and not self.drop_last:
                    last_batch = indices[-leftover:].tolist()
                    all_batches.append(last_batch)
        
        np.random.shuffle(all_batches)
        return iter(all_batches)
    
    def __len__(self):
        total_batches = 0
        for indices in self.group_to_indices.values():
            if len(indices) <= self.batch_size:
                total_batches += 1
            else:
                n_full_batches = len(indices) // self.batch_size
                total_batches += n_full_batches
                if not self.drop_last and len(indices) % self.batch_size > 0:
                    total_batches += 1
        return total_batches

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.temperature = temperature

    def forward(self, user_embeddings, review_embeddings):
        user_embeddings = F.normalize(user_embeddings, p=2, dim=1)
        review_embeddings = F.normalize(review_embeddings, p=2, dim=1)

        similarity = torch.mm(user_embeddings, review_embeddings.t()) / self.temperature
        similarity = torch.sigmoid(similarity)
        
        labels = torch.arange(similarity.size(0)).to(similarity.device)
        
        loss = F.cross_entropy(similarity, labels)
        
        return loss

class TwoTowersNetwork(nn.Module):
    def __init__(self, model_id):
        """
        Args:
            model_id (str): Model identifier.
        """
        super().__init__()
        self.bert1 = AutoModel.from_pretrained(model_id)
        self.bert2 = AutoModel.from_pretrained(model_id)
        
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)

    def forward(self, context_ids, context_mask, review_ids, review_mask):
        context_output = self.bert1(context_ids, attention_mask=context_mask)
        review_output = self.bert2(review_ids, attention_mask=review_mask)
        
        context_embed = self.mean_pooling(context_output, context_mask)
        review_embed = self.mean_pooling(review_output, review_mask)
        
        return context_embed, review_embed

def get_scheduler(optimizer, num_training_steps, warmup_ratio=0.02):
    """
    Get a scheduler with cosine annealing
    Args:
        optimizer (torch.optim.Optimizer): Optimizer.
        num_training_steps (int): Number of training steps.
        warmup_ratio (float): Warmup ratio.
    Returns:
        scheduler: Scheduler.
    """
    num_warmup_steps = int(num_training_steps * warmup_ratio)
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        num_cycles=2  
    )
    
    return scheduler

def get_val_item(val_users_indexed, val_reviews_indexed, val_matches, idx):
    """
    Get a validation item.
    Args:
        val_users_indexed (pd.DataFrame): Indexed validation users.
        val_reviews_indexed (pd.DataFrame): Indexed validation reviews.
        val_matches (pd.DataFrame): Validation matches.
        idx (int): Index of the item.
    Returns:
        dict: Dictionary with user_id, accommodation_id, review_id, user_tokens, review_tokens and review_ids.
    """
    user_id = val_matches.loc[idx, 'user_id']
    accommodation_id = val_matches.loc[idx, 'accommodation_id']
    review_id = val_matches.loc[idx, 'review_id']
    
    user_tokens = (
        np.array(val_users_indexed.loc[user_id].iloc[-2]), 
        np.array(val_users_indexed.loc[user_id].iloc[-1])
    )

    def pad_sequence(seq, max_len=512, pad_value=1):
        return np.pad(seq, (0, max_len - len(seq)), mode='constant', constant_values=pad_value)
    
    input_ids_list = val_reviews_indexed.loc[(slice(None), accommodation_id), :].values[:, -2]
    attention_mask_list = val_reviews_indexed.loc[(slice(None), accommodation_id), :].values[:, -1]
    padded_input_ids = np.stack([pad_sequence(x, max_len=512, pad_value=1) for x in input_ids_list])
    padded_attention_mask = np.stack([pad_sequence(x, max_len=512, pad_value=0) for x in attention_mask_list])

    review_tokens = (padded_input_ids, padded_attention_mask)
    review_ids = val_reviews_indexed.loc[(slice(None), accommodation_id), :].index.get_level_values('review_id').tolist()

    return {
        'user_id': user_id,
        'review_id': review_id,
        'user_tokens': user_tokens,
        'review_tokens': review_tokens,
        'review_ids': review_ids
    }

def evaluate(model, val_users_indexed, val_reviews_indexed, val_matches, device, criterion, batch_size=64, num_samples=3000):
    """
    Evaluate the model.
    Args:
        model: Model.
        val_users_indexed (pd.DataFrame): Indexed validation users.
        val_reviews_indexed (pd.DataFrame): Indexed validation reviews.
        val_matches (pd.DataFrame): Validation matches.
        device: Device.
        criterion: Criterion.
        batch_size (int): Batch size.
        num_samples (int): Number of samples to evaluate.
    Returns:
        dict: Dictionary with MRR, Hit@10 and loss.
    """
    model.eval()
    mrr_scores = []
    hit_rates = []
    total_loss = 0
    indices = np.random.choice(len(val_matches), num_samples, replace=False)

    with torch.no_grad():
        for ii, idx in enumerate(indices):
            batch = get_val_item(val_users_indexed, val_reviews_indexed, val_matches, idx)
            
            user_input_ids = torch.tensor(batch['user_tokens'][0], dtype=torch.long).unsqueeze(0).to(device)
            user_attention_mask = torch.tensor(batch['user_tokens'][1], dtype=torch.long).unsqueeze(0).to(device)
            
            review_input_ids = torch.tensor(batch['review_tokens'][0], dtype=torch.long)
            review_attention_mask = torch.tensor(batch['review_tokens'][1], dtype=torch.long)
            
            true_review_id = batch['review_id']
            review_ids = batch['review_ids']

            all_scores = []
            batch_loss = 0
            
            for j in range(0, len(review_input_ids), batch_size):
                batch_review_input_ids = review_input_ids[j:j+batch_size].to(device)
                batch_review_attention_mask = review_attention_mask[j:j+batch_size].to(device)
                
                batch_user_input_ids = user_input_ids.repeat(len(batch_review_input_ids), 1)
                batch_user_attention_mask = user_attention_mask.repeat(len(batch_review_input_ids), 1)
                
                user_emb, review_emb = model(
                    batch_user_input_ids,
                    batch_user_attention_mask,
                    batch_review_input_ids,
                    batch_review_attention_mask
                )
                
                scores = torch.sum(user_emb * review_emb, dim=1)
                all_scores.append(scores)
                
                batch_loss += criterion(user_emb, review_emb)

            all_scores = torch.cat(all_scores).cpu().numpy()
            
            pairs = list(zip(all_scores, review_ids))
            sorted_pairs = sorted(pairs, key=lambda x: x[0], reverse=True)
            sorted_review_ids = [rid for _, rid in sorted_pairs][:10]
            
            try:
                rank = sorted_review_ids.index(true_review_id) + 1
                mrr = 1.0 / rank
                hit_rates.append(1 if rank <= 10 else 0)
            except ValueError:
                mrr = 0.0
                hit_rates.append(0)
            
            mrr_scores.append(mrr)
            total_loss += batch_loss.item()

            if (ii + 1) % 250 == 0:
                print(f"Processed {idx+1}/{num_samples} samples, "
                      f"MRR: {np.mean(mrr_scores):.4f}, "
                      f"Hit@10: {np.mean(hit_rates):.4f}, "
                      f"Loss: {total_loss/(idx+1):.4f}", flush=True)

    metrics = {
        'mrr': np.mean(mrr_scores),
        'hit_rate': np.mean(hit_rates),
        'loss': total_loss/num_samples
    }
    
    return metrics

def train_model(kmeans_loader, accomodation_loader, val_users_indexed, val_reviews_indexed, val_matches, 
                model, num_epochs, learning_rate, device, max_grad_norm=1.0, batch_size=32):    
    """
    Train a model.
    Args:
        kmeans_loader (DataLoader): DataLoader based on kmeans groups.
        accomodation_loader (DataLoader): DataLoader based on accomodation groups.
        val_users_indexed (pd.DataFrame): Indexed validation users.
        val_reviews_indexed (pd.DataFrame): Indexed validation reviews.
        val_matches (pd.DataFrame): Validation matches.
        model: Model.
        num_epochs (int): Number of epochs.
        learning_rate (float): Learning rate.
        device: Device.
        max_grad_norm (float): Maximum gradient norm.
        batch_size (int): Batch size.
    Returns:
        dict: Training history.
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = ContrastiveLoss()
    total_samples = len(kmeans_loader.dataset) + len(accomodation_loader.dataset)

    num_training_steps = num_epochs * total_samples // batch_size
    scheduler = get_scheduler(optimizer, num_training_steps)
    
    history = {
        'train_loss': [],
        'val_mrr': [],
        'val_hit_rate': [],
        'best_mrr': 0,
        'best_epoch': 0,
        'best_model': None,
        'learning_rates': []
    }
    
    best_mrr = 0
    
    kmeans_iter = iter(kmeans_loader)
    accomodation_iter = iter(accomodation_loader)

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        num_batches = total_samples // batch_size
        
        current_lr = optimizer.param_groups[0]['lr']
        history['learning_rates'].append(current_lr)
        
        for batch_idx, in range(num_batches):
            if batch_idx % 2 == 0:  
                try:
                    kmeans_batch = next(kmeans_iter)
                    user_input_ids = kmeans_batch['user_input_ids'].to(device)
                    user_attention_mask = kmeans_batch['user_attention_mask'].to(device)
                    review_input_ids = kmeans_batch['review_input_ids'].to(device)
                    review_attention_mask = kmeans_batch['review_attention_mask'].to(device)
                except StopIteration:
                    kmeans_iter = iter(kmeans_loader)
                    kmeans_batch = next(kmeans_iter)
                    user_input_ids = kmeans_batch['user_input_ids'].to(device)
                    user_attention_mask = kmeans_batch['user_attention_mask'].to(device)
                    review_input_ids = kmeans_batch['review_input_ids'].to(device)
                    review_attention_mask = kmeans_batch['review_attention_mask'].to(device)
            else:  
                try:
                    accomodation_batch = next(accomodation_iter)
                    user_input_ids = accomodation_batch['user_input_ids'].to(device)
                    user_attention_mask = accomodation_batch['user_attention_mask'].to(device)
                    review_input_ids = accomodation_batch['review_input_ids'].to(device)
                    review_attention_mask = accomodation_batch['review_attention_mask'].to(device)
                except StopIteration:
                    accomodation_iter = iter(accomodation_loader)
                    accomodation_batch = next(accomodation_iter)
                    user_input_ids = accomodation_batch['user_input_ids'].to(device)
                    user_attention_mask = accomodation_batch['user_attention_mask'].to(device)
                    review_input_ids = accomodation_batch['review_input_ids'].to(device)
                    review_attention_mask = accomodation_batch['review_attention_mask'].to(device)

            optimizer.zero_grad()
            
            user_embeddings, review_embeddings = model(
                user_input_ids,
                user_attention_mask,
                review_input_ids,
                review_attention_mask
            )
            
            loss = criterion(user_embeddings, review_embeddings)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()
            epoch_loss += loss.item()
            
            if (batch_idx + 1) % 1500 == 0:
                current_lr = optimizer.param_groups[0]['lr']
                print(f'Epoch [{epoch+1}/{num_epochs}], '
                      f'Step [{(batch_idx+1)}/{num_batches}], '
                      f'LR: {current_lr:.6f}, '
                      f'Running Loss: {epoch_loss/(batch_idx+1):.4f}', flush=True)

        avg_epoch_loss = epoch_loss / num_batches
        history['train_loss'].append(avg_epoch_loss)
        
        print(f"\nValidating epoch {epoch+1}...", flush=True)
        val_metrics = evaluate(
            model=model,
            val_users_indexed=val_users_indexed,
            val_reviews_indexed=val_reviews_indexed,
            val_matches=val_matches,
            device=device,
            criterion=criterion,
            batch_size=batch_size
        )
        
        history['val_mrr'].append(val_metrics['mrr'])
        history['val_hit_rate'].append(val_metrics['hit_rate'])
        
        print(f"Epoch {epoch+1} Summary:", flush=True)
        print(f"Train Loss: {avg_epoch_loss:.4f}", flush=True)
        print(f"Val MRR: {val_metrics['mrr']:.4f}", flush=True)
        print(f"Val Hit@10: {val_metrics['hit_rate']:.4f}", flush=True)
        print(f"Val Loss: {val_metrics['loss']:.4f}\n", flush=True)
        
        if val_metrics['mrr'] > best_mrr:
            best_mrr = val_metrics['mrr']
            history['best_mrr'] = best_mrr
            history['best_epoch'] = epoch
            history['best_model'] = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), f'model_{best_mrr:.4f}.pt')
            print(f"New best MRR: {best_mrr:.4f}\n", flush=True)
        
        print("-" * 60 + "\n", flush=True)
    
    print(f"Training completed. Best MRR: {best_mrr:.4f} at epoch {history['best_epoch']+1}", flush=True)
    
    return history


### Load objects and create dataloaders

In [None]:
with open('users_dict.pkl', 'rb') as f:
    users_dict = pickle.load(f)

with open('reviews_dict.pkl', 'rb') as f:
    reviews_dict = pickle.load(f)

with open('positive_pairs.pkl', 'rb') as f:
    positive_pairs = pickle.load(f)

with open('kmeans_groups.pkl', 'rb') as f:
    kmeans_groups = pickle.load(f)

kmeans_train_dataset = TrainDataset(users_dict, reviews_dict, positive_pairs, kmeans_groups)
kmeans_sampler = GroupBatchSampler(kmeans_groups, batch_size=32, drop_last=True)
kmeans_loader = DataLoader(kmeans_train_dataset, batch_sampler=kmeans_sampler, collate_fn=collate_fn)

accomodation_groups = train_matches.groupby('accommodation_id').apply(lambda x: x.index.tolist()).to_dict()
accomodation_train_dataset = TrainDataset(users_dict, reviews_dict, positive_pairs, accomodation_groups)
accomodation_sampler = GroupBatchSampler(accomodation_groups, batch_size=32, drop_last=True)
accomodation_loader = DataLoader(accomodation_train_dataset, batch_sampler=accomodation_sampler, collate_fn=collate_fn)

### Set training parameters and train model

In [None]:
epochs = 3
learning_rate = 2e-5
batch_size = 32

model_id = "sentence-transformers/all-MiniLM-L12-v2"

seed_everything(42)
model = TwoTowersNetwork(model_id)
history = train_model(
    train_loader=kmeans_loader,
    val_users_indexed=val_users_indexed,
    val_reviews_indexed=val_reviews_indexed,
    val_matches=val_matches,
    model=model,
    num_epochs=epochs,
    learning_rate=learning_rate,
    device=device,
    batch_size=batch_size
)