<a href="https://colab.research.google.com/github/leonardp315/loss-functions-analyses/blob/main/loss-functions-analyses.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers sentence-transformers datasets pandas numpy matplotlib tqdm scikit-learn

In [1]:
"""
Comparative Evaluation of Sentence-Transformer Models with Different Loss Functions
for Semantic Similarity and Paraphrase Tasks

This script performs a systematic evaluation of different Sentence-Transformer models
combined with various loss functions on textual similarity (STS-B) and paraphrase
detection (MRPC) datasets.
"""

import torch
import torch.nn.functional as F
import random
import numpy as np
import pandas as pd
import time
import os
import json
from datetime import datetime
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from scipy.stats import spearmanr, pearsonr
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Directory configuration for results
RESULTS_DIR = Path("results")
FIGURES_DIR = RESULTS_DIR / "figures"
MODELS_DIR = RESULTS_DIR / "models"

for directory in [RESULTS_DIR, FIGURES_DIR, MODELS_DIR]:
    directory.mkdir(exist_ok=True, parents=True)

# Configuration for reproducibility
def set_seed(seed_value=42):
    """Sets seeds for reproducibility across multiple frameworks."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)

    # Additional settings for determinism in PyTorch
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    return seed_value

SEED = set_seed(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Configuration: Seed={SEED}, Device={DEVICE}")

# Experiment settings
SAMPLE_SIZE = None  # Use None for full dataset or a number for sampling
NUM_EPOCHS = 3
BATCH_SIZE = 16
SAVE_MODELS = True  # Save trained models

# Load and prepare datasets
class DatasetLoader:
    """Manager for loading and preparing textual similarity datasets."""

    def __init__(self, cache_dir=None):
        self.cache_dir = cache_dir

    def load_dataset(self, name, split='train', sample_size=None, random_state=42):
        """
        Loads and prepares popular textual similarity datasets.

        Args:
            name: Dataset name ('stsb' or 'mrpc')
            split: Dataset partition ('train', 'validation', 'test')
            sample_size: Number of examples for sampling (None to use all)
            random_state: Seed for reproducible sampling

        Returns:
            DataFrame with processed data
        """
        if name.lower() == 'stsb':
            return self._load_stsb(split, sample_size, random_state)
        elif name.lower() == 'mrpc':
            return self._load_mrpc(split, sample_size, random_state)
        else:
            raise ValueError(f"Unsupported dataset: {name}. Use 'stsb' or 'mrpc'")

    def _load_stsb(self, split, sample_size, random_state):
        """Loads the STS-B (Semantic Textual Similarity Benchmark) dataset."""
        ds = load_dataset('glue', 'stsb', cache_dir=self.cache_dir)[split]
        df = pd.DataFrame(ds)

        # Label processing
        df['label'] = pd.to_numeric(df['label'], errors='coerce')
        df = df.dropna(subset=['label'])

        # Normalization to [0, 1]
        df['label'] = df['label'] / 5.0
        df['label'] = df['label'].clip(lower=0.0, upper=1.0)

        # Binary label for classification
        df['label_bin'] = (df['label'] > 0.5).astype(int)

        # Dataset statistics
        print(f"\n[STS-B - {split}] Statistics:")
        print(f"- Examples: {len(df)}")
        print(f"- Similarity range: [{df['label'].min():.2f}, {df['label'].max():.2f}]")
        print(f"- Binary distribution: {df['label_bin'].value_counts().to_dict()}")

        # Apply sampling if requested
        if sample_size is not None:
            sample_size = min(sample_size, len(df))
            df = df.sample(n=sample_size, random_state=random_state)
            print(f"- Sample used: {sample_size} examples")

        return df

    def _load_mrpc(self, split, sample_size, random_state):
        """Loads the MRPC (Microsoft Research Paraphrase Corpus) dataset."""
        ds = load_dataset('glue', 'mrpc', cache_dir=self.cache_dir)[split]
        df = pd.DataFrame(ds)

        # Ensure labels are integers
        df['label'] = df['label'].astype(int)
        df['label_bin'] = df['label']

        # Dataset statistics
        print(f"\n[MRPC - {split}] Statistics:")
        print(f"- Examples: {len(df)}")
        print(f"- Distribution: {df['label'].value_counts().to_dict()}")

        # Apply sampling if requested
        if sample_size is not None:
            sample_size = min(sample_size, len(df))
            df = df.sample(n=sample_size, random_state=random_state)
            print(f"- Sample used: {sample_size} examples")

        return df

    def visualize_dataset_distribution(self, df, dataset_name):
        """Generates visualization of data distribution."""
        plt.figure(figsize=(10, 6))

        if dataset_name.lower() == 'stsb':
            sns.histplot(df['label'], bins=20, kde=True)
            plt.title('Similarity Distribution in STS-B')
            plt.xlabel('Normalized Similarity [0,1]')
        else:  # MRPC
            counts = df['label'].value_counts().sort_index()
            sns.barplot(x=counts.index, y=counts.values)
            plt.title('Class Distribution in MRPC')
            plt.xlabel('Class (0=Not Paraphrase, 1=Paraphrase)')
            plt.xticks([0, 1], ['Not Paraphrase', 'Paraphrase'])

        plt.ylabel('Count')
        plt.tight_layout()

        fig_path = FIGURES_DIR / f"{dataset_name}_distribution.png"
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        plt.close()

        return fig_path

# Classes for triplet learning
class TripletGenerator:
    """Generator of triplets (anchor, positive, negative) for Triplet Loss."""

    def __init__(self, dataset, fixed_negative=None, hard_negatives=False):
        """
        Initializes the triplet generator.

        Args:
            dataset: DataFrame with sentence pairs
            fixed_negative: Fixed negative sentence (optional)
            hard_negatives: If True, selects hard negatives from dataset
        """
        self.dataset = dataset
        self.fixed_negative = fixed_negative
        self.hard_negatives = hard_negatives

    def generate_triplets(self, n_triplets=None):
        """
        Generates sentence triplets for training.

        Args:
            n_triplets: Number of triplets to generate (default: dataset size)

        Returns:
            List of triplets (anchor, positive, negative)
        """
        if n_triplets is None:
            n_triplets = len(self.dataset)

        triplets = []
        indices = random.sample(range(len(self.dataset)), k=min(n_triplets, len(self.dataset)))

        for i in indices:
            anchor = self.dataset.iloc[i]['sentence1']
            positive = self.dataset.iloc[i]['sentence2']

            if self.fixed_negative:
                negative = self.fixed_negative
            elif self.hard_negatives:
                # Select a different sentence as negative
                neg_idx = random.choice([j for j in range(len(self.dataset)) if j != i])
                negative = random.choice([self.dataset.iloc[neg_idx]['sentence1'],
                                         self.dataset.iloc[neg_idx]['sentence2']])
            else:
                # Use a random sentence as negative
                negative = "This is a negative sentence for the triplet."

            triplets.append((anchor, positive, negative))

        return triplets

class TripletDataset(Dataset):
    """Triplet dataset compatible with PyTorch DataLoader."""

    def __init__(self, triplets):
        self.triplets = triplets

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        a, p, n = self.triplets[idx]
        return InputExample(texts=[a, p, n])

# Custom loss functions
class TripletLoss(torch.nn.Module):
    def __init__(self, model, margin=1.0): super().__init__(); self.model = model; self.margin = margin
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        return torch.mean(F.relu(torch.norm(e[0]-e[1], p=2, dim=1) - torch.norm(e[0]-e[2], p=2, dim=1) + self.margin))

class OnlineTripletLoss(TripletLoss): pass
class BatchHardTripletLoss(TripletLoss): pass
class BatchSemiHardTripletLoss(TripletLoss): pass
class BatchAllTripletLoss(TripletLoss): pass

class MSELoss(torch.nn.Module):
    def __init__(self, model): super().__init__(); self.model = model
    def forward(self, sf, lbl):
        e = [self.model(f)['sentence_embedding'] for f in sf]; return F.mse_loss(e[0], e[1])

class EuclideanLoss(torch.nn.Module):
    def __init__(self, model): super().__init__(); self.model = model
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        return torch.mean(torch.norm(e[0] - e[1], p=2, dim=1))

class NormalizedEuclideanLoss(torch.nn.Module):
    def __init__(self, model): super().__init__(); self.model = model
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        distance = torch.norm(e[0] - e[1], p=2, dim=1)
        return torch.mean(distance)

class AngularMarginLoss(torch.nn.Module):
    def __init__(self, model, margin=0.5): super().__init__(); self.model = model; self.margin = margin
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        cosine = torch.sum(e[0] * e[1], dim=1)
        theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
        return torch.mean((theta + self.margin * (1.0 - lbl.float())) ** 2)

class CircleLoss(torch.nn.Module):
    def __init__(self, model, m=0.25, gamma=256): super().__init__(); self.model = model; self.m = m; self.gamma = gamma
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        sim = torch.sum(e[0] * e[1], dim=1)
        alpha_p = torch.clamp_min(1 + self.m - sim, min=0)
        alpha_n = torch.clamp_min(sim + self.m, min=0)
        delta_p = 1 - self.m
        delta_n = self.m
        logits_p = (-self.gamma) * alpha_p * (sim - delta_p)
        logits_n = self.gamma * alpha_n * (sim - delta_n)
        loss = torch.log1p(torch.exp(logits_n)) + torch.log1p(torch.exp(logits_p))
        return loss.mean()

class SphereLoss(torch.nn.Module):
    def __init__(self, model): super().__init__(); self.model = model
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        return torch.mean(1 - torch.sum(e[0] * e[1], dim=1))

class HistogramLoss(torch.nn.Module):
    def __init__(self, model, num_bins=10): super().__init__(); self.model = model; self.num_bins = num_bins
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        sim = torch.sum(e[0] * e[1], dim=1)
        hist_pos = torch.histc(sim[lbl == 1], bins=self.num_bins, min=-1, max=1)
        hist_neg = torch.histc(sim[lbl == 0], bins=self.num_bins, min=-1, max=1)
        hist_pos /= (torch.sum(hist_pos) + 1e-10)
        hist_neg /= (torch.sum(hist_neg) + 1e-10)
        return torch.sum((hist_pos - hist_neg) ** 2)

class CentroidLoss(torch.nn.Module):
    def __init__(self, model): super().__init__(); self.model = model
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        pos_mask = (lbl == 1).unsqueeze(1)
        neg_mask = (lbl == 0).unsqueeze(1)
        pos_centroid = (e[0] * pos_mask).sum(0) / (pos_mask.sum() + 1e-10)
        neg_centroid = (e[0] * neg_mask).sum(0) / (neg_mask.sum() + 1e-10)
        return F.mse_loss(pos_centroid, neg_centroid)

class HyperSphereLoss(torch.nn.Module):
    def __init__(self, model, radius=1.0): super().__init__(); self.model = model; self.radius = radius
    def forward(self, sf, lbl):
        e = [self.model(f)['sentence_embedding'] for f in sf]
        norms = [torch.norm(emb, p=2, dim=1) for emb in e]
        return torch.mean((norms[0] - self.radius) ** 2 + (norms[1] - self.radius) ** 2)

class ProbabilisticLoss(torch.nn.Module):
    def __init__(self, model): super().__init__(); self.model = model
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        sim = torch.sum(e[0] * e[1], dim=1)
        prob = torch.sigmoid(sim)
        return F.binary_cross_entropy(prob, lbl.float())

class LiftedStructuredLoss(torch.nn.Module):
    def __init__(self, model, margin=1.0): super().__init__(); self.model = model; self.margin = margin
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        dist_matrix = torch.cdist(e[0], e[1], p=2)
        pos_mask = (lbl == 1).float()
        neg_mask = (lbl == 0).float()
        pos_term = torch.log(torch.exp(dist_matrix * pos_mask).sum() + 1)
        neg_term = torch.log(torch.exp(-dist_matrix * neg_mask + self.margin).sum() + 1)
        return pos_term + neg_term

class GeneralPairLoss(torch.nn.Module):
    def __init__(self, model, pos_weight=1.0, neg_weight=1.0): super().__init__(); self.model = model; self.pos_weight = pos_weight; self.neg_weight = neg_weight
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        sim = torch.sum(e[0] * e[1], dim=1)
        pos_pairs = sim[lbl == 1]
        neg_pairs = sim[lbl == 0]
        pos_loss = self.pos_weight * torch.mean((1 - pos_pairs) ** 2)
        neg_loss = self.neg_weight * torch.mean(neg_pairs ** 2)
        return pos_loss + neg_loss

class AngularLoss(torch.nn.Module):
    def __init__(self, model, angle_bound=1.0): super().__init__(); self.model = model; self.angle_bound = angle_bound
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        cos_theta = torch.sum(e[0] * e[1], dim=1)
        theta = torch.acos(torch.clamp(cos_theta, -1.0 + 1e-7, 1.0 - 1e-7))
        target = lbl.float()
        return torch.mean(target * theta + (1 - target) * torch.clamp(self.angle_bound - theta, min=0.0))

class MarginRankingLoss(torch.nn.Module):
    def __init__(self, model, margin=0.5): super().__init__(); self.model = model; self.margin = margin
    def forward(self, sf, lbl):
        e = [F.normalize(self.model(f)['sentence_embedding'], p=2, dim=1) for f in sf]
        sim = torch.sum(e[0] * e[1], dim=1)
        target = 2 * lbl.float() - 1
        return torch.mean(torch.clamp(self.margin - target * sim, min=0.0))

# Dictionary with loss functions
loss_functions = {
    'MSE': MSELoss,
    'Cosine': losses.CosineSimilarityLoss,
    'Contrastive': losses.ContrastiveLoss,
    'InfoNCE': losses.MultipleNegativesRankingLoss,
    'Euclidean': EuclideanLoss,
    'NormaEuc': NormalizedEuclideanLoss,
    'NPairs': losses.BatchAllTripletLoss,
    'MultiSimilarity': losses.MultipleNegativesRankingLoss,
    'AngularMargin': AngularMarginLoss,
    'Sphere': SphereLoss,
    'HyperSphere': HyperSphereLoss,
    'Probabilistic': ProbabilisticLoss,
    'LiftedStructured': LiftedStructuredLoss,
    'GeneralPair': GeneralPairLoss,
    'Angular': AngularLoss,
    'MarginRanking': MarginRankingLoss,
    'Triplet': TripletLoss,
    'OnlineTriplet': OnlineTripletLoss,
    'BatchHardTriplet': BatchHardTripletLoss,
    'BatchSemiHardTriplet': BatchSemiHardTripletLoss,
    'BatchAllTriplet': BatchAllTripletLoss
}

def generate_examples(df, loss_name, fixed_negative=None):
    """
    Generates training examples compatible with different loss functions.

    Args:
        df: DataFrame with data
        loss_name: Name of loss function to use
        fixed_negative: Fixed negative sentence for Triplet Loss

    Returns:
        Dataset with examples formatted for the specified loss function
    """
    if 'Triplet' in loss_name:
        triplets = TripletGenerator(df, fixed_negative, hard_negatives=True).generate_triplets()
        return TripletDataset(triplets)
    elif loss_name == 'Contrastive':
        # For Contrastive Loss, we use binary labels
        examples = [InputExample(texts=[r['sentence1'], r['sentence2']], label=float(r['label_bin']))
                    for _, r in df.iterrows()]
        return examples
    else:
        # For other loss functions, we use continuous similarity
        examples = [InputExample(texts=[r['sentence1'], r['sentence2']], label=float(r['label']))
                    for _, r in df.iterrows()]
        return examples

# Evaluation functions
def evaluate_model(model, test_df, dataset_name):
    """
    Evaluates a model on a test dataset.

    Args:
        model: Trained SentenceTransformer model
        test_df: DataFrame with test data
        dataset_name: Dataset name ('stsb' or 'mrpc')

    Returns:
        Dictionary with evaluation metrics
    """
    # Prepare data
    sent1 = test_df['sentence1'].tolist()
    sent2 = test_df['sentence2'].tolist()
    labels = test_df['label'].tolist()

    # Calculate embeddings and similarities
    embeddings = model.encode(sent1 + sent2, batch_size=32, show_progress_bar=False)
    embeddings1 = embeddings[:len(sent1)]
    embeddings2 = embeddings[len(sent1):]

    # Calculate cosine similarities
    similarities = []
    for e1, e2 in zip(embeddings1, embeddings2):
        similarities.append(cosine_similarity([e1], [e2])[0][0])

    # Basic metrics
    mean_sim = np.mean(similarities)
    std_sim = np.std(similarities)
    results = {
        'mean_similarity': mean_sim,
        'std_similarity': std_sim
    }

    # Dataset-specific metrics
    if dataset_name.lower() == 'stsb':
        # Correlation for similarity tasks
        if len(set(labels)) > 1 and len(set(similarities)) > 1:
            results['pearson'] = pearsonr(labels, similarities)[0]
            results['spearman'] = spearmanr(labels, similarities)[0]
        else:
            results['pearson'] = float('nan')
            results['spearman'] = float('nan')

        # Example for debugging
        print("\n[STS-B] Evaluation example:")
        for i in range(min(3, len(labels))):
            print(f"  Label: {labels[i]:.2f} | Similarity: {similarities[i]:.2f}")

    elif dataset_name.lower() == 'mrpc':
        # Classification metrics
        binary_preds = [1 if s >= 0.5 else 0 for s in similarities]
        results['accuracy'] = accuracy_score(labels, binary_preds)
        results['f1'] = f1_score(labels, binary_preds)
        results['precision'] = precision_score(labels, binary_preds)
        results['recall'] = recall_score(labels, binary_preds)

        # Example for debugging
        print("\n[MRPC] Evaluation example:")
        for i in range(min(3, len(labels))):
            print(f"  Label: {labels[i]} | Predicted: {binary_preds[i]} | Similarity: {similarities[i]:.2f}")

    return results

def plot_results(results_df, metric, dataset_name):
    """
    Generates comparative result plots.

    Args:
        results_df: DataFrame with results
        metric: Metric to visualize
        dataset_name: Dataset name

    Returns:
        Path to saved figure file
    """
    plt.figure(figsize=(12, 8))

    # Filter data for specific dataset
    df = results_df[results_df['Dataset'] == dataset_name].copy()

    # Prepare grouped bar plot
    pivot_df = df.pivot(index='Model', columns='Loss Function', values=metric)

    ax = pivot_df.plot(kind='bar', figsize=(12, 8))

    # Graph settings
    plt.title(f'{metric} by Model and Loss Function - {dataset_name.upper()}', fontsize=14)
    plt.xlabel('Model', fontsize=12)
    plt.ylabel(metric, fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.legend(title='Loss Function', fontsize=10)

    # Add values on bars
    for container in ax.containers:
        ax.bar_label(container, fmt='%.3f', fontsize=8)

    plt.tight_layout()

    # Save figure
    filename = f"{dataset_name}_{metric}_comparison.png"
    filepath = FIGURES_DIR / filename
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close()

    return filepath

def plot_training_curve(history, model_name, loss_name, dataset_name):
    """
    Plots training curve.

    Args:
        history: Training history
        model_name: Model name
        loss_name: Loss function name
        dataset_name: Dataset name

    Returns:
        Path to saved figure file
    """
    plt.figure(figsize=(10, 6))

    # Extract history data
    epochs = range(1, len(history['train_loss']) + 1)

    # Plot losses
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')

    # Add graph information
    plt.title(f'Training Curve: {model_name}\n{loss_name} on {dataset_name}', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()

    # Save figure
    model_short = model_name.split('/')[-1] if '/' in model_name else model_name
    filename = f"{dataset_name}_{model_short}_{loss_name}_training.png"
    filepath = FIGURES_DIR / filename
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close()

    return filepath

# Main training and evaluation function
def train_and_evaluate(model_name, dataset_name, loss_name, train_df, test_df,
                      epochs=3, batch_size=16, save_model=False):
    """
    Trains and evaluates a model with a specific loss function.

    Args:
        model_name: Sentence-Transformer model name
        dataset_name: Dataset name ('stsb' or 'mrpc')
        loss_name: Loss function name
        train_df: DataFrame with training data
        test_df: DataFrame with test data
        epochs: Number of training epochs
        batch_size: Batch size
        save_model: If True, saves the trained model

    Returns:
        Dictionary with results and metrics
    """
    try:
        # Initialize model
        model = SentenceTransformer(model_name).to(DEVICE)
        model_identifier = model_name.split('/')[-1] if '/' in model_name else model_name

        # Configure training
        fixed_negative = "This is an example negative sentence for training triplets."
        dataset = generate_examples(train_df, loss_name, fixed_negative)
        dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
        loss_fn = loss_functions[loss_name](model)

        # Record training history
        history = {'train_loss': []}

        class LogCallback:
            def __init__(self, history):
                self.history = history

            def on_epoch_end(self, epoch, loss, *args, **kwargs):
                self.history['train_loss'].append(loss)

        # Execute training
        start_time = time.time()
        model.fit(
            train_objectives=[(dataloader, loss_fn)],
            epochs=epochs,
            warmup_steps=int(len(dataloader) * 0.1),
            show_progress_bar=True,
            output_path=None,
            callback=LogCallback(history)
        )
        training_time = time.time() - start_time

        # Evaluate model
        evaluation_results = evaluate_model(model, test_df, dataset_name)

        # Save model if requested
        model_path = None
        if save_model:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            model_path = MODELS_DIR / f"{dataset_name}_{model_identifier}_{loss_name}_{timestamp}"
            model.save(str(model_path))

        # Plot training curve
        training_plot = plot_training_curve(history, model_name, loss_name, dataset_name)

        # Consolidate results
        results = {
            'Dataset': dataset_name,
            'Model': model_name,
            'Loss Function': loss_name,
            'Training Time (s)': round(training_time, 2),
            'Mean Similarity': round(evaluation_results['mean_similarity'], 4),
            'STD Similarity': round(evaluation_results['std_similarity'], 4),
            'Epochs': epochs,
            'Batch Size': batch_size,
            'Training Plot': str(training_plot),
            'Model Path': str(model_path) if model_path else None
        }

        # Add specific metrics
        if dataset_name.lower() == 'stsb':
            results['Pearson'] = round(evaluation_results['pearson'], 4) if 'pearson' in evaluation_results else None
            results['Spearman'] = round(evaluation_results['spearman'], 4) if 'spearman' in evaluation_results else None
        elif dataset_name.lower() == 'mrpc':
            results['Accuracy'] = round(evaluation_results['accuracy'], 4) if 'accuracy' in evaluation_results else None
            results['F1 Score'] = round(evaluation_results['f1'], 4) if 'f1' in evaluation_results else None
            results['Precision'] = round(evaluation_results['precision'], 4) if 'precision' in evaluation_results else None
            results['Recall'] = round(evaluation_results['recall'], 4) if 'recall' in evaluation_results else None

        return results

    except Exception as e:
        print(f"Error in train_and_evaluate({model_name}, {dataset_name}, {loss_name}): {e}")
        import traceback
        traceback.print_exc()
        return {
            'Dataset': dataset_name,
            'Model': model_name,
            'Loss Function': loss_name,
            'Error': str(e)
        }

# Main function
def main():
    # List of models to evaluate
    model_names = [
        'sentence-transformers/all-mpnet-base-v2',
        'sentence-transformers/bert-base-nli-mean-tokens',
        'sentence-transformers/paraphrase-MiniLM-L6-v2'
    ]

    # Datasets to evaluate
    datasets = ['stsb', 'mrpc']

    # Experimental configurations
    experiment_config = {
        'seed': SEED,
        'device': str(DEVICE),
        'epochs': NUM_EPOCHS,
        'batch_size': BATCH_SIZE,
        'sample_size': SAMPLE_SIZE,
        'save_models': SAVE_MODELS,
        'models': model_names,
        'datasets': datasets,
        'loss_functions': list(loss_functions.keys()),
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }

    # Save experiment configuration
    with open(RESULTS_DIR / "experiment_config.json", 'w') as f:
        json.dump(experiment_config, f, indent=2)

    # Initialize dataset loader
    loader = DatasetLoader()

    # Results stored here
    all_results = []
    dataset_figures = {}

    # Main loop
    for dataset_name in datasets:
        print(f"\n\n{'='*60}")
        print(f"Dataset: {dataset_name.upper()}")
        print(f"{'='*60}")

        # Load datasets
        train_df = loader.load_dataset(dataset_name, 'train', sample_size=SAMPLE_SIZE)
        test_df = loader.load_dataset(dataset_name, 'validation', sample_size=min(408, SAMPLE_SIZE if SAMPLE_SIZE else 1000))

        # Distribution visualization
        dist_fig = loader.visualize_dataset_distribution(train_df, dataset_name)
        dataset_figures[dataset_name] = str(dist_fig)

        # Data sample
        print(f"\nData sample ({dataset_name.upper()}):")
        print(train_df[['sentence1', 'sentence2', 'label']].head(3).to_string())

        # Run evaluation for each combination
        results_dataset = []

        for model_name in model_names:
            model_short = model_name.split('/')[-1]
            print(f"\n{'-'*40}")
            print(f"Model: {model_short}")
            print(f"{'-'*40}")

            for loss_name in loss_functions.keys():
                print(f"\nEvaluating {model_short} with {loss_name} on {dataset_name.upper()}...")

                result = train_and_evaluate(
                    model_name=model_name,
                    dataset_name=dataset_name,
                    loss_name=loss_name,
                    train_df=train_df,
                    test_df=test_df,
                    epochs=NUM_EPOCHS,
                    batch_size=BATCH_SIZE,
                    save_model=SAVE_MODELS
                )

                results_dataset.append(result)
                all_results.append(result)

                # Immediate result logging
                if 'Error' in result:
                    print(f" Error: {result['Error']}")
                else:
                    print(f" Completed: Mean Sim = {result['Mean Similarity']}")
                    if dataset_name.lower() == 'stsb':
                        print(f"   Pearson = {result['Pearson']}")
                    else:
                        print(f"   Accuracy = {result['Accuracy']}, F1 = {result['F1 Score']}")

        # Save results per dataset
        results_df = pd.DataFrame(results_dataset)
        results_df.to_csv(RESULTS_DIR / f"results_{dataset_name}.csv", index=False)

        # Generate visualizations
        if dataset_name.lower() == 'stsb':
            plot_results(results_df, 'Pearson', dataset_name)
        else:
            plot_results(results_df, 'F1 Score', dataset_name)
            plot_results(results_df, 'Accuracy', dataset_name)

    # Consolidate all results
    all_results_df = pd.DataFrame(all_results)
    all_results_df.to_csv(RESULTS_DIR / "complete_results.csv", index=False)

    # Generate HTML report
    generate_html_report(all_results_df, experiment_config, dataset_figures)

    print("\n\nExperiment completed. Results available in:", RESULTS_DIR)
    return all_results_df

def generate_html_report(results_df, config, dataset_figures):

    """
    Generates an HTML report with experiment results.

    Args:
        results_df: DataFrame with all results
        config: Experiment configuration
        dataset_figures: Dictionary with paths to dataset figures
    """
    report_path = RESULTS_DIR / "experiment_report.html"

    # Prepare result tables
    stsb_df = results_df[results_df['Dataset'] == 'stsb'].copy()
    mrpc_df = results_df[results_df['Dataset'] == 'mrpc'].copy()

    # Select relevant columns
    stsb_cols = ['Model', 'Loss Function', 'Training Time (s)', 'Mean Similarity', 'Pearson', 'Spearman']
    mrpc_cols = ['Model', 'Loss Function', 'Training Time (s)', 'Mean Similarity', 'Accuracy', 'F1 Score', 'Precision', 'Recall']

    # Clean model names for display
    for df in [stsb_df, mrpc_df]:
        df['Model'] = df['Model'].apply(lambda x: x.split('/')[-1] if '/' in x else x)

    # Generate HTML
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>Sentence-Transformers Evaluation Report</title>
        <style>
            body {{ font-family: Arial, sans-serif; line-height: 1.6; margin: 0; padding: 20px; color: #333; }}
            h1, h2, h3 {{ color: #2c3e50; }}
            table {{ border-collapse: collapse; width: 100%; margin-bottom: 20px; }}
            th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
            th {{ background-color: #f2f2f2; color: #333; font-weight: bold; }}
            tr:nth-child(even) {{ background-color: #f9f9f9; }}
            tr:hover {{ background-color: #f5f5f5; }}
            .container {{ max-width: 1200px; margin: 0 auto; padding: 20px; }}
            .section {{ margin-bottom: 30px; }}
            .best-result {{ font-weight: bold; color: #27ae60; }}
            img {{ max-width: 100%; height: auto; margin: 10px 0; border: 1px solid #ddd; }}
            .config {{ background-color: #f8f9fa; padding: 15px; border-radius: 4px; margin-bottom: 20px; }}
            footer {{ margin-top: 30px; padding-top: 10px; border-top: 1px solid #eee; color: #7f8c8d; font-size: 0.9em; }}
        </style>
    </head>
    <body>
        <div class="container">
            <header>
                <h1>Comparative Evaluation of Sentence-Transformer Models</h1>
                <p>Report generated at: {config['timestamp']}</p>
            </header>

            <div class="section">
                <h2>Experiment Configuration</h2>
                <div class="config">
                    <p><strong>Device:</strong> {config['device']}</p>
                    <p><strong>Seed:</strong> {config['seed']}</p>
                    <p><strong>Epochs:</strong> {config['epochs']}</p>
                    <p><strong>Batch Size:</strong> {config['batch_size']}</p>
                    <p><strong>Sample:</strong> {config['sample_size'] if config['sample_size'] else 'Full Dataset'}</p>
                    <p><strong>Models:</strong> {', '.join([m.split('/')[-1] if '/' in m else m for m in config['models']])}</p>
                    <p><strong>Loss Functions:</strong> {', '.join(config['loss_functions'])}</p>
                </div>
            </div>

            <div class="section">
                <h2>Results - STS-B (Semantic Similarity)</h2>
                <p>Training data distribution:</p>
                <img src="{dataset_figures['stsb']}" alt="STS-B Distribution">

                <h3>Performance Metrics</h3>
                <table>
                    <tr>
                        <th>Model</th>
                        <th>Loss Function</th>
                        <th>Time (s)</th>
                        <th>Mean Similarity</th>
                        <th>Pearson Correlation</th>
                        <th>Spearman Correlation</th>
                    </tr>
                    {stsb_df[stsb_cols].sort_values('Pearson', ascending=False).to_html(index=False, header=False, classes='results-table')}
                </table>

                <h3>Results Visualization</h3>
                <img src="{FIGURES_DIR / 'stsb_Pearson_comparison.png'}" alt="Pearson Comparison STS-B">
            </div>

            <div class="section">
                <h2>Results - MRPC (Paraphrase Detection)</h2>
                <p>Training data distribution:</p>
                <img src="{dataset_figures['mrpc']}" alt="MRPC Distribution">

                <h3>Performance Metrics</h3>
                <table>
                    <tr>
                        <th>Model</th>
                        <th>Loss Function</th>
                        <th>Time (s)</th>
                        <th>Mean Similarity</th>
                        <th>Accuracy</th>
                        <th>F1 Score</th>
                        <th>Precision</th>
                        <th>Recall</th>
                    </tr>
                    {mrpc_df[mrpc_cols].sort_values('F1 Score', ascending=False).to_html(index=False, header=False, classes='results-table')}
                </table>

                <h3>Results Visualization</h3>
                <img src="{FIGURES_DIR / 'mrpc_F1 Score_comparison.png'}" alt="F1 Comparison MRPC">
                <img src="{FIGURES_DIR / 'mrpc_Accuracy_comparison.png'}" alt="Accuracy Comparison MRPC">
            </div>

            <div class="section">
                <h2>Training Curve Analysis</h2>
                <p>Examples of training curves for the best models:</p>

                <h3>STS-B (Best model)</h3>
                <img src="{stsb_df.sort_values('Pearson', ascending=False).iloc[0]['Training Plot']}" alt="Best Curve STS-B">

                <h3>MRPC (Best model)</h3>
                <img src="{mrpc_df.sort_values('F1 Score', ascending=False).iloc[0]['Training Plot']}" alt="Best Curve MRPC">
            </div>

            <div class="section">
                <h2>Conclusions</h2>
                <p><strong>Best configuration for STS-B:</strong> {stsb_df.sort_values('Pearson', ascending=False).iloc[0]['Model']} with {stsb_df.sort_values('Pearson', ascending=False).iloc[0]['Loss Function']} (Pearson: {stsb_df.sort_values('Pearson', ascending=False).iloc[0]['Pearson']})</p>
                <p><strong>Best configuration for MRPC:</strong> {mrpc_df.sort_values('F1 Score', ascending=False).iloc[0]['Model']} with {mrpc_df.sort_values('F1 Score', ascending=False).iloc[0]['Loss Function']} (F1: {mrpc_df.sort_values('F1 Score', ascending=False).iloc[0]['F1 Score']})</p>

                <p>General observations:</p>
                <ul>
                    <li>Loss functions have significant impact on model performance.</li>
                    <li>Models specialized in paraphrase tend to perform better in evaluated tasks.</li>
                    <li>Training time varies considerably between models.</li>
                </ul>
            </div>

            <footer>
                <p>Report automatically generated by Sentence-Transformers evaluation script.</p>
                <p>All models and results are available in directory: {RESULTS_DIR}</p>
            </footer>
        </div>
    </body>
    </html>
    """

    # Save report
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write(html_content)

    print(f"HTML report generated at: {report_path}")
    return report_path

# Additional functions for advanced analysis

def analyze_similarity_metrics_correlation(results_df, dataset_name):
    """
    Analyzes correlation between mean similarity and performance metrics.

    Args:
        results_df: DataFrame with results
        dataset_name: Dataset name to analyze

    Returns:
        Figure with correlation matrix
    """
    # Filter data for specific dataset
    df = results_df[results_df['Dataset'] == dataset_name].copy()

    # Columns to analyze
    if dataset_name.lower() == 'stsb':
        cols = ['Mean Similarity', 'STD Similarity', 'Pearson', 'Spearman', 'Training Time (s)']
    else:  # MRPC
        cols = ['Mean Similarity', 'STD Similarity', 'Accuracy', 'F1 Score',
                'Precision', 'Recall', 'Training Time (s)']

    # Calculate correlation matrix
    corr_matrix = df[cols].corr()

    # Visualize correlation matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5)
    plt.title(f'Metrics Correlation - {dataset_name.upper()}')
    plt.tight_layout()

    # Save figure
    filepath = FIGURES_DIR / f"{dataset_name}_metric_correlation.png"
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close()

    return filepath

def analyze_time_vs_performance(results_df):
    """
    Analyzes relationship between training time and performance metrics.

    Args:
        results_df: DataFrame with results

    Returns:
        Figure with scatter plots
    """
    plt.figure(figsize=(12, 10))

    # Split into subplots
    fig, axes = plt.subplots(2, 1, figsize=(12, 12))

    # Data for STS-B
    stsb_df = results_df[results_df['Dataset'] == 'stsb'].copy()
    stsb_df['Model'] = stsb_df['Model'].apply(lambda x: x.split('/')[-1] if '/' in x else x)

    # Data for MRPC
    mrpc_df = results_df[results_df['Dataset'] == 'mrpc'].copy()
    mrpc_df['Model'] = mrpc_df['Model'].apply(lambda x: x.split('/')[-1] if '/' in x else x)

    # Plot for STS-B
    ax = axes[0]
    for model in stsb_df['Model'].unique():
        model_df = stsb_df[stsb_df['Model'] == model]
        ax.scatter(model_df['Training Time (s)'], model_df['Pearson'],
                  label=model, alpha=0.7, s=80)

        # Add labels for each point
        for _, row in model_df.iterrows():
            ax.annotate(row['Loss Function'],
                       (row['Training Time (s)'], row['Pearson']),
                       fontsize=8, alpha=0.8)

    ax.set_title('STS-B: Pearson Correlation vs. Training Time')
    ax.set_xlabel('Training Time (seconds)')
    ax.set_ylabel('Pearson Correlation')
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend()

    # Plot for MRPC
    ax = axes[1]
    for model in mrpc_df['Model'].unique():
        model_df = mrpc_df[mrpc_df['Model'] == model]
        ax.scatter(model_df['Training Time (s)'], model_df['F1 Score'],
                  label=model, alpha=0.7, s=80)

        # Add labels for each point
        for _, row in model_df.iterrows():
            ax.annotate(row['Loss Function'],
                       (row['Training Time (s)'], row['F1 Score']),
                       fontsize=8, alpha=0.8)

    ax.set_title('MRPC: F1 Score vs. Training Time')
    ax.set_xlabel('Training Time (seconds)')
    ax.set_ylabel('F1 Score')
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend()

    plt.tight_layout()

    # Save figure
    filepath = FIGURES_DIR / "time_vs_performance.png"
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close()

    return filepath

def analyze_loss_function_impact(results_df):
    """
    Analyzes impact of different loss functions on performance.

    Args:
        results_df: DataFrame with results

    Returns:
        DataFrame with impact statistics
    """
    # Statistics per loss function
    impact = []

    # Analysis for STS-B
    stsb_df = results_df[results_df['Dataset'] == 'stsb'].copy()
    stsb_metrics = ['Pearson', 'Spearman']

    for loss_fn in stsb_df['Loss Function'].unique():
        loss_stats = {
            'Dataset': 'STS-B',
            'Loss Function': loss_fn,
            'Count': len(stsb_df[stsb_df['Loss Function'] == loss_fn])
        }

        for metric in stsb_metrics:
            loss_stats[f'Mean {metric}'] = stsb_df[stsb_df['Loss Function'] == loss_fn][metric].mean()
            loss_stats[f'Std {metric}'] = stsb_df[stsb_df['Loss Function'] == loss_fn][metric].std()
            loss_stats[f'Max {metric}'] = stsb_df[stsb_df['Loss Function'] == loss_fn][metric].max()
            loss_stats[f'Min {metric}'] = stsb_df[stsb_df['Loss Function'] == loss_fn][metric].min()

        impact.append(loss_stats)

    # Analysis for MRPC
    mrpc_df = results_df[results_df['Dataset'] == 'mrpc'].copy()
    mrpc_metrics = ['Accuracy', 'F1 Score']

    for loss_fn in mrpc_df['Loss Function'].unique():
        loss_stats = {
            'Dataset': 'MRPC',
            'Loss Function': loss_fn,
            'Count': len(mrpc_df[mrpc_df['Loss Function'] == loss_fn])
        }

        for metric in mrpc_metrics:
            loss_stats[f'Mean {metric}'] = mrpc_df[mrpc_df['Loss Function'] == loss_fn][metric].mean()
            loss_stats[f'Std {metric}'] = mrpc_df[mrpc_df['Loss Function'] == loss_fn][metric].std()
            loss_stats[f'Max {metric}'] = mrpc_df[mrpc_df['Loss Function'] == loss_fn][metric].max()
            loss_stats[f'Min {metric}'] = mrpc_df[mrpc_df['Loss Function'] == loss_fn][metric].min()

        impact.append(loss_stats)

    # Create DataFrame with statistics
    impact_df = pd.DataFrame(impact)

    # Save analysis
    impact_df.to_csv(RESULTS_DIR / "loss_functions_impact.csv", index=False)

    return impact_df

def extended_experiment():
    """
    Main function that executes the experiment and additional analyses.
    """
    try:
        # Run main experiment
        results_df = main()

        # Validate we have results for analyses
        if results_df is None or len(results_df) == 0:
            print("No results for additional analyses.")
            return

        print("\n\n" + "="*60)
        print("Additional Analyses")
        print("="*60)

        # Correlation analysis between metrics
        print("\nAnalyzing correlation between metrics...")
        analyze_similarity_metrics_correlation(results_df, 'stsb')
        analyze_similarity_metrics_correlation(results_df, 'mrpc')

        # Time vs. performance analysis
        print("\nAnalyzing time vs. performance relationship...")
        analyze_time_vs_performance(results_df)

        # Loss function impact analysis
        print("\nAnalyzing loss functions impact...")
        impact_df = analyze_loss_function_impact(results_df)

        # Analysis summary
        print("\nAnalysis Summary:")
        print(f"- {len(results_df)} model-loss combinations tested")

        for dataset in ['STS-B', 'MRPC']:
            print(f"\n{dataset}:")
            dataset_impact = impact_df[impact_df['Dataset'] == dataset]

            if dataset == 'STS-B':
                best_loss = dataset_impact.sort_values('Mean Pearson', ascending=False).iloc[0]
                print(f"- Best loss function: {best_loss['Loss Function']} (Mean Pearson: {best_loss['Mean Pearson']:.4f})")
            else:
                best_loss = dataset_impact.sort_values('Mean F1 Score', ascending=False).iloc[0]
                print(f"- Best loss function: {best_loss['Loss Function']} (Mean F1: {best_loss['Mean F1 Score']:.4f})")

        print("\nAdditional analyses completed and saved in:", RESULTS_DIR)

        try:
            from google.colab import files
            import shutil

            zip_path = shutil.make_archive("resultados_experimento", 'zip', RESULTS_DIR)
            files.download(zip_path)
        except:
            print("Download automático indisponível (executando fora do Google Colab).")

    except Exception as e:
        print(f"Error in additional analyses: {e}")
        import traceback
        traceback.print_exc()

# Execute if main script
if __name__ == "__main__":
    print("="*80)
    print("Evaluation of Sentence-Transformer Models for Semantic Similarity")
    print("="*80)
    print("Settings:")
    print(f"- Seed: {SEED}")
    print(f"- Device: {DEVICE}")
    print(f"- Epochs: {NUM_EPOCHS}")
    print(f"- Batch Size: {BATCH_SIZE}")
    print(f"- Sample size: {SAMPLE_SIZE if SAMPLE_SIZE else 'Full dataset'}")
    print(f"- Results directory: {RESULTS_DIR}")
    print("="*80)

    # Run complete experiment with additional analyses
    extended_experiment()



ModuleNotFoundError: No module named 'datasets'