# **Comparing Baseline and Recurrent Transformers on different subset sizes of the SST-2 dataset**

This notebook compares the Baseline and Recurrent transformer architectures on different subsets of the SST-2 dataset:
   
   1. **10% of the training data**
   2. **50% of the training data** 

The goal is to analyze how data quantity affects the performance, convergence, and efficiency of each model.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
import time
import os
import warnings
import copy
from tqdm.notebook import tqdm # Use tqdm.notebook for Jupyter progress bars
from sklearn.metrics import f1_score, precision_score, recall_score

warnings.filterwarnings('ignore')
sns.set_theme(style="whitegrid")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")



## 1. Configurations

In [None]:
# Model configurations 
BASELINE_CONFIG = {
    'hidden_size': 384,
    'num_hidden_layers': 6,
    'num_attention_heads': 6,
    'intermediate_size': 1536,
    'dropout_prob': 0.1,
    'use_flash_attention': True,
    'use_swiglu': True,
    'use_rope': True,
    'use_rms_norm': True
}

RECURRENT_CONFIG = {
    'hidden_size': 256,
    'num_hidden_layers': 3,
    'recurrent_depth': 2,  # Effective depth: 3 Ã— 2 = 6
    'num_attention_heads': 4,
    'intermediate_size': 1024,
    'dropout_prob': 0.1,
    'residual_scale': 0.5,
    'use_flash_attention': True,
    'use_swiglu': True,
    'use_rope': True,
    'use_rms_norm': True
}

TRAINING_CONFIG = {
    'num_epochs': 5,
    'batch_size': 16,
    'learning_rate': 3e-5,
    'warmup_steps': 100,
    'eval_steps': 50,
    'max_length': 128,
    'patience': 3, # For early stopping
    'min_delta': 0.001 # For early stopping
}

DATA_SUBSETS = ['10_percent', '50_percent']

## 2. Functions for Models, Training, and Evaluation

In [None]:
import sys
import os
# get project root
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
# add to python path
sys.path.append(PROJECT_ROOT)


from training.utils import prepare_sst2_data, load_tokenizer
from models.baseline import BaselineModel, BaselineConfig
from models.recurrent import RecurrentModel, RecurrentConfig
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau




def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def get_models():
    """Initializes and returns fresh instances of the models."""
    # Baseline model
    baseline_config = BaselineConfig(
        vocab_size=30522,
        hidden_size=BASELINE_CONFIG['hidden_size'],
        num_hidden_layers=BASELINE_CONFIG['num_hidden_layers'],
        num_attention_heads=BASELINE_CONFIG['num_attention_heads'],
        intermediate_size=BASELINE_CONFIG['intermediate_size'],
        hidden_dropout_prob=BASELINE_CONFIG['dropout_prob'],
        attention_probs_dropout_prob=BASELINE_CONFIG['dropout_prob'],
        num_labels=2,
        **{k: v for k, v in BASELINE_CONFIG.items() if k.startswith('use_')}
    )

    # Recurrent model
    recurrent_config = RecurrentConfig(
        vocab_size=30522,
        hidden_size=RECURRENT_CONFIG['hidden_size'],
        num_hidden_layers=RECURRENT_CONFIG['num_hidden_layers'],
        recurrent_depth=RECURRENT_CONFIG['recurrent_depth'],
        num_attention_heads=RECURRENT_CONFIG['num_attention_heads'],
        intermediate_size=RECURRENT_CONFIG['intermediate_size'],
        hidden_dropout_prob=RECURRENT_CONFIG['dropout_prob'],
        attention_probs_dropout_prob=RECURRENT_CONFIG['dropout_prob'],
        residual_scale=RECURRENT_CONFIG['residual_scale'],
        num_labels=2,
        **{k: v for k, v in RECURRENT_CONFIG.items() if k.startswith('use_')}
    )
    
    baseline_model = BaselineModel(baseline_config).to(device)
    recurrent_model = RecurrentModel(recurrent_config).to(device)
    
    return baseline_model, recurrent_model

def train_model(model, name, train_loader, val_loader, num_epochs, patience, min_delta):
    optimizer = AdamW(model.parameters(), lr=TRAINING_CONFIG['learning_rate'])
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=False)
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
    
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        progress_bar = tqdm(train_loader, desc=f"{name} Epoch {epoch+1}", leave=False)
        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'])
            loss = outputs['loss']
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_losses.append(loss.item())
            progress_bar.set_postfix(loss=np.mean(train_losses))
        
        model.eval()
        val_losses, val_correct, val_total = [], 0, 0
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'])
                val_losses.append(outputs['loss'].item())
                predictions = outputs['logits'].argmax(dim=-1)
                val_correct += (predictions == batch['labels']).sum().item()
                val_total += batch['labels'].size(0)
        
        train_loss, val_loss, val_acc = np.mean(train_losses), np.mean(val_losses), val_correct / val_total
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

        if val_loss < best_val_loss - min_delta:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
            
    if best_model_state:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with val_loss={best_val_loss:.4f}")
        
    return history

def evaluate_model(model, loader):
    model.eval()
    all_predictions, all_labels, inference_times = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            start_time = time.time()
            outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
            inference_times.append((time.time() - start_time) * 1000 / batch['input_ids'].size(0))
            predictions = outputs['logits'].argmax(dim=-1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(batch['labels'].cpu().numpy())
            
    accuracy = (np.array(all_predictions) == np.array(all_labels)).mean()
    return {
        'accuracy': accuracy,
        'f1': f1_score(all_labels, all_predictions, average='weighted'),
        'precision': precision_score(all_labels, all_predictions, average='weighted'),
        'recall': recall_score(all_labels, all_predictions, average='weighted'),
        'inference_time_ms': np.mean(inference_times)
    }


## 3. Run Experiments

In [None]:
tokenizer = load_tokenizer('bert-base-uncased')
all_results = []
training_histories = {}

for subset in DATA_SUBSETS:
    print(f"\n{'='*20} Experiment: {subset.replace('_', ' ').title()} {'='*20}")
    data_dir = f'../data/processed_size_splits_sst2/{subset}'
    training_histories[subset] = {}
    
    # Load data for the current subset
    train_loader, val_loader, test_loader = prepare_sst2_data(
        data_dir=data_dir,
        tokenizer=tokenizer,
        batch_size=TRAINING_CONFIG['batch_size'],
        max_length=TRAINING_CONFIG['max_length']
    )
    print(f"Dataset sizes - Train: {len(train_loader.dataset)}, Validation: {len(val_loader.dataset)}, Test: {len(test_loader.dataset)}")

    # Get fresh models
    baseline_model, recurrent_model = get_models()
    models_to_run = {
        'Baseline': baseline_model,
        'Recurrent': recurrent_model
    }

    for model_name, model in models_to_run.items():
        print(f"\n--- Training {model_name} on {subset} ---")
        
        # Train the model
        history = train_model(
            model, f"{model_name}-{subset}", train_loader, val_loader, 
            num_epochs=TRAINING_CONFIG['num_epochs'],
            patience=TRAINING_CONFIG['patience'],
            min_delta=TRAINING_CONFIG['min_delta']
        )
        training_histories[subset][model_name] = history
        
        # Evaluate the model
        print(f"--- Evaluating {model_name} on {subset} test set ---")
        metrics = evaluate_model(model, test_loader)
        
        # Store results
        params = count_parameters(model)
        size_mb = params * 4 / (1024 * 1024)
        result = {
            'Model': model_name,
            'Subset': subset,
            'Parameters': params,
            'Size (MB)': size_mb,
            **metrics
        }
        all_results.append(result)
        print(f"Results for {model_name} on {subset}: Accuracy={metrics['accuracy']:.4f}, F1={metrics['f1']:.4f}")

# Convert results to DataFrame
results_df = pd.DataFrame(all_results)


## 4. Analyze and Visualize Results

In [None]:
print("\n--- Final Comparison Table ---")
display(results_df.round(4))

In [None]:
# Plot training curves
fig, axes = plt.subplots(len(DATA_SUBSETS), 2, figsize=(14, 6 * len(DATA_SUBSETS)), sharex='col')
fig.suptitle('Training Dynamics Across Data Subsets', fontsize=16, y=1.02)

for i, subset in enumerate(DATA_SUBSETS):
    ax_loss = axes[i][0]
    ax_acc = axes[i][1]
    
    # Plot loss
    baseline_hist = training_histories[subset]['Baseline']
    recurrent_hist = training_histories[subset]['Recurrent']
    
    ax_loss.plot(baseline_hist['train_loss'], label='Baseline Train', color='blue', linestyle='-')
    ax_loss.plot(baseline_hist['val_loss'], label='Baseline Val', color='blue', linestyle='--')
    ax_loss.plot(recurrent_hist['train_loss'], label='Recurrent Train', color='orange', linestyle='-')
    ax_loss.plot(recurrent_hist['val_loss'], label='Recurrent Val', color='orange', linestyle='--')
    ax_loss.set_ylabel('Loss')
    ax_loss.set_title(f'Loss on {subset.replace("_", " ").title()}')
    ax_loss.legend()

    # Plot accuracy
    ax_acc.plot(baseline_hist['val_acc'], label='Baseline', color='blue', marker='o')
    ax_acc.plot(recurrent_hist['val_acc'], label='Recurrent', color='orange', marker='s')
    ax_acc.set_ylabel('Validation Accuracy')
    ax_acc.set_title(f'Accuracy on {subset.replace("_", " ").title()}')
    ax_acc.legend()

axes[-1][0].set_xlabel('Epoch')
axes[-1][1].set_xlabel('Epoch')
plt.tight_layout()
plt.show()

In [None]:
# Bubble plot for performance comparison
plt.figure(figsize=(12, 8))

markers = {
    '10_percent': 'o', # Circle
    '50_percent': 's'  # Square
}
colors = {
    'Baseline': 'blue',
    'Recurrent': 'orange'
}

for i, row in results_df.iterrows():
    plt.scatter(
        row['Size (MB)'],
        row['Accuracy'],
        s=row['inference_time_ms'] * 150, # Bubble size by inference time
        c=colors[row['Model']],
        marker=markers[row['Subset']],
        alpha=0.6,
        label=f"{row['Model']} {row['Subset']}" if i < 4 else "" # Avoid duplicate labels
    )
    plt.text(row['Size (MB)']+0.5, row['Accuracy'], f"{row['Model']}\n{row['Subset'].split('_')[0]}", fontsize=9)

# Create custom legend
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], marker='o', color='w', label='10% Data', markerfacecolor='gray', markersize=10),
    Line2D([0], [0], marker='s', color='w', label='50% Data', markerfacecolor='gray', markersize=10),
    Line2D([0], [0], color='blue', lw=4, label='Baseline Model'),
    Line2D([0], [0], color='orange', lw=4, label='Recurrent Model')
]

plt.xlabel('Model Size (MB)')
plt.ylabel('Test Accuracy')
plt.title('Performance vs. Model Size Across Data Subsets\n(Bubble Size ~ Inference Time)')
plt.legend(handles=legend_elements, loc='lower right')
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()
