In [None]:
#@title 0. uv calibration
import os
!curl -Ls https://astral.sh/uv/install.sh | bash
os.environ["PATH"] += ":/root/.cargo/bin"
!uv --version

In [None]:
#@title 1. Installs, Imports and Main Configuration
!uv pip install transformers==4.38.2 -q
!uv pip install sentencepiece==0.2.0 -q
!uv pip install torch-xla==2.1.0 -q # For TPU support
!uv pip install pytorch-crf==0.7.2 -q
!uv pip install pandas==2.2.2 -q
!uv pip install scikit-learn==1.4.2 -q
!uv pip install tensorboard==2.15.2 -q
!uv pip install flaxcrf -U


# General imports
import os
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Scikit-learn imports for data handling and metrics
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.metrics import (
    classification_report,
    matthews_corrcoef,
    accuracy_score,
    f1_score,
    confusion_matrix,
    ConfusionMatrixDisplay
)

# check if tpu is available
import torch
try:
    import torch_xla.core.xla_model as xm
    TPU_AVAILABLE = xm.xla_device() == 'xla'
except ImportError:
    TPU_AVAILABLE = False

# Jax imports for TPU support
import jax
import jax.numpy as jnp
import flax.linen as nn
from flaxcrf import CRF
import optax
from transformers import FlaxT5ForConditionalGeneration, AutoTokenizer
from typing import Sequence, Tuple, Any, Optional

# --- Main Configuration ---

MODEL_NAME = "Rostlab/ProstT5"
NUM_CLASSES = 6  # num classes for classification ('S', 'T', 'L', 'I', 'M', 'O')
LABEL_MAP = {'S': 0, 'T': 1, 'L': 2, 'I': 3, 'M': 4, 'O': 5}

# Training Hyperparameters
BATCH_SIZE = 16 # Reduced batch size for better memory management
EPOCHS = 10
MAX_LENGTH = 512 # Max sequence length for tokenizer

# Optimizer Hyperparameters
CLASSIFIER_LR = 1e-3 # Learning rate for the new layers (classifier head)
ENCODER_LR_INITIAL = 0.0 # Initial LR for the transformer encoder (frozen)
ENCODER_LR_UNFROZEN = 2e-5 # LR for the encoder when unfrozen
WEIGHT_DECAY = 0.01

# --- Device Setup (CPU, GPU, or TPU) ---
TPU_AVAILABLE = False
try:
    import torch_xla.core.xla_model as xm
    TPU_AVAILABLE = xm.xla_device() == 'xla'
except ImportError:
    TPU_AVAILABLE = False

DEVICE = (
    "xla" if TPU_AVAILABLE else
    "mps" if torch.backends.mps.is_available() else
    "cuda" if torch.cuda.is_available() else
    "cpu"
)
print(f"Using device: {DEVICE}")

# --- File Paths ---
# Ensure you have your data in the specified Google Drive path
DRIVE_PATH = "/content/drive/MyDrive/PBL Rost/"
DATA_FILE = os.path.join(DRIVE_PATH, "data/complete_set_unpartitioned.fasta")
MODEL_SAVE_PATH = os.path.join(DRIVE_PATH, "models/optimized_bert_classifier.pt")
LOG_DIR = os.path.join(DRIVE_PATH, "logs/")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# load tokenizer and encoder
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
ENCODER = FlaxT5ForConditionalGeneration.from_pretrained(MODEL_NAME).encoder

In [None]:
#@title 2. Data Loading and Preparation
def load_and_prepare_data(data_path: str, label_map: dict):
    """
    Loads data from a FASTA file, performs cleaning, balancing, and splitting.

    Args:
        data_path (str): Path to the FASTA file.
        label_map (dict): Mapping from character labels to integer indices.

    Returns:
        tuple: train_sequences, test_sequences, train_labels, test_labels
    """
    # 1. Load data from FASTA file
    print("Loading data from FASTA file...")
    records = []
    with open(data_path, "r") as f:
        current_record = {}
        for line in f:
            if line.startswith(">"):
                if current_record:
                    records.append(current_record)
                header = line[1:].strip().split("|")
                # Handle cases where the header might not have 3 parts
                if len(header) == 3:
                    current_record = {
                        "uniprot_ac": header[0],
                        "kingdom": header[1],
                        "type": header[2],
                        "sequence": "",
                        "label": ""
                    }
                else:
                    current_record = {} # Reset if header is malformed
            elif current_record: # Ensure we have a record to add to
                # This assumes sequence comes before label
                if not current_record.get("sequence"):
                    current_record["sequence"] = line.strip()
                elif not current_record.get("label"):
                    current_record["label"] = line.strip()
    if current_record:
        records.append(current_record)
    df_raw = pd.DataFrame(records)
    print(f"Loaded {len(df_raw)} raw records.")

    # 2. Clean data: drop rows with missing values
    df_raw.dropna(subset=['sequence', 'label', 'type'], inplace=True)
    print(f"Records after dropping NA: {len(df_raw)}")

    # 3. Filter out records with 'P' in the label (as in original notebook)
    df = df_raw[~df_raw["label"].str.contains("P")].copy()
    print(f"Records after filtering 'P' labels: {len(df)}")

    # 4. Balance classes using oversampling
    print("Balancing classes using oversampling...")
    df_majority = df[df["type"] == "NO_SP"]
    df_minority = df[df["type"] != "NO_SP"]

    if not df_minority.empty:
        df_minority_upsampled = resample(
            df_minority,
            replace=True,
            n_samples=len(df_majority),
            random_state=42
        )
        df_balanced = pd.concat([df_majority, df_minority_upsampled])
    else:
        df_balanced = df_majority.copy()

    df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
    print(f"Total records after oversampling: {len(df_balanced)}")
    print("Class distribution after oversampling:")
    print(df_balanced["type"].value_counts())

    # 5. Encode labels and prepare lists
    # Ensure labels are within the valid range before encoding
    valid_chars = list(label_map.keys())
    df_balanced["label_encoded"] = df_balanced["label"].apply(
        lambda x: [label_map[c] for c in x if c in valid_chars]
    )
    # Remove rows where the label sequence became empty after mapping
    df_final = df_balanced[df_balanced["label_encoded"].map(len) > 0].copy()

    sequences = df_final["sequence"].tolist()
    label_seqs = df_final["label_encoded"].tolist()
    print(f"Final dataset size: {len(sequences)}")

    # 6. Split into training and testing sets
    train_seqs, test_seqs, train_label_seqs, test_label_seqs = train_test_split(
        sequences, label_seqs, test_size=0.2, random_state=42, stratify=df_final['type']
    )
    print(f"Training set size: {len(train_seqs)}")
    print(f"Test set size: {len(test_seqs)}")

    return train_seqs, test_seqs, train_label_seqs, test_label_seqs



# Tokenization and Batching
def tokenize_and_batch(sequences, labels, tokenizer, max_length, batch_size):
    """
    Tokenizes sequences and creates batches for training.
    Returns a list of dicts with input_ids, attention_mask, and labels.
    """
    # Tokenize all sequences
    encodings = tokenizer(
        sequences,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors='np'  # using numpy for JAX
    )
    # Pad/truncate labels to max_length
    padded_labels = [
        np.pad(l, (0, max_length - len(l)), constant_values=-100)[:max_length]
        for l in labels
    ]
    # Create batches
    input_ids = encodings['input_ids']
    attention_mask = encodings['attention_mask']
    labels = np.array(padded_labels)
    num_samples = len(sequences)
    batches = []
    for i in range(0, num_samples, batch_size):
        batch = {
            'input_ids': jnp.array(input_ids[i:i+batch_size]),
            'attention_mask': jnp.array(attention_mask[i:i+batch_size]),
            'labels': jnp.array(labels[i:i+batch_size])
        }
        batches.append(batch)
    return batches


In [None]:
#@title 3. Tokenization and Batching for preparation for training

# fasta file reading and split
train_seqs, test_seqs, train_label_seqs, test_label_seqs = load_and_prepare_data(DATA_FILE, LABEL_MAP)

# Example usage after loading data
train_batches = tokenize_and_batch(train_seqs, train_label_seqs, TOKENIZER, MAX_LENGTH, BATCH_SIZE)
test_batches = tokenize_and_batch(test_seqs, test_label_seqs, TOKENIZER, MAX_LENGTH, BATCH_SIZE)

In [None]:
class SPCNNClassifier(nn.Module):
    encoder: nn.Module  # encoder as a submodule
    num_labels: int
    dropout_rate: float = 0.2
    lstm_dropout_rate: float = 0.1
    kernel_size: int = 5
    num_lstm_layers: int = 3

    @nn.compact
    def __call__(self, input_ids, attention_mask, labels=None, training=False):
        # Use the encoder submodule
        encoder_output = self.encoder.encode(
            input_ids=input_ids,
            attention_mask=attention_mask,
            train=training
        )
        hidden_states = encoder_output.last_hidden_state
        hidden_size = hidden_states.shape[-1]

        x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(hidden_states)
        x_conv_input = jnp.transpose(x, (0, 2, 1))
        x_conv = nn.Conv(features=hidden_size, kernel_size=(self.kernel_size,),
                         padding='SAME', name='conv1d')(x_conv_input)
        x_conv = nn.BatchNorm(use_running_average=not training, name='bn_conv')(x_conv)
        x_conv = nn.relu(x_conv)
        x_lstm_input = jnp.transpose(x_conv, (0, 2, 1))

        lstm_out = x_lstm_input
        for i in range(self.num_lstm_layers):
            lstm_out, _ = nn.LSTM(features=hidden_size // 2,
                                  bidirectional=True,
                                  dropout_rate=self.lstm_dropout_rate,
                                  name=f'lstm_layer_{i}',
                                  )(lstm_out, None, deterministic=not training)

        logits = nn.Dense(features=self.num_labels, name='classifier')(lstm_out)
        logits = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(logits)
        crf_mask = attention_mask.astype(jnp.bool_)
        crf_layer = CRF(num_tags=self.num_labels, name='crf')

        if labels is not None:
            loss = crf_layer(logits, labels, mask=crf_mask)
            predictions = crf_layer.viterbi_decode(logits, mask=crf_mask)
            return loss, predictions
        else:
            predictions = crf_layer.viterbi_decode(logits, mask=crf_mask)
            return predictions

In [None]:
# --- State Management (using Optax and Flax's TrainState pattern) ---
from flax.training import train_state

class TrainState(train_state.TrainState):
    batch_stats: Any # To store BatchNorm moving averages/variances

# Create a new TrainState with the model parameters, optimizer, and batch stats
def create_model_and_params(key, encoder, num_labels, batch_size, seq_len):
    dummy_input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
    dummy_attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)

    model = SPCNNClassifier(encoder=encoder, num_labels=num_labels)
    variables = model.init(
        {'params': key, 'dropout': jax.random.split(key)[0]},
        input_ids=dummy_input_ids,
        attention_mask=dummy_attention_mask,
        labels=jnp.ones((batch_size, seq_len), dtype=jnp.int32),
        training=True,
        mutable=['batch_stats']
    )
    return model, variables['params'], variables['batch_stats']

@jax.jit
def train_step(state, batch, model):
    """
    Performs a single training step.

    Args:
        state: Current TrainState containing model parameters, optimizer state, etc.
        batch: Dictionary containing 'input_ids', 'attention_mask', and 'labels'.
        model: The SPCNNClassifier instance.

    Returns:
        A new TrainState, the computed loss, and predictions.
    """
    key, dropout_key = jax.random.split(state.rng) # Split RNG for next step and dropout

    def loss_fn(params):
        # `model.apply` expects a dictionary of variable collections.
        # We pass 'params' for trainable weights and 'batch_stats' for BatchNorm state.
        variables = {'params': params, 'batch_stats': state.batch_stats}
        
        # Call the model in training mode.
        # `mutable=['batch_stats']` means `updated_variables['batch_stats']` will contain the new stats.
        # `rngs={'dropout': dropout_key}` explicitly passes the key for dropout.
        loss, predictions = model.apply(
            variables,
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels'],
            training=True,
            mutable=['batch_stats'],
            rngs={'dropout': dropout_key}
        )
        return loss, (predictions, variables['batch_stats']) # Return loss and auxiliary data

    # Compute loss and gradients
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (predictions, updated_batch_stats)), grads = grad_fn(state.params)

    # Apply updates to parameters using the optimizer
    updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)

    # Update the training state
    new_state = state.replace(
        params=new_params,
        opt_state=new_opt_state,
        batch_stats=updated_batch_stats, # Update BatchNorm stats
        rng=key # Update the PRNGKey for the next step
    )
    return new_state, loss, predictions

@jax.jit
def eval_step(state, batch, model):
    """
    Performs a single evaluation step (no gradient updates).

    Args:
        state: Current TrainState (used for parameters and batch_stats).
        batch: Dictionary containing 'input_ids', 'attention_mask', and 'labels'.
        model: The SPCNNClassifier instance.

    Returns:
        The computed loss and predictions.
    """
    # Use the current parameters and the running batch statistics for evaluation.
    variables = {'params': state.params, 'batch_stats': state.batch_stats}
    
    # Call the model in evaluation mode (`training=False`).
    # No mutable state updates, no dropout active.
    loss, predictions = model.apply(
        variables,
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        labels=batch['labels'],
        training=False, # Important: disable training for evaluation (e.g., for BatchNorm, dropout)
    )
    return loss, predictions

In [None]:
# helper func for gradual unfreezing
def get_encoder_mask(params, encoder_key='encoder', unfreeze_layers=None):
    def mask_fn(param_name, _):
        if encoder_key in param_name:
            if unfreeze_layers is None:
                return False
            return any(layer in param_name for layer in unfreeze_layers)
        return True
    return jax.tree_util.tree_map_with_path(mask_fn, params)

In [None]:

#@title Training and Evaluation Loop
if __name__ == '__main__':
    # Hyperparameters
    key = jax.random.PRNGKey(0) # Initial PRNG key
    num_labels = 6
    batch_size = 16
    seq_len = MAX_LENGTH
    learning_rate = 1e-4
    num_epochs = 6

    print(f"Initializing model with encoder: {ENCODER}")

    # Initialize model and parameters
    model, initial_params, initial_batch_stats = create_model_and_params(key, ENCODER, num_labels, batch_size, seq_len)

    # using adamw optimizer and freezing the encoder initially
    mask = get_encoder_mask(initial_params, encoder_key='encoder', unfreeze_layers=None)
    tx = optax.adamw(learning_rate)
    opt_state = tx.init(initial_params)

    # Create initial training state
    state = TrainState.create(
        apply_fn=model.apply,
        params=initial_params,
        tx=tx,
        opt_state=opt_state,
        batch_stats=initial_batch_stats,
        rng=key # Initial PRNGKey for training steps
    )

    print("Model and optimizer initialized successfully.")

    # Training loop using real data
    for epoch in range(num_epochs):

        # Unfreeze layers gradually
        if epoch == 1:
            unfreeze_layers = [('encoder', 'block', '0')]
            mask = get_encoder_mask(state.params, encoder_key='encoder', unfreeze_layers=unfreeze_layers)
            new_tx = optax.masked(optax.adamw(learning_rate), mask)
            state = state.replace(tx=new_tx)
            print("Unfroze encoder block 0.")
        elif epoch == 4:
            unfreeze_layers = [('encoder', 'block', str(i)) for i in range(3)]
            mask = get_encoder_mask(state.params, encoder_key='encoder', unfreeze_layers=unfreeze_layers)
            new_tx = optax.masked(optax.adamw(learning_rate), mask)
            state = state.replace(tx=new_tx)
            print("Unfroze 0-2 encoder blocks.")
        elif epoch == 6:
            unfreeze_layers = [('encoder', 'block', str(i)) for i in range(4)]
            mask = get_encoder_mask(state.params, encoder_key='encoder', unfreeze_layers=unfreeze_layers)
            new_tx = optax.masked(optax.adamw(learning_rate), mask)
            state = state.replace(tx=new_tx)
            print("Unfroze 0-3 encoder blocks.")


        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        train_losses = []
        for batch in train_batches:
            state, loss, predictions = train_step(state, batch, model)
            train_losses.append(loss)
        print(f"  Train Loss: {np.mean(train_losses):.4f}")

        eval_losses = []
        for batch in test_batches:
            eval_loss, eval_predictions = eval_step(state, batch, model)
            eval_losses.append(eval_loss)
        print(f"  Eval Loss: {np.mean(eval_losses):.4f}")

    print("\nTraining complete!")

    # test for using interference TODO maybe remove
    print("\nPerforming inference on a new sample:")
    inference_batch = test_batches[0]  # Use the first test batch as an example
    inference_predictions = model.apply(
        {'params': state.params, 'batch_stats': state.batch_stats},
        input_ids=inference_batch['input_ids'],
        attention_mask=inference_batch['attention_mask'],
        training=False
    )
    print(f"Inference prediction for a new sample (first sample): {inference_predictions[0]}")

