In [2]:
import torch
from torch import nn
import pandas as pd
import numpy as np
import json
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm  # Use tqdm.notebook for Jupyter
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# For inline plots
%matplotlib inline

In [3]:
# Function to find unique labels
def find_unique_labels(filename):
    # Make an empty set
    unique_labels = set()
    with open(filename, 'r') as file:
        lines = file.readlines()
    for line in lines:
        unique_labels.add(line.strip())
    return unique_labels

# Function to load and preprocess data
def load_data(sequences_file, labels_file):
    # Load sequences
    with open(sequences_file, 'r') as file:
        sequences = [line.strip() for line in file.readlines()]
    
    # Load labels
    with open(labels_file, 'r') as file:
        labels = [line.strip() for line in file.readlines()]
    
    # Create a mapping from label text to numerical value
    unique_labels = list(set(labels))
    label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
    
    # Convert label text to numerical values
    label_ids = [label_to_id[label] for label in labels]
    
    print(f"Loaded {len(sequences)} sequences with average length: {sum(len(s) for s in sequences)/len(sequences):.1f}")
    print(f"Found {len(unique_labels)} unique labels: {', '.join(unique_labels)}")
    
    return sequences, label_ids, label_to_id, unique_labels

In [5]:
# Define SequenceDataset class
class SequenceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, k=6, max_length=512, stride=256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.k = k  # k-mer size
        self.stride = stride  # for overlapping chunks
        
        # Pre-process sequences to k-mers
        self.processed_texts = []
        self.processed_labels = []
        self._preprocess_sequences()
        
    def _create_kmers(self, sequence):
        """Convert a sequence to k-mers"""
        return [sequence[i:i+self.k] for i in range(len(sequence) - self.k + 1)]
    
    def _preprocess_sequences(self):
        """Convert sequences to k-mers and handle chunking for long sequences"""
        for idx, (seq, label) in enumerate(zip(self.texts, self.labels)):
            # Generate k-mers for the sequence
            kmers = self._create_kmers(seq)
            
            # Join k-mers with spaces to create a "sentence" BERT can process
            kmer_text = " ".join(kmers)
            
            # Tokenize the k-mer text
            tokens = self.tokenizer.tokenize(kmer_text)
            
            # If the tokenized sequence is shorter than max_length, add it directly
            if len(tokens) <= self.max_length - 2:  # -2 for [CLS] and [SEP]
                self.processed_texts.append(kmer_text)
                self.processed_labels.append(label)
            else:
                # For longer sequences, we need to chunk them with overlap
                # The tokenizer will add the [CLS] and [SEP] tokens for each chunk
                text_chunks = []
                
                # Split into chunks with overlap
                chunk_length = self.max_length - 2  # -2 for [CLS] and [SEP]
                for i in range(0, len(tokens), self.stride):
                    chunk = tokens[i:i + chunk_length]
                    if len(chunk) > chunk_length // 2:  # Ensure chunk is reasonably sized
                        text_chunks.append(self.tokenizer.convert_tokens_to_string(chunk))
                
                # Add each chunk as a separate entry with the same label
                for chunk in text_chunks:
                    self.processed_texts.append(chunk)
                    self.processed_labels.append(label)
        
        print(f"Processed {len(self.texts)} sequences into {len(self.processed_texts)} chunks")
        
    def __len__(self):
        return len(self.processed_texts)
    
    def __getitem__(self, idx):
        text = self.processed_texts[idx]
        label = self.processed_labels[idx]
        
        # Tokenize the text
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Remove the batch dimension added by the tokenizer
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        
        # Add the label
        encoding['label'] = torch.tensor(label, dtype=torch.long)
        
        return encoding
    
# Create specialized dataset for evaluation with chunk tracking
class ChunkTrackingDataset(SequenceDataset):
    def __init__(self, texts, labels, tokenizer, k=6, max_length=512, stride=256):
        # Store original sequence count before chunking
        self.original_count = len(texts)
        
        # Call parent init to do chunking
        super().__init__(texts, labels, tokenizer, k, max_length, stride)
        
        # Create a mapping from chunk to original sequence
        self.chunk_to_original = []
        
        # Track which chunks belong to which original sequence
        chunk_idx = 0
        for idx, seq in enumerate(texts):
            # Generate k-mers for the sequence
            kmers = self._create_kmers(seq)
            kmer_text = " ".join(kmers)
            tokens = self.tokenizer.tokenize(kmer_text)
            
            if len(tokens) <= self.max_length - 2:
                # One chunk for this sequence
                self.chunk_to_original.append(idx)
                chunk_idx += 1
            else:
                # Multiple chunks for this sequence
                chunk_length = self.max_length - 2
                for i in range(0, len(tokens), self.stride):
                    chunk = tokens[i:i + chunk_length]
                    if len(chunk) > chunk_length // 2:
                        self.chunk_to_original.append(idx)
                        chunk_idx += 1
    
    def __getitem__(self, idx):
        # Get the base encoding from parent class
        encoding = super().__getitem__(idx)
        
        # Add original sequence index
        encoding['original_index'] = torch.tensor(self.chunk_to_original[idx], dtype=torch.long)
        
        return encoding

# %%
# Model definition: BERT with linear regression head for classification
class BERTSequenceClassifier(nn.Module):
    def __init__(self, num_classes, freeze_bert=True):
        super(BERTSequenceClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Linear regression head
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        
        # Freeze BERT parameters if specified
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # Get BERT outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        # We take the [CLS] token representation (first token)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)  # Apply dropout
        
        # Pass through the linear regression head
        logits = self.classifier(pooled_output)
        
        return logits
        
# Model for chunk-based prediction aggregation
class ChunkAggregationModel(nn.Module):
    def __init__(self, base_model):
        super(ChunkAggregationModel, self).__init__()
        self.base_model = base_model
        
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        return self.base_model(input_ids, attention_mask, token_type_ids)
    
    def predict_with_chunks(self, chunks_dataloader, device, original_sequence_count):
        """
        Predict using multiple chunks per sequence and aggregate results.
        
        Args:
            chunks_dataloader: DataLoader containing chunks of sequences
            device: Device to run inference on
            original_sequence_count: Number of original sequences before chunking
            
        Returns:
            List of predictions for each original sequence
        """
        self.base_model.eval()
        
        all_logits = []
        all_chunk_indices = []
        
        # First, get logits for all chunks
        with torch.no_grad():
            for batch in chunks_dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                chunk_indices = batch['original_index'].to(device)
                
                logits = self.base_model(input_ids, attention_mask)
                
                all_logits.append(logits.cpu())
                all_chunk_indices.append(chunk_indices.cpu())
        
        # Concatenate results
        all_logits = torch.cat(all_logits, dim=0)
        all_chunk_indices = torch.cat(all_chunk_indices, dim=0)
        
        # Aggregate predictions for each original sequence
        final_predictions = []
        for i in range(original_sequence_count):
            # Get logits for all chunks of this sequence
            mask = (all_chunk_indices == i)
            sequence_logits = all_logits[mask]
            
            if len(sequence_logits) == 0:
                # Fallback if no chunks for this sequence (shouldn't happen normally)
                final_predictions.append(torch.zeros(all_logits.shape[1]).argmax().item())
                continue
                
            # Average logits across chunks
            avg_logits = torch.mean(sequence_logits, dim=0)
            
            # Get predicted class
            prediction = avg_logits.argmax().item()
            final_predictions.append(prediction)
        
        return final_predictions

In [7]:
# %%
# Function to evaluate the model
def evaluate_model(true_labels, pred_labels, label_names, original_sequences=None, predicted_sequences=None):
    # Print classification report
    report = classification_report(true_labels, pred_labels, target_names=label_names, digits=4)
    print("Classification Report:")
    print(report)
    
    # Create confusion matrix
    cm = np.zeros((len(label_names), len(label_names)), dtype=int)
    for t, p in zip(true_labels, pred_labels):
        cm[t][p] += 1
    
    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_names, yticklabels=label_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()  # Display directly in notebook
    
    # Calculate BLEU and ROUGE if sequences are provided
    if original_sequences is not None and predicted_sequences is not None:
        calculate_sequence_metrics(original_sequences, predicted_sequences)

# Function to calculate BLEU and ROUGE scores for sequences
def calculate_sequence_metrics(original_sequences, predicted_sequences):
    # Prepare for BLEU calculation
    smooth = SmoothingFunction().method1
    bleu_scores = []
    
    # Initialize ROUGE scorer
    rouge_scorer_instance = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    
    # Calculate scores for each pair of sequences
    for orig, pred in zip(original_sequences, predicted_sequences):
        # For BLEU score, convert sequences to list of characters
        orig_tokens = list(orig)
        pred_tokens = list(pred)
        
        # Calculate BLEU - using character-level tokenization
        try:
            bleu_score = sentence_bleu([orig_tokens], pred_tokens, smoothing_function=smooth)
            bleu_scores.append(bleu_score)
        except Exception as e:
            print(f"Error calculating BLEU: {e}")
            
        # Calculate ROUGE scores
        try:
            rouge_result = rouge_scorer_instance.score(orig, pred)
            for metric, score in rouge_result.items():
                rouge_scores[metric].append(score.fmeasure)
        except Exception as e:
            print(f"Error calculating ROUGE: {e}")
    
    # Print average BLEU score
    if bleu_scores:
        print(f"\nAverage BLEU score: {np.mean(bleu_scores):.4f}")
    
    # Print average ROUGE scores
    if all(scores for scores in rouge_scores.values()):
        print("\nAverage ROUGE scores:")
        for metric, scores in rouge_scores.items():
            if scores:
                print(f"  {metric}: {np.mean(scores):.4f}")
                
    # Plot histogram of BLEU scores
    if bleu_scores:
        plt.figure(figsize=(10, 6))
        plt.hist(bleu_scores, bins=20, alpha=0.7)
        plt.xlabel('BLEU Score')
        plt.ylabel('Number of Sequences')
        plt.title('Distribution of BLEU Scores')
        plt.grid(alpha=0.3)
        plt.show()
    
    # Plot histogram of ROUGE scores
    if all(scores for scores in rouge_scores.values()):
        plt.figure(figsize=(12, 6))
        
        for i, (metric, scores) in enumerate(rouge_scores.items()):
            if scores:
                plt.subplot(1, 3, i+1)
                plt.hist(scores, bins=20, alpha=0.7)
                plt.xlabel(f'{metric} Score')
                plt.ylabel('Number of Sequences')
                plt.title(f'Distribution of {metric} Scores')
                plt.grid(alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Function to get actual sequences for predictions
def get_sequences_for_predictions(val_sequences, true_labels, pred_labels, id_to_label):
    """
    Get actual sequences for both true and predicted labels
    Args:
        val_sequences: Original validation sequences
        true_labels: Numerical IDs for true labels
        pred_labels: Numerical IDs for predicted labels
        id_to_label: Mapping from ID to label name
    
    Returns:
        List of original sequences for each true and predicted label
    """
    true_sequences = []
    pred_sequences = []
    
    # Convert numerical labels to actual labels
    true_label_names = [id_to_label[label_id] for label_id in true_labels]
    pred_label_names = [id_to_label[label_id] for label_id in pred_labels]
    
    # For demonstration - typically you'd match specific sequences to predictions
    # This is a simplified approach assuming the labels represent different sequence types
    for i, sequence in enumerate(val_sequences):
        # Add original sequence for both lists (in practice, you'd have true sequences)
        true_sequences.append(sequence)
        
        # For predicted sequences - in practice, you might have a sequence generation model
        # Here we just use original sequence as a placeholder
        pred_sequences.append(sequence) 
    
    return true_sequences, pred_sequences

# Function to visualize training progress
def plot_training_progress(train_losses, val_losses, val_accuracies):
    # Plot losses
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()  # Display directly in notebook

# %%
# Set paths
sequences_file = "extracted-spike-data/Spike7k_sequences.txt"
labels_file = "extracted-spike-data/Spike7k_labels.txt"

# Hyperparameters
k = 6                    # k-mer size
max_length = 512         # BERT's max token length
stride = 256             # Stride for chunking overlaps
batch_size = 16          # Batch size for training
num_epochs = 5           # Number of training epochs
learning_rate = 2e-5     # Learning rate for optimizer
freeze_bert = True       # Whether to freeze BERT parameters

# Set device - with MPS support for Apple Silicon
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon) acceleration")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA acceleration")
else:
    device = torch.device("cpu")
    print("Using CPU (no GPU acceleration available)")

Using MPS (Apple Silicon) acceleration


In [8]:
# %%
# Load data
sequences, label_ids, label_to_id, unique_labels = load_data(sequences_file, labels_file)

# Create inverse mapping from ID to label
id_to_label = {idx: label for label, idx in label_to_id.items()}

# Split data
train_sequences, val_sequences, train_labels, val_labels = train_test_split(
    sequences, label_ids, test_size=0.2, random_state=25, stratify=label_ids
)

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create datasets with k-mer processing
train_dataset = SequenceDataset(train_sequences, train_labels, tokenizer, 
                               k=k, max_length=max_length, stride=stride)

# Use the chunk tracking dataset for validation to enable proper aggregation
val_dataset = ChunkTrackingDataset(val_sequences, val_labels, tokenizer,
                                 k=k, max_length=max_length, stride=stride)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

FileNotFoundError: [Errno 2] No such file or directory: 'extracted-spike-data/Spike7k_sequences.txt'

In [None]:
# %%
# Initialize model with frozen BERT if specified
base_model = BERTSequenceClassifier(num_classes=len(unique_labels), freeze_bert=freeze_bert)
base_model.to(device)

# Initialize optimizer - only optimize parameters that require gradients
optimizer = AdamW([p for p in base_model.parameters() if p.requires_grad], lr=learning_rate)

# Loss function
criterion = nn.CrossEntropyLoss()

# Training history
train_losses = []
val_losses = []
val_accuracies = []

In [None]:
# %%
# Create checkpoint directory
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize checkpoint variables
best_accuracy = 0.0
start_epoch = 0

# Check if checkpoint exists to resume training
checkpoint_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pt')
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    base_model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_accuracy = checkpoint['best_accuracy']
    if 'train_losses' in checkpoint:
        train_losses = checkpoint['train_losses']
        val_losses = checkpoint['val_losses']
        val_accuracies = checkpoint['val_accuracies']
    print(f"Resuming from epoch {start_epoch} with best accuracy: {best_accuracy:.4f}")

# Training loop
for epoch in range(start_epoch, num_epochs):
    try:
        # Training phase
        base_model.train()
        train_loss = 0
        
        for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}"):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
            
            # Calculate loss
            loss = criterion(outputs, labels)
            train_loss += loss.item()
            
            # Backward pass
            loss.backward()
            optimizer.step()
        
        avg_train_loss = train_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)
        
        # Validation phase - use chunk aggregation
        base_model.eval()
        val_loss = 0
        
        # Create aggregation model
        chunk_model = ChunkAggregationModel(base_model)
        
        # Get aggregated predictions
        val_preds = chunk_model.predict_with_chunks(val_dataloader, device, len(val_sequences))
        val_true = val_labels  # Original labels
        
        # Calculate accuracy
        accuracy = accuracy_score(val_true, val_preds)
        val_accuracies.append(accuracy)
        
        # For validation loss, we need to iterate through batches
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc=f"Validation Epoch {epoch+1}"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_dataloader)
        val_losses.append(avg_val_loss)
        
        print(f"Epoch {epoch+1}")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}")
        print(f"  Val Accuracy: {accuracy:.4f}")
        
        # Create checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': base_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_accuracies': val_accuracies,
            'best_accuracy': best_accuracy,
            'config': {
                'k': k,
                'max_length': max_length,
                'stride': stride,
                'freeze_bert': freeze_bert,
                'batch_size': batch_size,
                'learning_rate': learning_rate
            }
        }
        
        # Save the latest checkpoint (overwrite previous)
        torch.save(checkpoint, os.path.join(checkpoint_dir, 'latest_checkpoint.pt'))
        
        # Save epoch-specific checkpoint
        torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt'))
        
        # Save best model
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(checkpoint, os.path.join(checkpoint_dir, 'best_model.pt'))
            print(f"  New best model saved with accuracy: {accuracy:.4f}")
    
    except KeyboardInterrupt:
        print("\nTraining interrupted by user. Saving checkpoint...")
        # Save interrupt checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': base_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_accuracies': val_accuracies,
            'best_accuracy': best_accuracy,
            'config': {
                'k': k,
                'max_length': max_length,
                'stride': stride,
                'freeze_bert': freeze_bert
            }
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, 'interrupt_checkpoint.pt'))
        print(f"Interrupt checkpoint saved. Resume with checkpoint_path='interrupt_checkpoint.pt'")
        break
        
    except Exception as e:
        print(f"\nError during training: {str(e)}")
        # Save emergency checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': base_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_accuracies': val_accuracies,
            'best_accuracy': best_accuracy
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, f'emergency_checkpoint_epoch_{epoch+1}.pt'))
        print(f"Emergency checkpoint saved to 'emergency_checkpoint_epoch_{epoch+1}.pt'")
        raise  # Re-raise the exception for debugging

In [None]:
# %%
# Get sequences for evaluation
true_sequences, pred_sequences = get_sequences_for_predictions(
    val_sequences, val_true, val_preds, id_to_label
)

# Evaluate model with BLEU and ROUGE scores
evaluate_model(val_true, val_preds, unique_labels, true_sequences, pred_sequences)

# Plot training progress
plot_training_progress(train_losses, val_losses, val_accuracies)

# Save model
torch.save(base_model.state_dict(), 'bert_sequence_classifier.pt')

# Save k-mer and tokenization parameters
config = {
    'k': k,
    'max_length': max_length,
    'stride': stride,
    'num_classes': len(unique_labels),
    'freeze_bert': freeze_bert
}

with open('model_config.json', 'w') as f:
    json.dump(config, f)

# Save label mapping
with open('label_mapping.txt', 'w') as f:
    for label, idx in label_to_id.items():
        f.write(f"{label}\t{idx}\n")

print("Training completed and model saved!")

In [None]:
# %%
# Generate detailed BLEU and ROUGE evaluation report
def generate_evaluation_report(true_sequences, pred_sequences, true_labels, pred_labels, id_to_label):
    """
    Generate a detailed evaluation report with per-class BLEU and ROUGE metrics
    """
    # Get label names
    true_label_names = [id_to_label[label_id] for label_id in true_labels]
    pred_label_names = [id_to_label[label_id] for label_id in pred_labels]
    
    # Group sequences by class
    classes = {}
    for i, (true_seq, pred_seq, true_label, pred_label) in enumerate(zip(
            true_sequences, pred_sequences, true_label_names, pred_label_names)):
        
        # Initialize class if not seen before
        if true_label not in classes:
            classes[true_label] = {
                'count': 0,
                'correct': 0,
                'bleu_scores': [],
                'rouge1_scores': [],
                'rouge2_scores': [],
                'rougeL_scores': []
            }
        
        # Update class stats
        classes[true_label]['count'] += 1
        if true_label == pred_label:
            classes[true_label]['correct'] += 1
        
        # Calculate BLEU
        try:
            smooth = SmoothingFunction().method1
            true_tokens = list(true_seq)
            pred_tokens = list(pred_seq)
            bleu = sentence_bleu([true_tokens], pred_tokens, smoothing_function=smooth)
            classes[true_label]['bleu_scores'].append(bleu)
        except Exception as e:
            print(f"Error calculating BLEU for sequence {i}: {e}")
        
        # Calculate ROUGE
        try:
            scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
            scores = scorer.score(true_seq, pred_seq)
            classes[true_label]['rouge1_scores'].append(scores['rouge1'].fmeasure)
            classes[true_label]['rouge2_scores'].append(scores['rouge2'].fmeasure)
            classes[true_label]['rougeL_scores'].append(scores['rougeL'].fmeasure)
        except Exception as e:
            print(f"Error calculating ROUGE for sequence {i}: {e}")
    
    # Print report
    print("\n===== Detailed Evaluation Report =====")
    print(f"Total sequences: {len(true_sequences)}")
    
    # Overall metrics
    all_bleu = [score for class_data in classes.values() for score in class_data['bleu_scores']]
    all_rouge1 = [score for class_data in classes.values() for score in class_data['rouge1_scores']]
    all_rouge2 = [score for class_data in classes.values() for score in class_data['rouge2_scores']]
    all_rougeL = [score for class_data in classes.values() for score in class_data['rougeL_scores']]
    
    print("\nOverall Metrics:")
    print(f"  Average BLEU: {np.mean(all_bleu):.4f}")
    print(f"  Average ROUGE-1: {np.mean(all_rouge1):.4f}")
    print(f"  Average ROUGE-2: {np.mean(all_rouge2):.4f}")
    print(f"  Average ROUGE-L: {np.mean(all_rougeL):.4f}")
    
    # Per-class metrics
    print("\nPer-Class Metrics:")
    for class_name, data in classes.items():
        accuracy = data['correct'] / data['count'] if data['count'] > 0 else 0
        print(f"\n  Class: {class_name}")
        print(f"    Count: {data['count']}")
        print(f"    Accuracy: {accuracy:.4f}")
        print(f"    Average BLEU: {np.mean(data['bleu_scores']):.4f}")
        print(f"    Average ROUGE-1: {np.mean(data['rouge1_scores']):.4f}")
        print(f"    Average ROUGE-2: {np.mean(data['rouge2_scores']):.4f}")
        print(f"    Average ROUGE-L: {np.mean(data['rougeL_scores']):.4f}")

# Run the detailed evaluation
generate_evaluation_report(true_sequences, pred_sequences, val_true, val_preds, id_to_label)

# %%
# Function to evaluate sequence prediction quality with visualizations
def visualize_sequence_predictions(true_sequences, pred_sequences, true_labels, pred_labels, id_to_label, num_examples=5):
    """
    Visualize samples of true vs predicted sequences along with their metrics
    """
    # Convert numerical labels to names
    true_label_names = [id_to_label[label_id] for label_id in true_labels]
    pred_label_names = [id_to_label[label_id] for label_id in pred_labels]
    
    # Get correctly and incorrectly classified examples
    correct_indices = [i for i, (t, p) in enumerate(zip(true_labels, pred_labels)) if t == p]
    incorrect_indices = [i for i, (t, p) in enumerate(zip(true_labels, pred_labels)) if t != p]
    
    # Initialize ROUGE scorer
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    smooth = SmoothingFunction().method1
    
    # Display examples
    print("\n===== Example Predictions =====")
    
    # Show some correct examples
    print("\nCorrectly Classified Examples:")
    sample_correct = np.random.choice(correct_indices, min(num_examples, len(correct_indices)), replace=False)
    for idx in sample_correct:
        # Calculate metrics
        true_seq = true_sequences[idx]
        pred_seq = pred_sequences[idx]
        
        # BLEU score
        true_tokens = list(true_seq)
        pred_tokens = list(pred_seq)
        bleu = sentence_bleu([true_tokens], pred_tokens, smoothing_function=smooth)
        
        # ROUGE scores
        rouge_scores = scorer.score(true_seq, pred_seq)
        
        # Print example
        print(f"\nExample {idx}:")
        print(f"  True Label: {true_label_names[idx]}")
        print(f"  Predicted Label: {pred_label_names[idx]}")
        print(f"  BLEU Score: {bleu:.4f}")
        print(f"  ROUGE-1 F1: {rouge_scores['rouge1'].fmeasure:.4f}")
        print(f"  ROUGE-L F1: {rouge_scores['rougeL'].fmeasure:.4f}")
        
        # Print sequence preview (first 50 chars)
        print(f"  True Sequence: {true_seq[:50]}..." if len(true_seq) > 50 else f"  True Sequence: {true_seq}")
        print(f"  Pred Sequence: {pred_seq[:50]}..." if len(pred_seq) > 50 else f"  Pred Sequence: {pred_seq}")
    
    # Show some incorrect examples
    if incorrect_indices:
        print("\nIncorrectly Classified Examples:")
        sample_incorrect = np.random.choice(incorrect_indices, min(num_examples, len(incorrect_indices)), replace=False)
        for idx in sample_incorrect:
            # Calculate metrics
            true_seq = true_sequences[idx]
            pred_seq = pred_sequences[idx]
            
            # BLEU score
            true_tokens = list(true_seq)
            pred_tokens = list(pred_seq)
            bleu = sentence_bleu([true_tokens], pred_tokens, smoothing_function=smooth)
            
            # ROUGE scores
            rouge_scores = scorer.score(true_seq, pred_seq)
            
            # Print example
            print(f"\nExample {idx}:")
            print(f"  True Label: {true_label_names[idx]}")
            print(f"  Predicted Label: {pred_label_names[idx]}")
            print(f"  BLEU Score: {bleu:.4f}")
            print(f"  ROUGE-1 F1: {rouge_scores['rouge1'].fmeasure:.4f}")
            print(f"  ROUGE-L F1: {rouge_scores['rougeL'].fmeasure:.4f}")
            
            # Print sequence preview (first 50 chars)
            print(f"  True Sequence: {true_seq[:50]}..." if len(true_seq) > 50 else f"  True Sequence: {true_seq}")
            print(f"  Pred Sequence: {pred_seq[:50]}..." if len(pred_seq) > 50 else f"  Pred Sequence: {pred_seq}")
    else:
        print("\nNo incorrectly classified examples found!")

In [None]:
# %%
# Function to plot BLEU and ROUGE score distributions by class
def plot_metric_distributions_by_class(true_sequences, pred_sequences, true_labels, id_to_label):
    """
    Plot the distribution of BLEU and ROUGE scores for each class
    """
    # Convert numerical labels to names
    label_names = [id_to_label[label_id] for label_id in true_labels]
    unique_labels = list(set(label_names))
    
    # Initialize metrics storage
    class_metrics = {label: {
        'bleu': [],
        'rouge1': [],
        'rouge2': [],
        'rougeL': []
    } for label in unique_labels}
    
    # Calculate metrics for each sequence
    smooth = SmoothingFunction().method1
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    for i, (true_seq, pred_seq, label) in enumerate(zip(true_sequences, pred_sequences, label_names)):
        # BLEU score
        try:
            true_tokens = list(true_seq)
            pred_tokens = list(pred_seq)
            bleu = sentence_bleu([true_tokens], pred_tokens, smoothing_function=smooth)
            class_metrics[label]['bleu'].append(bleu)
        except:
            pass
        
        # ROUGE scores
        try:
            scores = scorer.score(true_seq, pred_seq)
            class_metrics[label]['rouge1'].append(scores['rouge1'].fmeasure)
            class_metrics[label]['rouge2'].append(scores['rouge2'].fmeasure)
            class_metrics[label]['rougeL'].append(scores['rougeL'].fmeasure)
        except:
            pass
    
    # Plot BLEU score distributions
    plt.figure(figsize=(12, 6))
    plt.title('BLEU Score Distribution by Class')
    
    # Create box plots for BLEU scores
    bleu_data = [class_metrics[label]['bleu'] for label in unique_labels]
    plt.boxplot(bleu_data, labels=unique_labels)
    plt.ylabel('BLEU Score')
    plt.xticks(rotation=45, ha='right')
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Plot ROUGE score distributions
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # ROUGE-1
    axes[0].set_title('ROUGE-1 F1 Distribution by Class')
    rouge1_data = [class_metrics[label]['rouge1'] for label in unique_labels]
    axes[0].boxplot(rouge1_data, labels=unique_labels)
    axes[0].set_ylabel('ROUGE-1 F1 Score')
    axes[0].set_xticklabels(unique_labels, rotation=45, ha='right')
    axes[0].grid(axis='y', alpha=0.3)
    
    # ROUGE-2
    axes[1].set_title('ROUGE-2 F1 Distribution by Class')
    rouge2_data = [class_metrics[label]['rouge2'] for label in unique_labels]
    axes[1].boxplot(rouge2_data, labels=unique_labels)
    axes[1].set_ylabel('ROUGE-2 F1 Score')
    axes[1].set_xticklabels(unique_labels, rotation=45, ha='right')
    axes[1].grid(axis='y', alpha=0.3)
    
    # ROUGE-L
    axes[2].set_title('ROUGE-L F1 Distribution by Class')
    rougeL_data = [class_metrics[label]['rougeL'] for label in unique_labels]
    axes[2].boxplot(rougeL_data, labels=unique_labels)
    axes[2].set_ylabel('ROUGE-L F1 Score')
    axes[2].set_xticklabels(unique_labels, rotation=45, ha='right')
    axes[2].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Run visualization functions
visualize_sequence_predictions(true_sequences, pred_sequences, val_true, val_preds, id_to_label)
plot_metric_distributions_by_class(true_sequences, pred_sequences, val_true, id_to_label)

In [None]:
# %%
# Function to test model on specific sequences
def test_on_specific_sequences(model, sequences, tokenizer, device, k=6, max_length=512):
    """
    Test the model on specific sequences and return predictions
    """
    model.eval()
    predictions = []
    
    # Process each sequence
    for seq in sequences:
        # Create k-mers
        kmers = [seq[i:i+k] for i in range(len(seq) - k + 1)]
        kmer_text = " ".join(kmers)
        
        # Tokenize
        encoding = tokenizer(
            kmer_text,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Move to device
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        
        # Get prediction
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            pred = outputs.argmax(dim=1).item()
            predictions.append(pred)
    
    return predictions

In [None]:
# %%
# Save complete model evaluation results
def save_evaluation_results(val_true, val_preds, true_sequences, pred_sequences, label_names, id_to_label):
    """
    Save detailed evaluation results to file
    """
    results = {
        "accuracy": accuracy_score(val_true, val_preds),
        "classification_report": classification_report(val_true, val_preds, target_names=label_names, output_dict=True),
        "sequence_metrics": []
    }
    
    # Calculate per-sequence metrics
    smooth = SmoothingFunction().method1
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    for i, (true_seq, pred_seq, true_label, pred_label) in enumerate(zip(
            true_sequences, pred_sequences, val_true, val_preds)):
        
        # Get label names
        true_label_name = id_to_label[true_label]
        pred_label_name = id_to_label[pred_label]
        
        # Calculate BLEU
        try:
            true_tokens = list(true_seq)
            pred_tokens = list(pred_seq)
            bleu = sentence_bleu([true_tokens], pred_tokens, smoothing_function=smooth)
        except:
            bleu = None
        
        # Calculate ROUGE
        try:
            rouge_scores = scorer.score(true_seq, pred_seq)
            rouge1 = rouge_scores['rouge1'].fmeasure
            rouge2 = rouge_scores['rouge2'].fmeasure
            rougeL = rouge_scores['rougeL'].fmeasure
        except:
            rouge1, rouge2, rougeL = None, None, None
        
        # Store metrics
        results["sequence_metrics"].append({
            "index": i,
            "true_label": true_label,
            "true_label_name": true_label_name,
            "pred_label": pred_label,
            "pred_label_name": pred_label_name,
            "correct": true_label == pred_label,
            "bleu": bleu,
            "rouge1": rouge1,
            "rouge2": rouge2,
            "rougeL": rougeL,
            "sequence_length": len(true_seq)
        })
    
    # Save to file
    with open('evaluation_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    print("Evaluation results saved to 'evaluation_results.json'")

# Run final evaluation and save results
save_evaluation_results(val_true, val_preds, true_sequences, pred_sequences, unique_labels, id_to_label)

print("Complete evaluation with BLEU and ROUGE metrics finished!")