<a href="https://colab.research.google.com/github/jonas-tfo/sp-prediction-models/blob/main/2state/2state_sp_classifier_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
!pip install transformers
!pip install jax
!pip install flax
!pip install optax
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import jax.numpy as jnp
import jax
from flax import linen as nn
import optax
from sklearn.utils import resample
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import pandas as pd

## Constants and setup

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')
DRIVE_PATH = "/content/drive/MyDrive/PBLRost/"
FASTA_PATH = os.path.join(DRIVE_PATH, "data/complete_set_unpartitioned.fasta")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
encoder = AutoModel.from_pretrained("Rostlab/prot_bert").to(device)

In [None]:
def encode_sequence(seq):
    seq = " ".join(seq)  # insert spaces between amino acids
    tokens = tokenizer(seq, return_tensors="pt")
    with torch.no_grad():
        output = encoder(**tokens)
        embedding = output.last_hidden_state  # [1, seq_len, 1024]
    return embedding[0, 1:-1].cpu().numpy()  # remove [CLS] and [SEP] to match label length

In [None]:
from torch.nn.utils.rnn import pad_sequence

def pad_array(arr, length, pad_value=0):

    if len(arr.shape) == 1:
        return np.pad(arr, (0, length - len(arr)), constant_values=pad_value)
    return np.pad(arr, ((0, length - arr.shape[0]), (0, 0)), constant_values=pad_value)


In [None]:
def load_and_prep_data(dataPath: str):
    import pandas as pd

    records = []  # uniprot_ac, kingdom, type_, sequence, label
    with open(dataPath, "r") as f:
        current_record = None
        for line in f:
            if line.startswith(">"):
                if current_record is not None:
                    if current_record["sequence"] is not None and current_record["label"] is not None:
                        records.append(current_record)
                    else:
                        print("Skipping incomplete record:", current_record)
                uniprot_ac, kingdom, type_ = line[1:].strip().split("|")
                current_record = {"uniprot_ac": uniprot_ac, "kingdom": kingdom, "type": type_, "sequence": None, "label": None}
            else:
                if current_record["sequence"] is None:
                    current_record["sequence"] = line.strip()
                elif current_record["label"] is None:
                    current_record["label"] = line.strip()
                else:
                    print("Skipping extra line in record:", current_record)
        if current_record is not None:
            if current_record["sequence"] is not None and current_record["label"] is not None:
                records.append(current_record)
            else:
                print("Skipping incomplete record:", current_record)

    print(f"Total records: {len(records)}")
    df_raw = pd.DataFrame(records)
    df_raw.dropna(subset=['sequence', 'label', 'type'], inplace=True)

    # Remove records with 'P' in sequence (if needed)
    df = df_raw[~df_raw["sequence"].str.contains("P")].copy()

    df_majority = df[df["type"] == "NO_SP"]
    df_minority = df[df["type"] != "NO_SP"]

    # Upsample minority class
    from sklearn.utils import resample
    df_minority_upsampled = resample(df_minority,
                                    replace=True,
                                    n_samples=len(df_majority),
                                    random_state=42)
    df_upsampled = pd.concat([df_majority, df_minority_upsampled])
    df_upsampled = df_upsampled.sample(frac=1, random_state=42).reset_index(drop=True)

    label_map = {'S': 1, 'T': 1, 'L': 1, 'I': 0, 'M': 0, 'O': 0}
    df_encoded = df_upsampled.copy()
    df_encoded["label"] = df_encoded["label"].apply(lambda x: [label_map[c] for c in x if c in label_map])
    df_encoded = df_encoded[df_encoded["label"].map(len) > 0]

    sequences = df_encoded["sequence"].tolist()
    labels = df_encoded["label"].tolist()

    print(f"Total records after oversampling: {len(df_encoded)}")
    print("Class distribution after oversampling:")
    print(df_encoded["type"].value_counts())

    print("Encoding...")

    # get embeddings from bert
    encoded_seqs = [encode_sequence(seq) for seq in sequences]
    encoded_labels = [np.array(lbl) for lbl in labels]

    max_len = max(len(seq) for seq in encoded_seqs)
    hidden_dim = encoded_seqs[0].shape[1]

    X = np.stack([pad_array(seq, max_len) for seq in encoded_seqs])
    Y = np.stack([pad_array(lbl, max_len) for lbl in encoded_labels])

    X = jnp.array(X)
    Y = jnp.array(Y)

    from sklearn.model_selection import train_test_split
    train_seqs, test_seqs, train_types, test_types = train_test_split(
        X, Y, test_size=0.2, random_state=42
    )

    print(f"Training set size: {len(train_seqs)}")
    print(f"Test set size: {len(test_seqs)}")

    return train_seqs, test_seqs, train_types, test_types

# Usage:
train_seqs, test_seqs, train_types, test_types = load_and_prep_data(FASTA_PATH)

In [None]:
class PerResidueClassifier(nn.Module):
    hidden_size: int = 256

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        x = x.squeeze(-1)  # shape: [batch, seq_len]
        return x



In [None]:
# STEP 5: Training loop
model = PerResidueClassifier()
rng = jax.random.PRNGKey(0)
params = model.init(rng, train_seqs)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)

@jax.jit
def loss_fn(params, X, Y):
    logits = model.apply(params, X)
    loss = optax.sigmoid_binary_cross_entropy(logits, Y).mean()
    return loss

@jax.jit
def update(params, opt_state, X, Y):
    loss, grads = jax.value_and_grad(loss_fn)(params, X, Y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

In [None]:
# STEP 6: Train with mini-batches and evaluate on test set
batch_size = 32

# Create a simple data loader
def data_loader(X, Y, batch_size, shuffle=True):
    dataset_size = X.shape[0]
    indices = jnp.arange(dataset_size)
    if shuffle:
        indices = jax.random.permutation(rng, indices)

    for i in range(0, dataset_size, batch_size):
        batch_indices = indices[i:i + batch_size]
        yield X[batch_indices], Y[batch_indices]

# Re-initialize parameters and optimizer state for a fresh training run
rng = jax.random.PRNGKey(0) # Use a new random key for reproducibility
params = model.init(rng, train_seqs[:1]) # Initialize with a small batch
opt_state = optimizer.init(params)


num_epochs = 10 # You can adjust the number of epochs

for epoch in range(num_epochs):
    total_loss = 0
    count = 0
    for batch_X, batch_Y in data_loader(train_seqs, train_types, batch_size):
        params, opt_state, loss = update(params, opt_state, batch_X, batch_Y)
        total_loss += loss * batch_X.shape[0]
        count += batch_X.shape[0]

    avg_loss = total_loss / count
    print(f"Epoch {epoch+1}, Training Loss: {avg_loss:.4f}")

# STEP 7: Predict and evaluate on the test set
test_logits = model.apply(params, test_seqs)
test_preds = (jax.nn.sigmoid(test_logits) > 0.5).astype(int)

# Simple accuracy calculation (considering padded values might affect this)
# A more robust evaluation would involve masking padded values
correct_predictions = jnp.sum(test_preds == test_types)
total_predictions = test_types.size
accuracy = correct_predictions / total_predictions

print("\nEvaluation on Test Set:")
print(f"Accuracy: {accuracy:.4f}")

# Print predictions for the first sequence in the test set
print("\nPredictions for first sequence in test set:")
print(test_preds[0])
print("Actual labels for first sequence in test set:")
print(test_types[0])