In [None]:
import tqdm as notebook_tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split

MODEL_NAME = "Rostlab/prot_bert" 
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" # using mps instead of cuda for training on mac
#DEVICE = "cpu"  # use GPU if available, otherwise CPU
print(f"Using device: {DEVICE}")
NUM_CLASSES = 6  # num classes for classification
BATCH_SIZE = 16
EPOCHS = 5
LR = 0.001

In [None]:
from google.colab import drive
import os
import pandas as pd

# Mount Google Drive
drive.mount('/content/drive')

# Define paths to CSV files
DATA_PATH = "/content/drive/MyDrive/PBLRost/data/reduced_30/"
DATA_PATH_FOLDS = "/content/drive/MyDrive/PBLRost/data/reduced_30/5-fold/"
MODEL_SAVE_PATH = "/content/drive/MyDrive/PBLRost/models/6state_protbert_lstm_cnn_fold{}.pt"
TEST_CSV = os.path.join(DATA_PATH, "reduced_30_signalP6_test.csv")

# K-fold cross validation setup
NUM_FOLDS = 5

# Load test data (same for all folds)
print("Loading test data...")
test_df = pd.read_csv(TEST_CSV)
print(f"Test records: {len(test_df)}")

# Display sample
test_df.head()


In [None]:
# Process test data
print("\nProcessing test data...")
test_df_filtered = test_df[~test_df["labels"].str.contains("P", na=False)]
print(f"Test records after filtering: {len(test_df_filtered)}")

test_df_filtered.describe()


In [None]:
label_map = {'S': 0, 'T': 1, 'L': 2, 'I': 3, 'M': 4, 'O': 5}

# Process test data
test_df_encoded = test_df_filtered.copy()
test_df_encoded["label"] = test_df_encoded["labels"].apply(lambda x: [label_map[c] for c in x if c in label_map])
test_df_encoded = test_df_encoded[test_df_encoded["label"].map(len) > 0]
test_seqs = test_df_encoded["sequence"].tolist()
test_label_seqs = test_df_encoded["label"].tolist()

print(f"Test sequences: {len(test_seqs)}")
test_df_encoded.describe()


In [None]:
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
encoder = BertModel.from_pretrained(MODEL_NAME)
encoder.to(DEVICE)
print("ProtBERT model loaded successfully")

# SPDataset class definition
class SPDataset(Dataset):
    def __init__(self, sequences, label_seqs, label_map):
        self.label_map = label_map
        self.label_seqs = label_seqs
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        # preprocess the sequence (insert spaces between amino acids)
        seq_processed = " ".join(list(seq))
        labels = self.label_seqs[idx]
        # Tokenize the sequence (padding to ensure all sequences are the same length -> 512 tokens) 
        encoded = tokenizer(seq_processed, return_tensors="pt",
                            padding="max_length", truncation=True, max_length=512)
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)
        
        # Build a label tensor of the same length as input_ids.
        # For tokens beyond the original sequence length, assign -100 so that loss func ignores them.
        orig_length = len(seq)
        token_labels = []
        
        for i in range(input_ids.size(0)):
            if i == 0 or i > orig_length:  
                token_labels.append(-100)  # ignore padding tokens
            else:
                # Use the already encoded label directly
                token_labels.append(labels[i-1])
        labels_tensor = torch.tensor(token_labels)
        
        return {
            'input_ids': input_ids, # tokenized and padded 
            'attention_mask': attention_mask, # differentiate between padding and non-padding tokens
            'labels': labels_tensor # aligned label tensor
        }

# Create test dataset (same for all folds)
test_dataset = SPDataset(test_seqs, test_label_seqs, label_map)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"\nTest data prepared: {len(test_seqs)} sequences")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchcrf import CRF

class SPCNNClassifier(nn.Module):
    def __init__(self, encoder_model, num_labels):
        super().__init__()
        self.encoder = encoder_model  
        self.dropout = nn.Dropout(0.2)
        hidden_size = self.encoder.config.hidden_size
        # detects local features in the sequence
        self.conv = nn.Conv1d(in_channels=hidden_size, out_channels=1024, kernel_size=8, dilation=2, padding=7)
        # Normalize the convolution output (expects shape: (batch, 1024, seq_len))
        self.bn_conv = nn.BatchNorm1d(1024)
        # 2 layer long short term memory network
        self.lstm = nn.LSTM(input_size=1024, hidden_size=512, num_layers=2, bidirectional=True, batch_first=True)
        # dense layer
        self.classifier = nn.Linear(512 * 2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)



    def forward(self, input_ids, attention_mask, labels=None):
        # Encode with BERT
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = encoder_output.last_hidden_state  # (batch, seq_len, hidden_size)
        #residue_embeddings = output.last_hidden_state[:, 1:-1, :]  # Remove CLS/SEP

        #  CNN 1
        # Apply conv, then batch normalization and ReLU
        x_conv = self.conv(hidden_states.transpose(1, 2))  # (batch, 1024, seq_len)
        x_conv = self.bn_conv(x_conv)
        x_conv = F.relu_(x_conv)                          # (batch, 1024, seq_len)

        # Transpose CNN output for LSTM: (batch, seq_len, features)
        x_lstm_input = x_conv.transpose(1, 2)           # (batch, seq_len, 1024)

        # Apply BiLSTM
        lstm_out, _ = self.lstm(x_lstm_input)            # (batch, seq_len, 1024)

        # Classifier to num_labels
        x_linear = self.classifier(lstm_out)             # (batch, seq_len, num_labels)
        logits = self.dropout(x_linear)                  # (batch, seq_len, num_labels)

        if labels is not None:
            # Replace ignore-index (-100) with a valid label (0) since CRF doesn't support -100
            mod_labels = labels.clone()
            mod_labels[labels == -100] = 0
            loss = -self.crf(logits, mod_labels, mask=attention_mask.bool(), reduction='mean')
            return loss
        else:
            predictions = self.crf.decode(logits, mask=attention_mask.bool())
            return predictions



In [None]:
from transformers import get_linear_schedule_with_warmup

# Function to load and prepare data for a specific fold
def prepare_fold_data(fold_num):
    """Load and prepare data for a specific fold"""
    # Load fold data
    train_csv = os.path.join(DATA_PATH, f"fold_{fold_num}_train.csv")
    val_csv = os.path.join(DATA_PATH, f"fold_{fold_num}_val.csv")
    
    print(f"\n=== Fold {fold_num} ===")
    print(f"Loading training data from: {train_csv}")
    train_df = pd.read_csv(train_csv)
    print(f"Training records: {len(train_df)}")
    
    print(f"Loading validation data from: {val_csv}")
    val_df = pd.read_csv(val_csv)
    print(f"Validation records: {len(val_df)}")
    
    # Filter data
    train_df_filtered = train_df[~train_df["labels"].str.contains("P", na=False)]
    val_df_filtered = val_df[~val_df["labels"].str.contains("P", na=False)]
    print(f"Training records after filtering: {len(train_df_filtered)}")
    print(f"Validation records after filtering: {len(val_df_filtered)}")
    
    # Encode labels
    train_df_encoded = train_df_filtered.copy()
    train_df_encoded["label"] = train_df_encoded["labels"].apply(lambda x: [label_map[c] for c in x if c in label_map])
    train_df_encoded = train_df_encoded[train_df_encoded["label"].map(len) > 0]
    train_seqs = train_df_encoded["sequence"].tolist()
    train_label_seqs = train_df_encoded["label"].tolist()
    
    val_df_encoded = val_df_filtered.copy()
    val_df_encoded["label"] = val_df_encoded["labels"].apply(lambda x: [label_map[c] for c in x if c in label_map])
    val_df_encoded = val_df_encoded[val_df_encoded["label"].map(len) > 0]
    val_seqs = val_df_encoded["sequence"].tolist()
    val_label_seqs = val_df_encoded["label"].tolist()
    
    # Create datasets
    train_dataset = SPDataset(train_seqs, train_label_seqs, label_map)
    val_dataset = SPDataset(val_seqs, val_label_seqs, label_map)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    print(f"Fold {fold_num} prepared: {len(train_seqs)} train, {len(val_seqs)} val sequences")
    
    return train_loader, val_loader, train_seqs, val_seqs

print("Data preparation function ready")


In [None]:
# Compute sequence-level accuracy, skipping -100 (ignored) positions
def sequence_level_accuracy(preds_flat, labels_flat, test_label_seqs):
    # reconstruct the sequences from the flat predictions
    seq_lengths = [len(seq) for seq in test_label_seqs]
    preds_seq = []
    labels_seq = []
    idx = 0
    for l in seq_lengths:
        preds_seq.append(preds_flat[idx:idx+l])
        labels_seq.append(labels_flat[idx:idx+l])
        idx += l

    # check if the valid predictions match the labels
    correct = 0
    for pred, label in zip(preds_seq, labels_seq):
        is_valid = [l != -100 for l in label]
        valid_preds = [p for p, valid in zip(pred, is_valid) if valid]
        valid_labels = [l for l, valid in zip(label, is_valid) if valid]
        if valid_preds == valid_labels:
            correct += 1

    total = len(seq_lengths)
    return correct / total if total > 0 else 0.0


In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.amp import autocast, GradScaler
import gc
import matplotlib.pyplot as plt

# Store results for all folds
fold_results = {
    'train_losses': [],
    'val_losses': [],
    'best_val_losses': []
}

# K-Fold Cross Validation Training Loop
for fold in range(1, NUM_FOLDS + 1):
    print(f"\n{'='*60}")
    print(f"Starting Fold {fold}/{NUM_FOLDS}")
    print(f"{'='*60}")
    
    # Prepare data for this fold
    train_loader, val_loader, train_seqs, val_seqs = prepare_fold_data(fold)
    
    # Initialize fresh model for each fold
    encoder_fold = BertModel.from_pretrained(MODEL_NAME)
    encoder_fold.to(DEVICE)
    model = SPCNNClassifier(encoder_fold, NUM_CLASSES).to(DEVICE)
    
    # Initialize optimizer and scheduler
    optimizer = torch.optim.AdamW([
        {"params": model.encoder.encoder.layer[-4:].parameters(), "lr": 5e-6},
        {"params": model.conv.parameters(), "lr": 1e-3},
        {"params": model.classifier.parameters(), "lr": 1e-3},
        {"params": model.lstm.parameters(), "lr": 1e-3},
        {"params": model.crf.parameters(), "lr": 1e-3},
    ])
    
    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )
    
    scaler = GradScaler()
    
    # Freeze encoder parameters initially
    for param in model.encoder.encoder.layer[:-10].parameters():
        param.requires_grad = False
    model.encoder.gradient_checkpointing_enable()
    
    # Track losses for this fold
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    
    # Training loop for this fold
    for epoch in range(EPOCHS):
        # Training phase
        model.train()
        pbar = tqdm(train_loader, desc=f"Fold {fold} - Epoch {epoch+1}/{EPOCHS} [Train]", unit="batch")
        total_train_loss = 0
        
        for batch in pbar:
            try:
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                token_labels = batch['labels'].to(DEVICE)
                
                optimizer.zero_grad()
                
                loss = model(input_ids, attention_mask, token_labels)
                
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                
                total_train_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
                
            except RuntimeError as e:
                print("Error during training:", e)
                gc.collect()
                continue
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Fold {fold} - Epoch {epoch+1}/{EPOCHS} [Val]", unit="batch"):
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                token_labels = batch['labels'].to(DEVICE)
                
                loss = model(input_ids, attention_mask, token_labels)
                total_val_loss += loss.item()
                val_batches += 1
        
        avg_val_loss = total_val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)
        
        print(f"Fold {fold} - Epoch {epoch+1}/{EPOCHS} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # Save best model for this fold
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            model_path = MODEL_SAVE_PATH.format(fold)
            torch.save(model.state_dict(), model_path)
            print(f"  → Best model for fold {fold} saved to {model_path}")
    
    # Store results for this fold
    fold_results['train_losses'].append(train_losses)
    fold_results['val_losses'].append(val_losses)
    fold_results['best_val_losses'].append(best_val_loss)
    
    print(f"\nFold {fold} complete! Best validation loss: {best_val_loss:.4f}")
    
    # Plot losses for this fold
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, EPOCHS + 1), train_losses, label='Train Loss', marker='o')
    plt.plot(range(1, EPOCHS + 1), val_losses, label='Validation Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {fold} - Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

print("\n" + "="*60)
print("K-Fold Cross Validation Complete!")
print("="*60)

# Print summary of all folds
print("\nSummary of all folds:")
for i, best_loss in enumerate(fold_results['best_val_losses'], 1):
    print(f"Fold {i}: Best Validation Loss = {best_loss:.4f}")

avg_best_val_loss = sum(fold_results['best_val_losses']) / NUM_FOLDS
print(f"\nAverage Best Validation Loss across all folds: {avg_best_val_loss:.4f}")

# Plot comparison across folds
plt.figure(figsize=(12, 6))
for i in range(NUM_FOLDS):
    plt.plot(range(1, EPOCHS + 1), fold_results['val_losses'][i], 
             label=f'Fold {i+1}', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Across All Folds')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:

print("All fold models saved successfully!")
for i in range(1, NUM_FOLDS + 1):
    print(f"Fold {i}: {MODEL_SAVE_PATH.format(i)}")


In [None]:
# Evaluation on Test Set using all fold models

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, accuracy_score

# Store predictions from all folds
all_fold_predictions = []
all_fold_metrics = []

print("Evaluating all fold models on test set...\n")

for fold in range(1, NUM_FOLDS + 1):
    print(f"Evaluating Fold {fold} Model")
    
    # Load the trained model for this fold
    encoder_eval = BertModel.from_pretrained(MODEL_NAME)
    encoder_eval.to(DEVICE)
    model = SPCNNClassifier(encoder_eval, NUM_CLASSES).to(DEVICE)
    
    model_path = MODEL_SAVE_PATH.format(fold)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    print(f"Model loaded from {model_path}")
    
    model.eval()
    test_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            
            # Compute loss using CRF
            loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            test_loss += loss.item()
            
            # Decode predictions using CRF
            predictions = model(input_ids=input_ids, attention_mask=attention_mask)
            
            # Collect valid tokens
            for pred_seq, label_seq, mask in zip(predictions, labels, attention_mask):
                for pred, true, is_valid in zip(pred_seq, label_seq, mask):
                    if true.item() != -100 and is_valid.item() == 1:
                        all_preds.append(pred)
                        all_labels.append(true.item())
    
    # Store predictions for ensemble
    all_fold_predictions.append(all_preds)
    
    # Calculate metrics for this fold
    print(f"\nFold {fold} Results:")
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=list(label_map.keys())))
    
    f1_weighted = f1_score(all_labels, all_preds, average='weighted')
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    mcc = matthews_corrcoef(all_labels, all_preds)
    token_acc = accuracy_score(all_labels, all_preds)
    seq_acc = sequence_level_accuracy(all_preds, all_labels, test_label_seqs)
    avg_loss = test_loss / len(test_loader)
    
    print(f"F1 Score (weighted): {f1_weighted:.4f}")
    print(f"F1 Score (macro): {f1_macro:.4f}")
    print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")
    print(f"Token-level Accuracy: {token_acc:.4f}")
    print(f"Sequence Level Accuracy: {seq_acc:.4f}")
    print(f"Average test loss: {avg_loss:.4f}")
    
    all_fold_metrics.append({
        'fold': fold,
        'f1_weighted': f1_weighted,
        'f1_macro': f1_macro,
        'mcc': mcc,
        'token_acc': token_acc,
        'seq_acc': seq_acc,
        'test_loss': avg_loss
    })


metrics_df = pd.DataFrame(all_fold_metrics)
print(metrics_df.to_string(index=False))

print(f"\nAverage across folds:")
print(f"F1 (weighted): {metrics_df['f1_weighted'].mean():.4f} ± {metrics_df['f1_weighted'].std():.4f}")
print(f"F1 (macro): {metrics_df['f1_macro'].mean():.4f} ± {metrics_df['f1_macro'].std():.4f}")
print(f"MCC: {metrics_df['mcc'].mean():.4f} ± {metrics_df['mcc'].std():.4f}")
print(f"Token Accuracy: {metrics_df['token_acc'].mean():.4f} ± {metrics_df['token_acc'].std():.4f}")
print(f"Sequence Accuracy: {metrics_df['seq_acc'].mean():.4f} ± {metrics_df['seq_acc'].std():.4f}")
print(f"Test Loss: {metrics_df['test_loss'].mean():.4f} ± {metrics_df['test_loss'].std():.4f}")

# Confusion Matrix for best performing fold
best_fold_idx = metrics_df['f1_weighted'].idxmax()
best_fold_num = all_fold_metrics[best_fold_idx]['fold']

print(f"\nShowing confusion matrix for best performing fold: Fold {best_fold_num}")

# Re-evaluate best fold for confusion matrix
encoder_best = BertModel.from_pretrained(MODEL_NAME)
encoder_best.to(DEVICE)
model_best = SPCNNClassifier(encoder_best, NUM_CLASSES).to(DEVICE)
model_best.load_state_dict(torch.load(MODEL_SAVE_PATH.format(best_fold_num), map_location=DEVICE))
model_best.eval()

all_preds_best = []
all_labels_best = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)
        
        predictions = model_best(input_ids=input_ids, attention_mask=attention_mask)
        
        for pred_seq, label_seq, mask in zip(predictions, labels, attention_mask):
            for pred, true, is_valid in zip(pred_seq, label_seq, mask):
                if true.item() != -100 and is_valid.item() == 1:
                    all_preds_best.append(pred)
                    all_labels_best.append(true.item())

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(all_labels_best, all_preds_best, labels=list(label_map.values()))
cm_relative = cm.astype("float") / cm.sum(axis=1, keepdims=True)
disp = ConfusionMatrixDisplay(confusion_matrix=cm_relative, display_labels=list(label_map.keys()))
disp.plot(cmap="OrRd", xticks_rotation=45)
plt.title(f"Confusion Matrix - Fold {best_fold_num} (Best)")
plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

def draw_vertical_model():
    fig, ax = plt.subplots(figsize=(6, 10))
    ax.axis('off')

    # Define blocks with (x, y)
    blocks = [
        ("Input IDs", 0.5, 9.0),
        ("Attention Mask", 2.5, 9.0),
        ("Encoder\n(BERT-like)", 1.5, 8.0),
        ("Conv1D\n(1 layer, kernel=8, dilation=2)", 1.5, 7.0),
        ("BatchNorm1D + ReLU", 1.5, 6.0),
        ("BiLSTM\n(2 layers, hidden=512)", 1.5, 5.0),
        ("Dense\n(1024 → num_label (6))", 1.5, 4.0),
        ("Dropout\n(p=0.2)", 1.5, 3.0),
        ("CRF Layer", 1.5, 2.0),
        ("Per token Predictions", 1.5, 1.0)
    ]

    box_width = 2.0
    box_height = 0.6

    # Draw blocks
    for label, x, y in blocks:
        rect = mpatches.FancyBboxPatch((x, y), box_width, box_height, boxstyle="round,pad=0.03",
                                       edgecolor='black', facecolor='skyblue', linewidth=2)
        ax.add_patch(rect)
        ax.text(x + box_width / 2, y + box_height / 2, label, ha='center', va='center', fontsize=10)

    # Draw arrows
    for i in range(2, len(blocks) - 1):  # skip input IDs and mask arrows
        x1 = blocks[i][1] + box_width / 2
        y1 = blocks[i][2]
        y2 = blocks[i+1][2] + box_height
        ax.annotate('', xy=(x1, y2), xytext=(x1, y1),
                    arrowprops=dict(facecolor='black', arrowstyle='->'))

    # Draw arrows from inputs
    ax.annotate('', xy=(1.5 + box_width/2, 8.6), xytext=(0.5 + box_width/2, 9.0),
                arrowprops=dict(facecolor='black', arrowstyle='->'))
    ax.annotate('', xy=(1.5 + box_width/2, 8.6), xytext=(2.5 + box_width/2, 9.0),
                arrowprops=dict(facecolor='black', arrowstyle='->'))

    plt.title("SPCNNClassifier Architecture", fontsize=14)
    plt.ylim(0, 10)
    plt.xlim(0, 5)
    plt.tight_layout()
    plt.show()

draw_vertical_model()
