## Imports

In [None]:

from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import jax.numpy as jnp
import jax
from flax import linen as nn
import optax
from sklearn.utils import resample, train_test_split
import pandas as pd

## Constants and setup

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')
DRIVE_PATH = "/content/drive/MyDrive/PBLRost/"
FASTA_PATH = os.path.join(DRIVE_PATH, "data/complete_set_unpartitioned.fasta")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
encoder = AutoModel.from_pretrained("Rostlab/prot_bert").to(device)

In [None]:
def encode_sequence(seq):
    seq = " ".join(seq)  # insert spaces between amino acids
    tokens = tokenizer(seq, return_tensors="pt")
    with torch.no_grad():
        output = encoder(**tokens)
        embedding = output.last_hidden_state  # [1, seq_len, 1024]
    return embedding[0, 1:-1].cpu().numpy()  # remove [CLS] and [SEP] to match label length

In [None]:
from torch.nn.utils.rnn import pad_sequence

def pad_array(arr, length, pad_value=0):

    if len(arr.shape) == 1:
        return np.pad(arr, (0, length - len(arr)), constant_values=pad_value)
    return np.pad(arr, ((0, length - arr.shape[0]), (0, 0)), constant_values=pad_value)


In [None]:
def load_and_prep_data(dataPath: str):
    records = []
    with open(dataPath, "r") as f:
        current_record = {}
        for line in f:
            if line.startswith(">"):
                if current_record:
                    records.append(current_record)
                header = line[1:].strip().split("|")
                if len(header) == 3:
                    current_record = {
                        "uniprot_ac": header[0],
                        "kingdom": header[1],
                        "type": header[2],
                        "sequence": ""
                    }
                else:
                    current_record = {}
            elif current_record:
                if not current_record.get("sequence"):
                    current_record["sequence"] = line.strip()
    if current_record:
        records.append(current_record)
    df_raw = pd.DataFrame(records)

    # drop na rows
    df_raw.dropna(subset=['sequence', 'type'], inplace=True)

    # Remove records with 'P' in sequence (if needed)
    df = df_raw[~df_raw["sequence"].str.contains("P")].copy()

    df_majority = df[df["type"] == "NO_SP"]
    df_minority = df[df["type"] != "NO_SP"]

    # Upsample minority class
    df_minority_upsampled = resample(df_minority,
                                    replace=True,     # sample with replacement
                                    n_samples=len(df_majority),    # to match majority class
                                    random_state=42) # reproducible results
    # Combine majority class with upsampled minority class
    df_upsampled = pd.concat([df_majority, df_minority_upsampled])
    # Shuffle the dataset
    df_upsampled = df_upsampled.sample(frac=1, random_state=42).reset_index(drop=True)

    label_map = {'S': 1, 'T': 1, 'L': 1, 'I': 0, 'M': 0, 'O': 0}

    df_encoded = df_upsampled.copy()
    df_encoded["label"] = df_encoded["label"].apply(lambda x: [label_map[c] for c in x if c in label_map])
    df_encoded = df_encoded[df_encoded["label"].map(len) > 0]  # Remove rows with empty label lists


    sequences = df_encoded["sequence"].tolist()
    labels = df_encoded["label"].tolist()

    df_encoded.describe()

    # total records after oversampling
    print(f"Total records after oversampling: {len(df_upsampled)}")

    # majority class distribution
    print("Class distribution after oversampling:")
    print(df_upsampled["type"].value_counts())

    # get embeddings from bert
    encoded_seqs = [encode_sequence(seq) for seq in sequences]
    encoded_labels = [np.array(lbl) for lbl in labels]

    max_len = max(len(seq) for seq in encoded_seqs)
    hidden_dim = encoded_seqs[0].shape[1]

    # pad the sequences
    X = np.stack([pad_array(seq, max_len) for seq in encoded_seqs])
    Y = np.stack([pad_array(lbl, max_len) for lbl in encoded_labels])

    X = jnp.array(X)  # shape: [batch_size, seq_len, hidden_dim]
    Y = jnp.array(Y)  # shape: [batch_size, seq_len]

    # each part
    train_seqs, test_seqs, train_types, test_types = train_test_split(
        X, Y, test_size=0.2, random_state=42
    )

    print(f"Training set size: {len(train_seqs)}")
    print(f"Test set size: {len(test_seqs)}")

    return train_seqs, test_seqs, train_types, test_types

train_seqs, test_seqs, train_types, test_types = load_and_prep_data(FASTA_PATH)


In [None]:
class PerResidueClassifier(nn.Module):
    hidden_size: int = 256

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        x = x.squeeze(-1)  # shape: [batch, seq_len]
        return x



In [None]:
# STEP 5: Training loop
model = PerResidueClassifier()
rng = jax.random.PRNGKey(0)
params = model.init(rng, train_seqs)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)

@jax.jit
def loss_fn(params, X, Y):
    logits = model.apply(params, X)
    loss = optax.sigmoid_binary_cross_entropy(logits, Y).mean()
    return loss

@jax.jit
def update(params, opt_state, X, Y):
    loss, grads = jax.value_and_grad(loss_fn)(params, X, Y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

In [None]:
# STEP 6: Train
for epoch in range(10):
    params, opt_state, loss = update(params, opt_state, train_seqs, train_types)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")



In [None]:
# STEP 7: Predict on test set
logits = model.apply(params, test_seqs)
preds = (jax.nn.sigmoid(logits) > 0.5).astype(int)

print("Predictions for first test sequence:")
print(preds[0])