In [None]:
import os
import gc
import time
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, roc_curve, precision_recall_curve

from tensorboardX import SummaryWriter

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Create logs directory
os.makedirs('transformers_logs', exist_ok=True)
os.makedirs('transformers_logs/tensorboard', exist_ok=True)
os.makedirs('transformers_logs/model_checkpoints', exist_ok=True)
os.makedirs('transformers_logs/results', exist_ok=True)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
def load_sequences(file_path="Sequence_data.txt"):
    """Load protein sequences from a FASTA file"""
    print("Loading protein sequences...")
    
    headers = []
    sequences = []
    current_header = None
    current_seq = ""
    
    with open(file_path, "r") as file:
        for line in file:
            line = line.strip()
            if line.startswith(">"):
                # If there is an existing sequence, save it
                if current_header:
                    headers.append(current_header)
                    sequences.append(current_seq)
                
                # Extract header ID (middle part)
                full_header = line[1:]
                parts = full_header.split("|")
                current_header = parts[1] if len(parts) > 1 else full_header
                current_seq = ""
            else:
                current_seq += line
        
        # Don't forget the last sequence
        if current_header:
            headers.append(current_header)
            sequences.append(current_seq)
    
    # Create DataFrame
    df = pd.DataFrame({
        "Header": headers,
        "Sequence": sequences
    })
    
    print(f"Loaded {len(df)} protein sequences")
    return df

def load_labels(file_path="labels.xlsx"):
    """Load phosphorylation site labels"""
    print("Loading phosphorylation site labels...")
    df_labels = pd.read_excel(file_path)
    print(f"Loaded {len(df_labels)} phosphorylation sites")
    return df_labels

def merge_sequence_and_labels(df_seq, df_labels):
    """Merge sequence data with labels data"""
    print("Merging sequences with labels...")
    
    # Merge using pandas
    merged_df = pd.merge(
        df_seq,
        df_labels,
        left_on="Header",
        right_on="UniProt ID",
        how="inner"
    )
    
    # Add target column
    merged_df["target"] = 1  # All these are positive examples
    
    print(f"Merged data contains {len(merged_df)} rows")
    return merged_df

# Load data
df_seq = load_sequences()
df_labels = load_labels()
df_merged = merge_sequence_and_labels(df_seq, df_labels)

# Display a few rows to check the data
df_merged.head()

In [None]:
import pandas as pd
import random
import progressbar

def generate_negative_samples(df_merged):
    """
    Generate negative samples for each protein sequence by randomly sampling
    from S/T/Y sites that are not known phosphorylation sites.
    """
    print("Generating negative samples...")
    
    all_rows = []
    groups = list(df_merged.groupby('Header'))
    
    # Create a text-based progress bar
    bar = progressbar.ProgressBar(max_value=len(groups))
    
    for i, (header, group) in enumerate(groups):
        seq = group['Sequence'].iloc[0]
        positive_positions = group['Position'].astype(int).tolist()
        sty_positions = [idx+1 for idx, aa in enumerate(seq) if aa in ("S", "T", "Y")]
        negative_candidates = [pos for pos in sty_positions if pos not in positive_positions]
        
        n_pos = len(positive_positions)
        sample_size = min(n_pos, len(negative_candidates))
        if sample_size > 0:
            sampled_negatives = random.sample(negative_candidates, sample_size)
            
            # keep all positives
            all_rows.append(group)
            
            # add negatives
            for neg_pos in sampled_negatives:
                new_row = group.iloc[0].copy()
                new_row['AA']       = seq[neg_pos - 1]
                new_row['Position'] = neg_pos
                new_row['target']   = 0
                all_rows.append(pd.DataFrame([new_row]))
        
        # advance the bar
        bar.update(i+1)
    
    df_final = pd.concat(all_rows, ignore_index=True)
    print(f"Generated dataset with {len(df_final)} rows (positives + negatives)")
    return df_final

# === Usage ===
# Assuming df_merged is already defined:
df_final = generate_negative_samples(df_merged)

# Check class balance
print("Class distribution:")
print(df_final['target'].value_counts())

In [None]:
class PhosphorylationDataset(Dataset):
    def __init__(self, dataframe, tokenizer, window_size=20, max_length=512):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.window_size = window_size
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        sequence = row['Sequence']
        position = int(row['Position']) - 1  # Convert to 0-based indexing
        target = int(row['target'])
        
        # Extract a window around the phosphorylation site
        start = max(0, position - self.window_size)
        end = min(len(sequence), position + self.window_size + 1)
        
        # The window centered on the target site
        window_sequence = sequence[start:end]
        
        # Tokenize the sequence
        encoding = self.tokenizer(
            window_sequence,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # Remove the batch dimension added by the tokenizer
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'target': torch.tensor(target, dtype=torch.float),
            'sequence': window_sequence,
            'position': torch.tensor(position, dtype=torch.long),
            'header': row['Header']
        }

def split_dataset(df, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """Split the dataset into training, validation, and test sets"""
    assert train_ratio + val_ratio + test_ratio == 1.0
    
    # Group by Header to ensure proteins don't leak across splits
    headers = df['Header'].unique()
    n_headers = len(headers)
    
    # Shuffle the headers
    np.random.shuffle(headers)
    
    # Split points
    train_split = int(n_headers * train_ratio)
    val_split = int(n_headers * (train_ratio + val_ratio))
    
    # Split headers
    train_headers = headers[:train_split]
    val_headers = headers[train_split:val_split]
    test_headers = headers[val_split:]
    
    # Create dataframes
    train_df = df[df['Header'].isin(train_headers)]
    val_df = df[df['Header'].isin(val_headers)]
    test_df = df[df['Header'].isin(test_headers)]
    
    print(f"Train set: {len(train_df)} samples from {len(train_headers)} proteins")
    print(f"Validation set: {len(val_df)} samples from {len(val_headers)} proteins")
    print(f"Test set: {len(test_df)} samples from {len(test_headers)} proteins")
    
    return train_df, val_df, test_df

# Load the ESM-2 tokenizer - using the correct model ID
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Split dataset
train_df, val_df, test_df = split_dataset(df_final)

# Create datasets
train_dataset = PhosphorylationDataset(train_df, tokenizer, window_size=20)
val_dataset = PhosphorylationDataset(val_df, tokenizer, window_size=20)
test_dataset = PhosphorylationDataset(test_df, tokenizer, window_size=20)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# Check a sample from the dataset
sample = next(iter(train_loader))
print(f"Input shape: {sample['input_ids'].shape}")
print(f"Target shape: {sample['target'].shape}")

In [None]:
class PhosphoTransformer(nn.Module):
    def __init__(self, model_name="facebook/esm2_t6_8M_UR50D", dropout_rate=0.3, window_context=3):
        super().__init__()
        
        # Load pre-trained protein language model
        self.protein_encoder = AutoModel.from_pretrained(model_name)
        
        # Get hidden size from the model config
        hidden_size = self.protein_encoder.config.hidden_size
        
        # Context aggregation (lightweight)
        self.window_context = window_context
        context_size = hidden_size * (2*window_context + 1)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(context_size, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 1)
        )
        
    def forward(self, input_ids, attention_mask):
        # Get the transformer outputs
        outputs = self.protein_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Get sequence outputs
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        
        # Find the center position
        center_pos = sequence_output.shape[1] // 2
        
        # Extract features from window around center
        batch_size, seq_len, hidden_dim = sequence_output.shape
        context_features = []
        
        for i in range(-self.window_context, self.window_context + 1):
            pos = center_pos + i
            # Handle boundary cases
            if pos < 0 or pos >= seq_len:
                # Use zero padding for out-of-bounds positions
                context_features.append(torch.zeros(batch_size, hidden_dim, device=sequence_output.device))
            else:
                context_features.append(sequence_output[:, pos, :])
        
        # Concatenate context features
        concat_features = torch.cat(context_features, dim=1)
        
        # Pass through classifier
        logits = self.classifier(concat_features)
        
        return logits.squeeze(-1)

In [None]:
def train_epoch(model, data_loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    all_targets = []
    all_predictions = []
    
    print("Training:")
    for i, batch in enumerate(data_loader):
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['target'].to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(input_ids, attention_mask)
        
        # Calculate loss
        loss = F.binary_cross_entropy_with_logits(outputs, targets)
        
        # Backward pass
        loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        optimizer.step()
        scheduler.step()
        
        # Print progress occasionally
        if i % 20 == 0 or i == len(data_loader) - 1:
            print(f"\rBatch {i+1}/{len(data_loader)}, Loss: {loss.item():.4f}", end="")
        
        # Accumulate loss
        total_loss += loss.item()
        
        # Store predictions and targets for metrics
        all_targets.extend(targets.cpu().numpy())
        all_predictions.extend(torch.sigmoid(outputs).detach().cpu().numpy())
    
    print()  # Add a newline after the progress bar
    
    # Calculate metrics
    all_predictions_binary = (np.array(all_predictions) > 0.5).astype(int)
    accuracy = accuracy_score(all_targets, all_predictions_binary)
    precision = precision_score(all_targets, all_predictions_binary)
    recall = recall_score(all_targets, all_predictions_binary)
    f1 = f1_score(all_targets, all_predictions_binary)
    auc = roc_auc_score(all_targets, all_predictions)
    
    # Calculate average loss
    avg_loss = total_loss / len(data_loader)
    
    return {
        "loss": avg_loss,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc": auc
    }

def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0
    all_targets = []
    all_predictions = []
    
    print("Evaluating:")
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            # Get batch data
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = batch['target'].to(device)
            
            # Forward pass
            outputs = model(input_ids, attention_mask)
            
            # Calculate loss
            loss = F.binary_cross_entropy_with_logits(outputs, targets)
            
            # Accumulate loss
            total_loss += loss.item()
            
            # Store predictions and targets for metrics
            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(torch.sigmoid(outputs).detach().cpu().numpy())
            
            # Print progress occasionally
            if i % 20 == 0 or i == len(data_loader) - 1:
                print(f"\rBatch {i+1}/{len(data_loader)}", end="")
    
    print()  # Add a newline after the progress
    
    # Calculate metrics
    all_predictions_binary = (np.array(all_predictions) > 0.5).astype(int)
    accuracy = accuracy_score(all_targets, all_predictions_binary)
    precision = precision_score(all_targets, all_predictions_binary)
    recall = recall_score(all_targets, all_predictions_binary)
    f1 = f1_score(all_targets, all_predictions_binary)
    auc = roc_auc_score(all_targets, all_predictions)
    
    # Calculate confusion matrix
    cm = confusion_matrix(all_targets, all_predictions_binary)
    
    # Calculate average loss
    avg_loss = total_loss / len(data_loader)
    
    return {
        "loss": avg_loss,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc": auc,
        "confusion_matrix": cm,
        "predictions": all_predictions,
        "targets": all_targets
    }

In [None]:
def train_model(model, train_loader, val_loader, epochs=10, lr=5e-5, weight_decay=0.01):
    # Set up optimizer
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # Total steps for scheduler
    total_steps = len(train_loader) * epochs
    
    # Set up scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )
    
    # Set up TensorBoard writer
    writer = SummaryWriter(log_dir="transformers_logs/tensorboard")
    
    # Best validation metric
    best_val_f1 = 0.0
    
    # Training loop
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        # Train
        train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device)
        
        # Validate
        val_metrics = evaluate(model, val_loader, device)
        
        # Log metrics
        for metric_name, metric_value in train_metrics.items():
            writer.add_scalar(f"train/{metric_name}", metric_value, epoch)
            print(f"Train {metric_name}: {metric_value:.4f}")
        
        for metric_name, metric_value in val_metrics.items():
            if metric_name not in ["confusion_matrix", "predictions", "targets"]:
                writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
                print(f"Val {metric_name}: {metric_value:.4f}")
        
        # Save metrics to log file
        with open(f"transformers_logs/epoch_{epoch+1}_metrics.json", "w") as f:
            json.dump({
                "train": train_metrics,
                "val": {k: v if not isinstance(v, np.ndarray) else v.tolist() 
                       for k, v in val_metrics.items() if k != "predictions" and k != "targets"}
            }, f, indent=2)
        
        # Save model if it's the best so far
        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            torch.save(model.state_dict(), f"transformers_logs/model_checkpoints/best_model.pt")
            print(f"Saved new best model with F1 score: {best_val_f1:.4f}")
        
        # Always save the latest model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_metrics': train_metrics,
            'val_metrics': {k: v for k, v in val_metrics.items() if k != "predictions" and k != "targets"}
        }, f"transformers_logs/model_checkpoints/latest_model.pt")
    
    # Close TensorBoard writer
    writer.close()
    
    return model

# Initialize model
model = PhosphoTransformer().to(device)

# Train model
trained_model = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=5,  # Adjust based on your needs
    lr= 2e-5,   # Start with a small learning rate for fine-tuning
    weight_decay=0.01
)

In [None]:
def plot_training_curves():
    """Plot training and validation loss curves from training logs"""
    
    # Find all metric files from training
    import glob
    log_files = sorted(glob.glob("transformers_logs/epoch_*_metrics.json"))
    
    if not log_files:
        print("No training log files found. Make sure you've run training first.")
        return
    
    # Extract epoch numbers from filenames
    import re
    epochs = [int(re.search(r"epoch_(\d+)_metrics", file).group(1)) for file in log_files]
    
    # Load metrics from each epoch
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    for file in log_files:
        with open(file, 'r') as f:
            metrics = json.load(f)
            train_losses.append(metrics['train']['loss'])
            val_losses.append(metrics['val']['loss'])
            train_accuracies.append(metrics['train']['accuracy'])
            val_accuracies.append(metrics['val']['accuracy'])
    
    # Create loss curve
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Create accuracy curve
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, 'b-', label='Training Accuracy')
    plt.plot(epochs, val_accuracies, 'r-', label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig("transformers_logs/results/training_curves.png", dpi=300)
    plt.close()
    
    print("Training curves saved to transformers_logs/results/training_curves.png")
    
    # Display the plot in the notebook
    from IPython.display import Image, display
    display(Image("transformers_logs/results/training_curves.png"))
    
    # Also create a plot for all metrics
    metrics_to_plot = ['accuracy', 'precision', 'recall', 'f1', 'auc']
    
    plt.figure(figsize=(15, 10))
    
    for i, metric in enumerate(metrics_to_plot):
        plt.subplot(2, 3, i+1)
        
        train_values = [metrics['train'][metric] for metrics in [json.load(open(f)) for f in log_files]]
        val_values = [metrics['val'][metric] for metrics in [json.load(open(f)) for f in log_files]]
        
        plt.plot(epochs, train_values, 'b-', label=f'Training {metric}')
        plt.plot(epochs, val_values, 'r-', label=f'Validation {metric}')
        plt.xlabel('Epoch')
        plt.ylabel(metric.capitalize())
        plt.title(f'Training and Validation {metric.capitalize()}')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig("transformers_logs/results/all_metrics_curves.png", dpi=300)
    plt.close()
    
    print("All metrics curves saved to transformers_logs/results/all_metrics_curves.png")
    display(Image("transformers_logs/results/all_metrics_curves.png"))

# Run the function to generate the plots
plot_training_curves()

In [None]:
def evaluate_and_save_results(model, test_loader, device):
    # Evaluate on test set
    test_metrics = evaluate(model, test_loader, device)
    
    # Print metrics
    print("\nTest Set Metrics:")
    for metric_name, metric_value in test_metrics.items():
        if metric_name not in ["confusion_matrix", "predictions", "targets"]:
            print(f"{metric_name}: {metric_value:.4f}")
    
    # Save metrics to file
    with open("transformers_logs/results/test_metrics.json", "w") as f:
        json.dump({
            k: v if not isinstance(v, np.ndarray) else v.tolist() 
            for k, v in test_metrics.items() if k != "predictions" and k != "targets"
        }, f, indent=2)
    
    # Save detailed predictions for analysis
    test_predictions = []
    print("Generating detailed predictions:")
    with torch.no_grad():
        batch_count = len(test_loader)
        for i, batch in enumerate(test_loader):
            # Print progress
            if i % 5 == 0 or i == batch_count - 1:
                print(f"\rBatch {i+1}/{batch_count}", end="")
                
            # Get batch data
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = batch['target'].cpu().numpy()
            headers = batch['header']
            positions = batch['position'].cpu().numpy()
            sequences = batch['sequence']
            
            # Get predictions
            outputs = model(input_ids, attention_mask)
            probas = torch.sigmoid(outputs).cpu().numpy()
            preds = (probas > 0.5).astype(int)
            
            # Store results
            for i in range(len(headers)):
                test_predictions.append({
                    "header": headers[i],
                    "position": int(positions[i]) + 1,  # Convert back to 1-based indexing
                    "sequence": sequences[i],
                    "target": int(targets[i]),
                    "prediction": int(preds[i]),
                    "probability": float(probas[i])
                })
    
    print()  # New line after progress tracking
    
    # Save detailed predictions
    pd.DataFrame(test_predictions).to_csv("transformers_logs/results/test_predictions.csv", index=False)
    
    return test_metrics, test_predictions

# Load best model
best_model = PhosphoTransformer().to(device)
best_model.load_state_dict(torch.load("transformers_logs/model_checkpoints/best_model.pt"))

# Evaluate
test_metrics, test_predictions = evaluate_and_save_results(best_model, test_loader, device)

In [None]:
def plot_roc_curve(targets, probas, output_path):
    """Plot ROC curve"""
    fpr, tpr, _ = roc_curve(targets, probas)
    auc = roc_auc_score(targets, probas)
    
    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, lw=2, label=f'ROC curve (AUC = {auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig(output_path, dpi=300)
    plt.close()

def plot_precision_recall_curve(targets, probas, output_path):
    """Plot precision-recall curve"""
    from sklearn.metrics import average_precision_score  # Add this import
    
    precision, recall, _ = precision_recall_curve(targets, probas)
    avg_precision = average_precision_score(targets, probas)
    
    plt.figure(figsize=(10, 8))
    plt.plot(recall, precision, lw=2, label=f'PR curve (AP = {avg_precision:.4f})')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc="lower left")
    plt.grid(True)
    plt.savefig(output_path, dpi=300)
    plt.close()

def plot_confusion_matrix(cm, output_path):
    """Plot confusion matrix"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.close()

def plot_metrics_comparison(metrics, output_path):
    """Plot comparison of key metrics"""
    metrics_to_plot = ['accuracy', 'precision', 'recall', 'f1', 'auc']
    values = [metrics[metric] for metric in metrics_to_plot]
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(metrics_to_plot, values, color='steelblue')
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.4f}', ha='center', va='bottom')
    
    plt.ylim(0, 1.1)
    plt.ylabel('Score')
    plt.title('Model Performance Metrics')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.close()

In [None]:
# Load the detailed predictions for visualization
test_predictions_df = pd.read_csv("transformers_logs/results/test_predictions.csv")

# Extract targets and predictions for visualization
targets = test_predictions_df['target'].values
probabilities = test_predictions_df['probability'].values

# Create visualizations
plot_roc_curve(
    targets, 
    probabilities, 
    "transformers_logs/results/roc_curve.png"
)

plot_precision_recall_curve(
    targets, 
    probabilities, 
    "transformers_logs/results/precision_recall_curve.png"
)

# For confusion matrix, we need to compute it from the predictions
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(
    test_predictions_df['target'], 
    test_predictions_df['prediction']
)

plot_confusion_matrix(
    cm, 
    "transformers_logs/results/confusion_matrix.png"
)

# For metrics comparison, calculate metrics from predictions
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

metrics = {
    'accuracy': accuracy_score(test_predictions_df['target'], test_predictions_df['prediction']),
    'precision': precision_score(test_predictions_df['target'], test_predictions_df['prediction']),
    'recall': recall_score(test_predictions_df['target'], test_predictions_df['prediction']),
    'f1': f1_score(test_predictions_df['target'], test_predictions_df['prediction']),
    'auc': roc_auc_score(test_predictions_df['target'], test_predictions_df['probability'])
}

plot_metrics_comparison(
    metrics,
    "transformers_logs/results/metrics_comparison.png"
)

# Show the plots in the notebook
from IPython.display import Image, display

print("ROC Curve:")
display(Image("transformers_logs/results/roc_curve.png"))

print("Precision-Recall Curve:")
display(Image("transformers_logs/results/precision_recall_curve.png"))

print("Confusion Matrix:")
display(Image("transformers_logs/results/confusion_matrix.png"))

print("Metrics Comparison:")
display(Image("transformers_logs/results/metrics_comparison.png"))

In [None]:
def analyze_misclassifications(predictions_df):
    """Analyze misclassified examples"""
    # Load predictions
    if isinstance(predictions_df, str):
        predictions_df = pd.read_csv(predictions_df)
    
    # Find misclassified examples
    misclassified = predictions_df[predictions_df['target'] != predictions_df['prediction']]
    
    # False positives (predicted 1, actual 0)
    false_positives = misclassified[misclassified['prediction'] == 1]
    
    # False negatives (predicted 0, actual 1)
    false_negatives = misclassified[misclassified['prediction'] == 0]
    
    print(f"Total misclassifications: {len(misclassified)} ({len(misclassified)/len(predictions_df)*100:.2f}% of test set)")
    print(f"False positives: {len(false_positives)} ({len(false_positives)/len(misclassified)*100:.2f}% of errors)")
    print(f"False negatives: {len(false_negatives)} ({len(false_negatives)/len(misclassified)*100:.2f}% of errors)")
    
    # Look at the most confident misclassifications
    most_confident_fp = false_positives.sort_values('probability', ascending=False).head(10)
    most_confident_fn = false_negatives.sort_values('probability').head(10)
    
    print("\nMost confident false positives:")
    print(most_confident_fp[['header', 'position', 'sequence', 'probability']])
    
    print("\nMost confident false negatives:")
    print(most_confident_fn[['header', 'position', 'sequence', 'probability']])
    
    # Analyze sequence patterns in misclassifications
    # Count central amino acid types
    center_aa_fp = [seq[len(seq)//2] for seq in false_positives['sequence']]
    center_aa_fn = [seq[len(seq)//2] for seq in false_negatives['sequence']]
    
    fp_aa_counts = pd.Series(center_aa_fp).value_counts()
    fn_aa_counts = pd.Series(center_aa_fn).value_counts()
    
    print("\nCentral amino acid distribution in false positives:")
    print(fp_aa_counts)
    
    print("\nCentral amino acid distribution in false negatives:")
    print(fn_aa_counts)
    
    # Save analysis to file
    with open("transformers_logs/results/misclassification_analysis.txt", "w") as f:
        f.write(f"Total misclassifications: {len(misclassified)} ({len(misclassified)/len(predictions_df)*100:.2f}% of test set)\n")
        f.write(f"False positives: {len(false_positives)} ({len(false_positives)/len(misclassified)*100:.2f}% of errors)\n")
        f.write(f"False negatives: {len(false_negatives)} ({len(false_negatives)/len(misclassified)*100:.2f}% of errors)\n\n")
        
        f.write("Most confident false positives:\n")
        f.write(most_confident_fp[['header', 'position', 'sequence', 'probability']].to_string() + "\n\n")
        
        f.write("Most confident false negatives:\n")
        f.write(most_confident_fn[['header', 'position', 'sequence', 'probability']].to_string() + "\n\n")
        
        f.write("Central amino acid distribution in false positives:\n")
        f.write(fp_aa_counts.to_string() + "\n\n")
        
        f.write("Central amino acid distribution in false negatives:\n")
        f.write(fn_aa_counts.to_string() + "\n")
    
    return misclassified

# Analyze misclassifications
misclassified = analyze_misclassifications("transformers_logs/results/test_predictions.csv")

In [None]:
def compare_with_xgboost():
    """
    Compare transformer model performance with XGBoost baseline
    Note: This assumes you have previously run the XGBoost model and have results
    """
    try:
        # Try to load XGBoost results
        with open("results/metrics_20250414_195426.json", "r") as f:
            xgboost_metrics = json.load(f)
        
        # Load transformer results
        with open("transformers_logs/results/test_metrics.json", "r") as f:
            transformer_metrics = json.load(f)
        
        # Extract metrics for comparison
        metrics_to_compare = ['accuracy', 'precision', 'recall', 'f1', 'auc']
        
        xgb_values = [xgboost_metrics.get(m.capitalize(), xgboost_metrics.get(m.upper(), xgboost_metrics.get(m, 0))) 
                     for m in metrics_to_compare]
        transformer_values = [transformer_metrics[m] for m in metrics_to_compare]
        
        # Create comparison bar chart
        plt.figure(figsize=(12, 8))
        
        x = np.arange(len(metrics_to_compare))
        width = 0.35
        
        plt.bar(x - width/2, xgb_values, width, label='XGBoost')
        plt.bar(x + width/2, transformer_values, width, label='Transformer')
        
        plt.xlabel('Metrics')
        plt.ylabel('Scores')
        plt.title('Performance Comparison: XGBoost vs Transformer')
        plt.xticks(x, metrics_to_compare)
        plt.ylim(0, 1.1)
        plt.legend()
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add value labels
        for i, v in enumerate(xgb_values):
            plt.text(i - width/2, v + 0.01, f'{v:.4f}', ha='center', va='bottom')
            
        for i, v in enumerate(transformer_values):
            plt.text(i + width/2, v + 0.01, f'{v:.4f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig("transformers_logs/results/model_comparison.png", dpi=300)
        plt.close()
        
        print("XGBoost vs Transformer comparison:")
        display(Image("transformers_logs/results/model_comparison.png"))
        
        # Save comparison to text file
        with open("transformers_logs/results/model_comparison.txt", "w") as f:
            f.write("Performance Comparison: XGBoost vs Transformer\n")
            f.write("="*50 + "\n\n")
            
            for i, metric in enumerate(metrics_to_compare):
                f.write(f"{metric.capitalize()}:\n")
                f.write(f"  - XGBoost:     {xgb_values[i]:.4f}\n")
                f.write(f"  - Transformer: {transformer_values[i]:.4f}\n")
                f.write(f"  - Difference:  {transformer_values[i] - xgb_values[i]:.4f}\n\n")
        
        return True
    
    except Exception as e:
        print(f"Could not compare with XGBoost: {e}")
        return False

# Compare with XGBoost if available
try:
    compare_with_xgboost()
except:
    print("XGBoost comparison failed or not available")

In [None]:
def predict_phosphorylation(model, tokenizer, sequence, positions=None, window_size=10, device='cuda'):
    """
    Predict phosphorylation sites in a protein sequence
    
    Args:
        model: Trained model
        tokenizer: Tokenizer for the model
        sequence: Protein sequence as string
        positions: List of positions to check (1-based indexing). If None, check all S/T/Y positions.
        window_size: Window size to use for prediction
        device: Device to run predictions on
        
    Returns:
        DataFrame with predictions
    """
    model.eval()
    
    # If positions not specified, find all S/T/Y positions
    if positions is None:
        positions = [i+1 for i, aa in enumerate(sequence) if aa in ["S", "T", "Y"]]
    
    results = []
    total = len(positions)
    
    print(f"Predicting phosphorylation for {total} positions:")
    with torch.no_grad():
        for i, pos in enumerate(positions):
            # Print progress
            if i % 10 == 0 or i == total - 1:
                print(f"\rPosition {i+1}/{total}", end="")
                
            # Extract window
            pos_idx = pos - 1  # Convert to 0-based
            start = max(0, pos_idx - window_size)
            end = min(len(sequence), pos_idx + window_size + 1)
            window = sequence[start:end]
            
            # Tokenize
            encoding = tokenizer(
                window,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            
            # Move to device
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)
            
            # Predict
            output = model(input_ids, attention_mask)
            probability = torch.sigmoid(output).item()
            prediction = 1 if probability > 0.5 else 0
            
            # Store result
            results.append({
                "position": pos,
                "amino_acid": sequence[pos_idx],
                "window": window,
                "probability": probability,
                "prediction": prediction
            })
    
    print()  # New line after progress tracking
    
    return pd.DataFrame(results)

# Example: Predict phosphorylation on a test sequence
# Replace with actual test sequence from your data
test_seq = df_final['Sequence'].iloc[0]  
test_positions = [int(pos) for pos in df_final[df_final['Sequence'] == test_seq]['Position']]

predictions = predict_phosphorylation(
    model=best_model,
    tokenizer=tokenizer,
    sequence=test_seq,
    positions=test_positions,
    window_size=10,
    device=device
)

print("Prediction results for test sequence:")
print(predictions.head(10))

# Save predictions
predictions.to_csv("transformers_logs/results/example_predictions.csv", index=False)

In [None]:
# Load test metrics for summary
with open("transformers_logs/results/test_metrics.json", "r") as f:
    test_metrics = json.load(f)

# Generate summary
summary = f"""
# Phosphorylation Site Prediction with Transformers: Summary

## Model Architecture
- Pre-trained protein language model: ESM-2
- Fine-tuned on phosphorylation site data
- Window size: 10 amino acids on each side of the site

## Dataset Statistics
- Training samples: {len(train_dataset)}
- Validation samples: {len(val_dataset)}
- Test samples: {len(test_dataset)}

## Performance Metrics
- Accuracy: {test_metrics['accuracy']:.4f}
- Precision: {test_metrics['precision']:.4f}
- Recall: {test_metrics['recall']:.4f}
- F1 Score: {test_metrics['f1']:.4f}
- ROC AUC: {test_metrics['auc']:.4f}

## Confusion Matrix
{test_metrics['confusion_matrix']}

## Key Findings
- The transformer-based model achieves good performance on phosphorylation site prediction
- The model leverages pre-trained protein language model knowledge
- Misclassification analysis shows patterns that could be addressed in future work

## Advantages Over Feature Engineering Approaches
- No manual feature engineering required
- Captures complex patterns and long-range dependencies in protein sequences
- Can be fine-tuned with relatively small datasets

## Future Improvements
- Try different protein language models (ESM-2 larger variants, ProtTrans, etc.)
- Combine with structural information using Graph Neural Networks
- Experiment with different window sizes and model architectures
- Incorporate kinase-specific information for better predictions
"""

# Save summary
with open("transformers_logs/results/final_summary.md", "w") as f:
    f.write(summary)

# Print summary
print(summary)

In [None]:
# Launch TensorBoard
%load_ext tensorboard
%tensorboard --logdir=transformers_logs/tensorboard