In [None]:
#@title 0. uv calibration
import os
!curl -Ls https://astral.sh/uv/install.sh | bash
os.environ["PATH"] += ":/root/.cargo/bin"
!uv --version

In [None]:
#@title 1. Installs, Imports and Main Configuration
!uv pip install transformers==4.38.2 -q
!uv pip install sentencepiece==0.2.0 -q
!uv pip install torch-xla==2.1.0 -q # For TPU support
!uv pip install pytorch-crf==0.7.2 -q
!uv pip install pandas==2.2.2 -q
!uv pip install scikit-learn==1.4.2 -q
!uv pip install tensorboard==2.15.2 -q

# General imports
import os
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Scikit-learn imports for data handling and metrics
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.metrics import (
    classification_report,
    matthews_corrcoef,
    accuracy_score,
    f1_score,
    confusion_matrix,
    ConfusionMatrixDisplay
)

# PyTorch and Transformers imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchcrf import CRF
from transformers import T5Tokenizer, T5EncoderModel, get_linear_schedule_with_warmup
import tensorflow as tf
try:
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(resolver)
  tf.tpu.experimental.initialize_tpu_system(resolver)
  print("All devices: ", tf.config.list_logical_devices('TPU'))
  strategy = tf.distribute.experimental.TPUStrategy(resolver)
except ValueError:
  strategy = tf.distribute.get_strategy()

# --- Main Configuration ---

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Model Configuration
MODEL_NAME = "Rostlab/ProstT5"
NUM_CLASSES = 6  # num classes for classification ('S', 'T', 'L', 'I', 'M', 'O')

# Training Hyperparameters
BATCH_SIZE = 16 # Reduced batch size for better memory management
EPOCHS = 10
MAX_LENGTH = 512 # Max sequence length for tokenizer

# Optimizer Hyperparameters
CLASSIFIER_LR = 1e-3 # Learning rate for the new layers (classifier head)
ENCODER_LR_INITIAL = 0.0 # Initial LR for the transformer encoder (frozen)
ENCODER_LR_UNFROZEN = 2e-5 # LR for the encoder when unfrozen
WEIGHT_DECAY = 0.01

# --- Device Setup (CPU, GPU, or TPU) ---
TPU_AVAILABLE = False
try:
    import torch_xla.core.xla_model as xm
    TPU_AVAILABLE = xm.xla_device() == 'xla'
except ImportError:
    TPU_AVAILABLE = False

DEVICE = (
    "xla" if TPU_AVAILABLE else
    "mps" if torch.backends.mps.is_available() else
    "cuda" if torch.cuda.is_available() else
    "cpu"
)
print(f"Using device: {DEVICE}")

# --- File Paths ---
# Ensure you have your data in the specified Google Drive path
DRIVE_PATH = "/content/drive/MyDrive/PBL Rost/"
DATA_FILE = os.path.join(DRIVE_PATH, "data/complete_set_unpartitioned.fasta")
MODEL_SAVE_PATH = os.path.join(DRIVE_PATH, "models/optimized_bert_classifier.pt")
LOG_DIR = os.path.join(DRIVE_PATH, "logs/")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title 2. Data Loading and Preparation
def load_and_prepare_data(data_path: str, label_map: dict):
    """
    Loads data from a FASTA file, performs cleaning, balancing, and splitting.

    Args:
        data_path (str): Path to the FASTA file.
        label_map (dict): Mapping from character labels to integer indices.

    Returns:
        tuple: train_sequences, test_sequences, train_labels, test_labels
    """
    # 1. Load data from FASTA file
    print("Loading data from FASTA file...")
    records = []
    with open(data_path, "r") as f:
        current_record = {}
        for line in f:
            if line.startswith(">"):
                if current_record:
                    records.append(current_record)
                header = line[1:].strip().split("|")
                # Handle cases where the header might not have 3 parts
                if len(header) == 3:
                    current_record = {
                        "uniprot_ac": header[0],
                        "kingdom": header[1],
                        "type": header[2],
                        "sequence": "",
                        "label": ""
                    }
                else:
                    current_record = {} # Reset if header is malformed
            elif current_record: # Ensure we have a record to add to
                # This assumes sequence comes before label
                if not current_record.get("sequence"):
                    current_record["sequence"] = line.strip()
                elif not current_record.get("label"):
                    current_record["label"] = line.strip()
    if current_record:
        records.append(current_record)
    df_raw = pd.DataFrame(records)
    print(f"Loaded {len(df_raw)} raw records.")

    # 2. Clean data: drop rows with missing values
    df_raw.dropna(subset=['sequence', 'label', 'type'], inplace=True)
    print(f"Records after dropping NA: {len(df_raw)}")

    # 3. Filter out records with 'P' in the label (as in original notebook)
    df = df_raw[~df_raw["label"].str.contains("P")].copy()
    print(f"Records after filtering 'P' labels: {len(df)}")

    # 4. Balance classes using oversampling
    print("Balancing classes using oversampling...")
    df_majority = df[df["type"] == "NO_SP"]
    df_minority = df[df["type"] != "NO_SP"]

    if not df_minority.empty:
        df_minority_upsampled = resample(
            df_minority,
            replace=True,
            n_samples=len(df_majority),
            random_state=42
        )
        df_balanced = pd.concat([df_majority, df_minority_upsampled])
    else:
        df_balanced = df_majority.copy()

    df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
    print(f"Total records after oversampling: {len(df_balanced)}")
    print("Class distribution after oversampling:")
    print(df_balanced["type"].value_counts())

    # 5. Encode labels and prepare lists
    # Ensure labels are within the valid range before encoding
    valid_chars = list(label_map.keys())
    df_balanced["label_encoded"] = df_balanced["label"].apply(
        lambda x: [label_map[c] for c in x if c in valid_chars]
    )
    # Remove rows where the label sequence became empty after mapping
    df_final = df_balanced[df_balanced["label_encoded"].map(len) > 0].copy()

    sequences = df_final["sequence"].tolist()
    label_seqs = df_final["label_encoded"].tolist()
    print(f"Final dataset size: {len(sequences)}")

    # 6. Split into training and testing sets
    train_seqs, test_seqs, train_label_seqs, test_label_seqs = train_test_split(
        sequences, label_seqs, test_size=0.2, random_state=42, stratify=df_final['type']
    )
    print(f"Training set size: {len(train_seqs)}")
    print(f"Test set size: {len(test_seqs)}")

    return train_seqs, test_seqs, train_label_seqs, test_label_seqs

class SPDataset(Dataset):
    """
    Custom PyTorch Dataset for protein sequences and their labels.
    Correctly pads labels with -100 for Pytorch's CrossEntropyLoss ignore_index.
    """
    def __init__(self, sequences, label_seqs, tokenizer, max_length):
        self.sequences = sequences
        self.label_seqs = label_seqs
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        labels = self.label_seqs[idx]

        # Tokenize the sequence
        # We add spaces between amino acids for the T5 tokenizer
        spaced_seq = " ".join(list(seq))
        encoded = self.tokenizer(
            spaced_seq,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)

        # add 0 padding
        token_labels = np.zeros(self.max_length, dtype=int)

        # Get the length of the actual tokenized sequence (excluding padding)
        # This is the sum of the attention mask
        actual_token_len = attention_mask.sum().item()

        # Determine the length of the labels to use based on the actual token length
        # and the original label sequence length, capped by max_length - 1 (for </S>)
        # The T5 tokenizer adds a </S> token at the end. We should not have a label for it.
        # The labels correspond to the amino acid sequence, which aligns with the input tokens
        # before the final </S> token and any padding.
        label_len_to_use = min(len(labels), actual_token_len -1) # -1 for the end token

        # Place the actual labels at the beginning, corresponding to the non-padded, non-</S> tokens
        if label_len_to_use > 0:
             # Ensure labels are within the valid range [0, NUM_CLASSES-1]
            valid_labels = [l for l in labels[:label_len_to_use] if 0 <= l < len(label_map)]
            # If filtering removed any labels, adjust the length
            label_len_to_use = len(valid_labels)
            token_labels[:label_len_to_use] = torch.tensor(valid_labels, dtype=torch.long)


        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': token_labels
        }

In [None]:
#@title 3. Model Definition (ProstT5-CNN-BiLSTM-CRF)
class SPCNNClassifier(nn.Module):
    """
    A sophisticated model combining a pre-trained encoder with CNN, BiLSTM, and CRF layers.
    """
    def __init__(self, encoder_model, num_labels):
        super().__init__()
        self.encoder = encoder_model
        self.dropout = nn.Dropout(0.2)
        hidden_size = self.encoder.config.hidden_size

        # 1D Convolutional layer to extract local features
        self.conv = nn.Conv1d(
            in_channels=hidden_size,
            out_channels=hidden_size, # Keep dimensionality
            kernel_size=5,
            padding="same" # "same" padding ensures output length is same as input
        )
        # Batch Normalization for the convolution output
        self.bn_conv = nn.BatchNorm1d(hidden_size)

        # 3-layer Bidirectional LSTM to capture long-range dependencies
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size // 2, # Halve size because it's bidirectional
            num_layers=3,
            bidirectional=True,
            batch_first=True,
            dropout=0.1 # Add dropout between LSTM layers
        )
        # Classifier head to project LSTM output to the number of classes
        self.classifier = nn.Linear(hidden_size, num_labels)
        # Conditional Random Field (CRF) layer to model dependencies between labels
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None, return_loss_only=False):
        # 1. Get embeddings from the pre-trained encoder
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = encoder_output.last_hidden_state  # (batch, seq_len, hidden_size)

        # 2. Apply CNN layer
        # Transpose for Conv1d: (batch, hidden_size, seq_len)
        x_conv_input = hidden_states.transpose(1, 2)
        x_conv = self.conv(x_conv_input)
        x_conv = self.bn_conv(x_conv)
        x_conv = F.relu(x_conv)
        # Transpose back: (batch, seq_len, hidden_size)
        x_lstm_input = x_conv.transpose(1, 2)

        # 3. Apply BiLSTM layer
        lstm_out, _ = self.lstm(x_lstm_input)

        # 4. Apply Classifier
        logits = self.classifier(lstm_out)
        logits = self.dropout(logits)

        # 5. Use CRF for loss calculation or decoding
        # The CRF layer expects a mask of boolean type.
        crf_mask = attention_mask.bool()

        if labels is not None:
            loss = -self.crf(logits, labels, mask=crf_mask, reduction='mean')
            if return_loss_only:
                return loss
            else:
                predictions = self.crf.decode(logits, mask=crf_mask)
                return loss, predictions
        else:
            predictions = self.crf.decode(logits, mask=crf_mask)
            return predictions

In [None]:
#@title 4. Training and Evaluation Functions
def train_one_epoch(model, loader, optimizer, scheduler, device, scaler=None):
    """
    Trains the model for one epoch.

    Args:
        model: The PyTorch model.
        loader: The DataLoader for training data.
        optimizer: The optimizer.
        scheduler: The learning rate scheduler.
        device: The device to train on ('cuda', 'cpu', or XLA device).
        scaler: GradScaler for mixed-precision training on CUDA.
    """
    model.train()
    total_loss = 0
    pbar = tqdm(loader, desc="Training", leave=False)

    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        # Get loss only for backpropagation
        loss = model(input_ids, attention_mask, labels, return_loss_only=True)

        if scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            if TPU_AVAILABLE:
                # Use xm.optimizer_step on TPU
                xm.optimizer_step(optimizer, barrier=True)
            else:
                optimizer.step()

        scheduler.step()
        total_loss += loss.item()
        pbar.set_postfix(loss=f'{loss.item():.4f}')

    return total_loss / len(loader)

def evaluate(model, loader, device):
    """
    Evaluates the model on the validation set.

    Args:
        model: The PyTorch model.
        loader: The DataLoader for validation data.
        device: The device to evaluate on.

    Returns:
        tuple: (validation_loss, all_predictions, all_labels)
    """
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    pbar = tqdm(loader, desc="Evaluating", leave=False)

    with torch.no_grad():
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Get loss and predictions (loss is calculated even in eval mode when labels are provided)
            loss, predictions = model(input_ids, attention_mask, labels, return_loss_only=False) # Keep return_loss_only=False for evaluation

            total_loss += loss.item()

            # Collect valid predictions and labels, ignoring padding/-100
            for pred_seq, label_seq, mask in zip(predictions, labels, attention_mask.bool()):
                # Unpack the list of lists from predictions
                # predictions from CRF decode is a list of lists, where each inner list is a sequence of predicted tags
                flat_preds = [item for sublist in pred_seq for item in (sublist if isinstance(sublist, list) else [sublist])]


                # Align with labels based on the attention mask
                # and ignore labels that are -100 (padding in the custom dataset)
                active_labels = label_seq[mask]
                # Exclude -100 labels and corresponding predictions
                valid_indices = (active_labels != -100).nonzero(as_tuple=True)[0]
                valid_labels = active_labels[valid_indices]
                # Predictions are for the whole sequence, so we need to align them with the valid labels
                valid_preds = [flat_preds[i] for i in valid_indices]


                all_preds.extend(valid_preds)
                all_labels.extend(valid_labels.cpu().numpy())

    return total_loss / len(loader), all_preds, all_labels

def sequence_level_accuracy(preds_flat, labels_flat, test_label_seqs):
    """
    Computes sequence-level accuracy. A sequence is correct only if all its labels are correct.
    Note: This function now relies on the flat lists and will need the original sequence lengths
          to correctly group predictions and labels.
    """
    # The original test_label_seqs provides the lengths before padding and filtering -100.
    # We need to reconstruct the sequences from the flat lists based on these original lengths
    # and the valid labels/preds that were collected.

    correct_sequences = 0
    current_flat_idx = 0

    if not preds_flat or not labels_flat:
        return 0.0

    for original_seq_labels in test_label_seqs:
        # Determine the number of valid labels for this original sequence from the flat list
        # This assumes the order in the flat lists matches the order of sequences in test_label_seqs
        num_valid_labels_in_flat = sum(1 for label in original_seq_labels if 0 <= label < len(label_map)) # Count non -100 like labels based on original map size

        # Extract the corresponding chunk from the flat lists
        pred_seq_flat_chunk = preds_flat[current_flat_idx : current_flat_idx + num_valid_labels_in_flat]
        label_seq_flat_chunk = labels_flat[current_flat_idx : current_flat_idx + num_valid_labels_in_flat]

        if pred_seq_flat_chunk == list(label_seq_flat_chunk): # Convert numpy array to list for comparison
             correct_sequences += 1

        current_flat_idx += num_valid_labels_in_flat


    return correct_sequences / len(test_label_seqs) if test_label_seqs else 0.0

In [None]:
#@title 5. Main Execution: Setup, Train, and Evaluate
# --- 1. Setup ---
# Label mapping
label_map = {'S': 0, 'T': 1, 'L': 2, 'I': 3, 'M': 4, 'O': 5}
inv_label_map = {v: k for k, v in label_map.items()}

# Load and prepare data
train_seqs, test_seqs, train_labels, test_labels = load_and_prepare_data(
    data_path=DATA_FILE, label_map=label_map
)

# Initialize tokenizer and model
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
encoder = T5EncoderModel.from_pretrained(MODEL_NAME)
model = SPCNNClassifier(encoder, NUM_CLASSES).to(DEVICE)

# Create Datasets and DataLoaders
train_dataset = SPDataset(train_seqs, train_labels, tokenizer, MAX_LENGTH)
test_dataset = SPDataset(test_seqs, test_labels, tokenizer, MAX_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=2)

# --- 2. Optimizer and Scheduler Setup (with Gradual Unfreezing) ---
# Freeze the encoder initially
for param in model.encoder.parameters():
    param.requires_grad = False

# Define parameter groups for different learning rates
optimizer = torch.optim.AdamW([
    {'params': model.encoder.parameters(), 'lr': ENCODER_LR_INITIAL},
    {'params': model.conv.parameters(), 'lr': CLASSIFIER_LR},
    {'params': model.bn_conv.parameters(), 'lr': CLASSIFIER_LR},
    {'params': model.lstm.parameters(), 'lr': CLASSIFIER_LR},
    {'params': model.classifier.parameters(), 'lr': CLASSIFIER_LR},
    {'params': model.crf.parameters(), 'lr': CLASSIFIER_LR}
], weight_decay=WEIGHT_DECAY)

# Learning rate scheduler
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# Mixed precision scaler for CUDA
scaler = torch.amp.GradScaler("cuda") if DEVICE == 'cuda' else None

# TensorBoard writer
writer = SummaryWriter(log_dir=LOG_DIR)


# --- 3. Training Loop ---
print("\n--- Starting Training ---")
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    # --- Gradual Unfreezing Logic ---
    # At epoch 4, unfreeze the last 4 encoder layers and set their learning rate
    if epoch == 4:
        print("Unfreezing last 4 encoder layers...")
        for block in model.encoder.block[-4:]:
            for param in block.parameters():
                param.requires_grad = True
        optimizer.param_groups[0]['lr'] = ENCODER_LR_UNFROZEN

    # At epoch 7, unfreeze all remaining encoder layers
    if epoch == 7:
        print("Unfreezing all encoder layers...")
        for param in model.encoder.parameters():
            param.requires_grad = True
        optimizer.param_groups[0]['lr'] = ENCODER_LR_UNFROZEN

    # Train for one epoch
    avg_train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, DEVICE, scaler)
    print(f"Average Training Loss: {avg_train_loss:.4f}")
    writer.add_scalar("Loss/train", avg_train_loss, epoch)

    # Evaluate on the test set
    avg_val_loss, preds, labels = evaluate(model, test_loader, DEVICE)
    print(f"Average Validation Loss: {avg_val_loss:.4f}")
    writer.add_scalar("Loss/validation", avg_val_loss, epoch)

    # Log metrics if predictions were made
    if preds and labels:
        token_acc = accuracy_score(labels, preds)
        writer.add_scalar("Accuracy/token", token_acc, epoch)
        print(f"Token-level Accuracy: {token_acc:.4f}")

# --- 4. Final Evaluation and Saving ---
print("\n--- Final Evaluation ---")
val_loss, all_preds, all_labels = evaluate(model, test_loader, DEVICE)

if all_preds and all_labels:
    # Get metrics
    report = classification_report(
        all_labels, all_preds,
        target_names=list(label_map.keys()),
        digits=4,
        zero_division=0
    )
    seq_acc = sequence_level_accuracy(all_preds, all_labels, test_labels)
    mcc = matthews_corrcoef(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    print("Classification Report:")
    print(report)
    print(f"Sequence Level Accuracy: {seq_acc:.4f}")
    print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")
    print(f"F1 Score (Macro): {f1_macro:.4f}")

    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_preds, labels=list(label_map.values()))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(label_map.keys()))
    fig, ax = plt.subplots(figsize=(8, 8))
    disp.plot(cmap="OrRd", xticks_rotation=45, ax=ax)
    plt.title("Confusion Matrix")
    plt.show()
else:
    print("Evaluation produced no predictions. Skipping metrics calculation.")

# Save the model
print(f"Saving model to {MODEL_SAVE_PATH}")
# Ensure directory exists
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
if TPU_AVAILABLE:
    # Use xm.save for TPUs
    xm.save(model.state_dict(), MODEL_SAVE_PATH)
else:
    torch.save(model.state_dict(), MODEL_SAVE_PATH)

writer.close()
print("--- Training and Evaluation Complete ---")

