In [38]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
from datetime import datetime
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix, roc_curve, auc
import pandas as pd
from datetime import datetime
import os


def load_sequences(filename):
    sequences = {}
    current_id = None
    current_seq = []
    
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if current_id is not None:
                    sequences[current_id] = ''.join(current_seq)
                current_id = line[1:].split()[0]
                current_seq = []
            else:
                current_seq.append(line)
        
        if current_id is not None:
            sequences[current_id] = ''.join(current_seq)
    
    return sequences

class ProteinTokenizer:
    def __init__(self):
        # Define vocabulary for amino acids
        self.amino_acids = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                           'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
        # Add special tokens
        self.special_tokens = ['[PAD]', '[CLS]', '[SEP]', '[UNK]']
        
        # Create vocabulary
        self.vocab = {token: idx for idx, token in enumerate(self.special_tokens + self.amino_acids)}
        self.vocab_size = len(self.vocab)
        
    def encode(self, sequence, max_length=256):
        # Add special tokens
        tokens = ['[CLS]'] + list(sequence) + ['[SEP]']
        
        # Convert to ids
        input_ids = [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens]
        
        # Create attention mask
        attention_mask = [1] * len(input_ids)
        
        # Pad or truncate
        if len(input_ids) < max_length:
            padding_length = max_length - len(input_ids)
            input_ids = input_ids + [self.vocab['[PAD]']] * padding_length
            attention_mask = attention_mask + [0] * padding_length
        else:
            input_ids = input_ids[:max_length]
            attention_mask = attention_mask[:max_length]
        
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(attention_mask)
        }

class ProteinDataset(Dataset):
    def __init__(self, sequences, labels):
        self.tokenizer = ProteinTokenizer()
        
        # Pre-process all sequences at initialization
        print("Pre-processing sequences...")
        self.processed_sequences = []
        for sequence in sequences:
            encoded = self.tokenizer.encode(
                sequence,
                max_length=256
            )
            self.processed_sequences.append(encoded)
        
        self.labels = torch.FloatTensor(labels)
        
    def __len__(self):
        return len(self.processed_sequences)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.processed_sequences[idx]['input_ids'],
            'attention_mask': self.processed_sequences[idx]['attention_mask'],
            'labels': self.labels[idx]
        }

class ProteinBERTEmbeddings(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings=256, dropout=0.1):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.LayerNorm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, device=input_ids.device).expand(input_ids.size(0), -1)
        
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        
        embeddings = words_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class ProteinBERTAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads=8, dropout=0.1):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        
        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)
        
        self.dropout = nn.Dropout(dropout)
        
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, hidden_states, attention_mask=None):
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / torch.sqrt(torch.tensor(self.attention_head_size, dtype=torch.float))
        
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e9)
        
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        
        return context_layer

class ProteinBERTLayer(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = ProteinBERTAttention(hidden_size)
        self.intermediate = nn.Linear(hidden_size, hidden_size * 4)
        self.output = nn.Linear(hidden_size * 4, hidden_size)
        self.LayerNorm1 = nn.LayerNorm(hidden_size)
        self.LayerNorm2 = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, hidden_states, attention_mask=None):
        attention_output = self.attention(hidden_states, attention_mask)
        hidden_states = self.LayerNorm1(hidden_states + attention_output)
        
        intermediate_output = F.gelu(self.intermediate(hidden_states))
        layer_output = self.output(intermediate_output)
        layer_output = self.dropout(layer_output)
        layer_output = self.LayerNorm2(hidden_states + layer_output)
        
        return layer_output

class ProteinBERT(nn.Module):
    def __init__(self, vocab_size, hidden_size=256, num_layers=6):
        super().__init__()
        self.embeddings = ProteinBERTEmbeddings(vocab_size, hidden_size)
        self.encoder = nn.ModuleList([ProteinBERTLayer(hidden_size) for _ in range(num_layers)])
        
    def forward(self, input_ids, attention_mask=None):
        hidden_states = self.embeddings(input_ids)
        
        for layer in self.encoder:
            hidden_states = layer(hidden_states, attention_mask)
        
        return hidden_states

class ProteinFunctionPredictor(nn.Module):
    def __init__(self, num_classes, vocab_size):
        super().__init__()
        hidden_size = 256  # Smaller hidden size for efficiency
        
        # Initialize our custom BERT model
        self.bert = ProteinBERT(vocab_size=vocab_size, hidden_size=hidden_size)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, num_classes),
            nn.Sigmoid()
        )
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[:, 0, :]  # Use [CLS] token representation
        return self.classifier(sequence_output)

def train_model(model, train_loader, val_loader, num_epochs=5):
    device = torch.device('cpu')
    print(f"Using device: {device}")
    
    model = model.to(device)
    criterion = nn.BCELoss()
    
    # Optimize all parameters
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=1, factor=0.5)
    
    best_val_f1 = 0
    patience = 3
    patience_counter = 0
    
    metrics_history = {
        'train_loss': [], 'val_loss': [],
        'train_accuracy': [], 'val_accuracy': [],
        'train_f1': [], 'val_f1': [],
        'train_precision': [], 'val_precision': [],
        'train_recall': [], 'val_recall': []
    }
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_outputs_all = []
        train_labels_all = []
        
        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad(set_to_none=True)
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_outputs_all.append(outputs.detach())
            train_labels_all.append(labels)
            
            # Print batch metrics
            if batch_idx % 50 == 0:
                batch_metrics = calculate_metrics(outputs.detach(), labels)
                print(f'\nEpoch {epoch+1}, Batch {batch_idx}:')
                print(f'Loss: {loss.item():.4f}')
                print(f'Batch Metrics:')
                print(f'  Accuracy: {batch_metrics["accuracy"]:.4f}')
                print(f'  F1: {batch_metrics["f1"]:.4f}')
                print(f'  Precision: {batch_metrics["precision"]:.4f}')
                print(f'  Recall: {batch_metrics["recall"]:.4f}')
        
        # Calculate epoch training metrics
        train_outputs_all = torch.cat(train_outputs_all)
        train_labels_all = torch.cat(train_labels_all)
        train_metrics = calculate_metrics(train_outputs_all, train_labels_all)
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_outputs_all = []
        val_labels_all = []
        
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                val_outputs_all.append(outputs)
                val_labels_all.append(labels)
        
        # Calculate validation metrics
        val_outputs_all = torch.cat(val_outputs_all)
        val_labels_all = torch.cat(val_labels_all)
        val_metrics = calculate_metrics(val_outputs_all, val_labels_all)
        avg_val_loss = val_loss / len(val_loader)
        
        # Update metrics history
        metrics_history['train_loss'].append(avg_train_loss)
        metrics_history['val_loss'].append(avg_val_loss)
        metrics_history['train_accuracy'].append(train_metrics['accuracy'])
        metrics_history['val_accuracy'].append(val_metrics['accuracy'])
        metrics_history['train_f1'].append(train_metrics['f1'])
        metrics_history['val_f1'].append(val_metrics['f1'])
        metrics_history['train_precision'].append(train_metrics['precision'])
        metrics_history['val_precision'].append(val_metrics['precision'])
        metrics_history['train_recall'].append(train_metrics['recall'])
        metrics_history['val_recall'].append(val_metrics['recall'])
        
        # Print comprehensive epoch results
        print(f'\n{"="*70}')
        print(f'Epoch {epoch+1}/{num_epochs} Summary:')
        print(f'{"="*70}')
        
        print('\nTraining Metrics:')
        print(f'  Loss:      {avg_train_loss:.4f}')
        print(f'  Accuracy:  {train_metrics["accuracy"]:.4f}')
        print(f'  F1 Score:  {train_metrics["f1"]:.4f}')
        print(f'  Precision: {train_metrics["precision"]:.4f}')
        print(f'  Recall:    {train_metrics["recall"]:.4f}')
        
        print('\nValidation Metrics:')
        print(f'  Loss:      {avg_val_loss:.4f}')
        print(f'  Accuracy:  {val_metrics["accuracy"]:.4f}')
        print(f'  F1 Score:  {val_metrics["f1"]:.4f}')
        print(f'  Precision: {val_metrics["precision"]:.4f}')
        print(f'  Recall:    {val_metrics["recall"]:.4f}')
        
        # Update learning rate
        scheduler.step(val_metrics['f1'])
        
        # Save best model
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            patience_counter = 0
            print('\nSaving best model...')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_metrics': train_metrics,
                'val_metrics': val_metrics,
                'metrics_history': metrics_history,
            }, 'best_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nEarly stopping triggered!")
                print(f"Best validation F1: {best_val_f1:.4f}")
                break
        
        print(f'\nLearning rate: {optimizer.param_groups[0]["lr"]:.6f}')
        print(f'{"="*70}\n')
        
    return metrics_history

def main():
    print("Loading data...")
    balanced_terms = pd.read_csv("balanced_train_terms.tsv", sep="\t")
    sequences = load_sequences("/Users/devshah/Documents/WorkSpace/University/year 3/CSC413_project/CAFA 5 Protein Function Prediction/Train/train_sequences.fasta")
    
    print("Preparing labels...")
    unique_terms = balanced_terms['term'].unique()
    term_to_idx = {term: idx for idx, term in enumerate(unique_terms)}
    num_classes = len(unique_terms)
    print(f"Number of classes: {num_classes}")
    
    protein_list = list(set(balanced_terms['EntryID']))
    labels = np.zeros((len(protein_list), num_classes))
    protein_to_idx = {pid: idx for idx, pid in enumerate(protein_list)}
    
    for _, row in balanced_terms.iterrows():
        if row['EntryID'] in protein_to_idx:
            protein_idx = protein_to_idx[row['EntryID']]
            term_idx = term_to_idx[row['term']]
            labels[protein_idx, term_idx] = 1
    
    print("Creating dataset...")
    sequence_list = [sequences[pid] for pid in protein_list]
    dataset = ProteinDataset(sequence_list, labels)
    
    # Split dataset
    train_indices, val_indices = train_test_split(
        range(len(dataset)),
        test_size=0.2,
        random_state=42
    )
    
    train_loader = DataLoader(
        dataset,
        batch_size=8,
        sampler=SubsetRandomSampler(train_indices),
        num_workers=0
    )
    
    val_loader = DataLoader(
        dataset,
        batch_size=16,
        sampler=SubsetRandomSampler(val_indices),
        num_workers=0
    )
    
    print("Initializing model...")
    tokenizer = ProteinTokenizer()
    model = ProteinFunctionPredictor(num_classes=num_classes, vocab_size=tokenizer.vocab_size)
    
    print("Starting training...")
    train_model(model, train_loader, val_loader)

if __name__ == '__main__':
    main()

Loading data...
Processed 0/34911 sequences...
Processed 1000/34911 sequences...
Processed 2000/34911 sequences...
Processed 3000/34911 sequences...
Processed 4000/34911 sequences...
Processed 5000/34911 sequences...
Processed 6000/34911 sequences...
Processed 7000/34911 sequences...
Processed 8000/34911 sequences...
Processed 9000/34911 sequences...
Processed 10000/34911 sequences...
Processed 11000/34911 sequences...
Processed 12000/34911 sequences...
Processed 13000/34911 sequences...
Processed 14000/34911 sequences...
Processed 15000/34911 sequences...
Processed 16000/34911 sequences...
Processed 17000/34911 sequences...
Processed 18000/34911 sequences...
Processed 19000/34911 sequences...
Processed 20000/34911 sequences...
Processed 21000/34911 sequences...
Processed 22000/34911 sequences...
Processed 23000/34911 sequences...
Processed 24000/34911 sequences...
Processed 25000/34911 sequences...
Processed 26000/34911 sequences...
Processed 27000/34911 sequences...
Processed 28000/3

  plt.style.use('seaborn')


Training plots saved in 'training_plots' directory


In [None]:
def plot_training_metrics(metrics_history, save_dir='training_plots'):
    """Create comprehensive training metric plots"""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Set style for better visualization
    plt.style.use('seaborn')
    
    # 1. Combined Loss Plot
    plt.figure(figsize=(12, 6))
    plt.plot(metrics_history['train_loss'], label='Training Loss', marker='o')
    plt.plot(metrics_history['val_loss'], label='Validation Loss', marker='o')
    plt.title('Training and Validation Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'{save_dir}/loss_comparison_{timestamp}.png')
    plt.close()
    
    # 2. All Metrics Comparison
    plt.figure(figsize=(15, 10))
    metrics = ['accuracy', 'f1', 'precision', 'recall']
    colors = ['#2ecc71', '#3498db', '#e74c3c', '#f1c40f']
    
    for metric, color in zip(metrics, colors):
        plt.subplot(2, 2, metrics.index(metric) + 1)
        plt.plot(metrics_history[f'train_{metric}'], label=f'Training {metric.capitalize()}',
                color=color, marker='o')
        plt.plot(metrics_history[f'val_{metric}'], label=f'Validation {metric.capitalize()}',
                color=color, linestyle='--', marker='o')
        plt.title(f'{metric.capitalize()} Over Time')
        plt.xlabel('Epoch')
        plt.ylabel(metric.capitalize())
        plt.legend()
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'{save_dir}/all_metrics_comparison_{timestamp}.png')
    plt.close()
    
    # 3. Learning Curve Analysis
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(metrics_history['train_loss']) + 1)
    
    plt.plot(epochs, metrics_history['train_f1'], 'b-', label='Training F1')
    plt.plot(epochs, metrics_history['val_f1'], 'r-', label='Validation F1')
    plt.fill_between(epochs, 
                    [t - v for t, v in zip(metrics_history['train_f1'], metrics_history['val_f1'])],
                    [t + v for t, v in zip(metrics_history['train_f1'], metrics_history['val_f1'])],
                    alpha=0.1, color='gray')
    plt.title('Learning Curve Analysis')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'{save_dir}/learning_curve_{timestamp}.png')
    plt.close()
    
    # 4. Metric Correlation Heatmap
    plt.figure(figsize=(12, 10))
    metric_df = pd.DataFrame(metrics_history)
    correlation = metric_df.corr()
    
    sns.heatmap(correlation, annot=True, cmap='coolwarm', center=0)
    plt.title('Metric Correlation Heatmap')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/metric_correlation_{timestamp}.png')
    plt.close()
    
    # 5. Training Stability Plot
    plt.figure(figsize=(12, 6))
    metrics_std = {metric: np.std(metrics_history[metric]) 
                  for metric in metrics_history.keys()}
    
    plt.bar(metrics_std.keys(), metrics_std.values())
    plt.xticks(rotation=45)
    plt.title('Metric Stability Analysis\n(Lower Standard Deviation = More Stable)')
    plt.ylabel('Standard Deviation')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/stability_analysis_{timestamp}.png')
    plt.close()

def plot_prediction_analysis(model, val_loader, save_dir='prediction_analysis'):
    """Create plots analyzing model predictions"""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    device = next(model.parameters()).device
    model.eval()
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            outputs = model(input_ids, attention_mask)
            all_preds.append(outputs.cpu().numpy())
            all_labels.append(batch['labels'].cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # 1. Prediction Distribution
    plt.figure(figsize=(10, 6))
    plt.hist(all_preds.flatten(), bins=50, alpha=0.5, label='Predictions')
    plt.hist(all_labels.flatten(), bins=50, alpha=0.5, label='Ground Truth')
    plt.title('Distribution of Predictions vs Ground Truth')
    plt.xlabel('Value')
    plt.ylabel('Count')
    plt.legend()
    plt.savefig(f'{save_dir}/prediction_distribution_{timestamp}.png')
    plt.close()
    
    # 2. ROC Curve
    plt.figure(figsize=(10, 6))
    fpr, tpr, _ = roc_curve(all_labels.flatten(), all_preds.flatten())
    roc_auc = auc(fpr, tpr)
    
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.savefig(f'{save_dir}/roc_curve_{timestamp}.png')
    plt.close()
    
    # 3. Confusion Matrix Heatmap
    plt.figure(figsize=(10, 8))
    predictions_binary = (all_preds >= 0.5).astype(int)
    cm = confusion_matrix(all_labels.flatten(), predictions_binary.flatten())
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.savefig(f'{save_dir}/confusion_matrix_{timestamp}.png')
    plt.close()
    
    # 4. Prediction Error Analysis
    plt.figure(figsize=(10, 6))
    errors = np.abs(all_preds - all_labels)
    plt.hist(errors.flatten(), bins=50)
    plt.title('Distribution of Prediction Errors')
    plt.xlabel('Absolute Error')
    plt.ylabel('Count')
    plt.savefig(f'{save_dir}/error_distribution_{timestamp}.png')
    plt.close()

    # 5. Class-wise Performance
    avg_preds_per_class = np.mean(all_preds, axis=0)
    avg_labels_per_class = np.mean(all_labels, axis=0)

    plt.figure(figsize=(12, 6))
    plt.scatter(range(len(avg_preds_per_class)), avg_preds_per_class, 
               alpha=0.5, label='Predicted')
    plt.scatter(range(len(avg_labels_per_class)), avg_labels_per_class, 
               alpha=0.5, label='Actual')
    plt.title('Class-wise Prediction Performance')
    plt.xlabel('Class Index')
    plt.ylabel('Average Probability')
    plt.legend()
    plt.savefig(f'{save_dir}/class_performance_{timestamp}.png')
    plt.close()

# Function to generate plots during training
def generate_training_plots(model, metrics_history, val_loader):
    print("\nGenerating comprehensive analysis plots...")
    plot_training_metrics(metrics_history)
    plot_prediction_analysis(model, val_loader)
    print("Analysis plots saved in 'training_plots' and 'prediction_analysis' directories")

def calculate_metrics(outputs, labels, threshold=0.5):
    """Calculate comprehensive metrics for model evaluation"""
    predictions = (outputs >= threshold)
    
    if isinstance(outputs, torch.Tensor):
        predictions = predictions.cpu().numpy()
        labels = labels.cpu().numpy()
        outputs = outputs.cpu().numpy()
    
    # Calculate metrics
    accuracy = accuracy_score(labels.flatten(), predictions.flatten())
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels.flatten(), 
        predictions.flatten(), 
        average='binary',
        zero_division=0
    )
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }