## Imports

In [14]:
!pip install transformers
!pip install jax
!pip install flax
!pip install optax
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import jax.numpy as jnp
import jax
from flax import linen as nn
import optax
from sklearn.utils import resample
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import pandas as pd



## Constants and setup

In [15]:
from google.colab import drive
import os
drive.mount('/content/drive')
DRIVE_PATH = "/content/drive/MyDrive/PBLRost/"
FASTA_PATH = os.path.join(DRIVE_PATH, "data/complete_set_unpartitioned.fasta")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
encoder = AutoModel.from_pretrained("Rostlab/prot_bert").to(device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [16]:
def encode_sequence(seq):
    seq = " ".join(seq)  # insert spaces between amino acids
    tokens = tokenizer(seq, return_tensors="pt")
    with torch.no_grad():
        output = encoder(**tokens)
        embedding = output.last_hidden_state  # [1, seq_len, 1024]
    print(f"Encoded sequence: {seq}")
    return embedding[0, 1:-1].cpu().numpy()  # remove [CLS] and [SEP] to match label length

In [17]:
from torch.nn.utils.rnn import pad_sequence

def pad_array(arr, length, pad_value=0):

    if len(arr.shape) == 1:
        return np.pad(arr, (0, length - len(arr)), constant_values=pad_value)
    return np.pad(arr, ((0, length - arr.shape[0]), (0, 0)), constant_values=pad_value)


In [None]:
def load_and_prep_data(dataPath: str):
    import pandas as pd

    records = []  # uniprot_ac, kingdom, type_, sequence, label
    with open(dataPath, "r") as f:
        current_record = None
        for line in f:
            if line.startswith(">"):
                if current_record is not None:
                    if current_record["sequence"] is not None and current_record["label"] is not None:
                        records.append(current_record)
                    else:
                        print("Skipping incomplete record:", current_record)
                uniprot_ac, kingdom, type_ = line[1:].strip().split("|")
                current_record = {"uniprot_ac": uniprot_ac, "kingdom": kingdom, "type": type_, "sequence": None, "label": None}
            else:
                if current_record["sequence"] is None:
                    current_record["sequence"] = line.strip()
                elif current_record["label"] is None:
                    current_record["label"] = line.strip()
                else:
                    print("Skipping extra line in record:", current_record)
        if current_record is not None:
            if current_record["sequence"] is not None and current_record["label"] is not None:
                records.append(current_record)
            else:
                print("Skipping incomplete record:", current_record)

    print(f"Total records: {len(records)}")
    df_raw = pd.DataFrame(records)
    df_raw.dropna(subset=['sequence', 'label', 'type'], inplace=True)

    # Remove records with 'P' in sequence (if needed)
    df = df_raw[~df_raw["sequence"].str.contains("P")].copy()

    df_majority = df[df["type"] == "NO_SP"]
    df_minority = df[df["type"] != "NO_SP"]

    # Upsample minority class
    from sklearn.utils import resample
    df_minority_upsampled = resample(df_minority,
                                    replace=True,
                                    n_samples=len(df_majority),
                                    random_state=42)
    df_upsampled = pd.concat([df_majority, df_minority_upsampled])
    df_upsampled = df_upsampled.sample(frac=1, random_state=42).reset_index(drop=True)

    label_map = {'S': 1, 'T': 1, 'L': 1, 'I': 0, 'M': 0, 'O': 0}
    df_encoded = df_upsampled.copy()
    df_encoded["label"] = df_encoded["label"].apply(lambda x: [label_map[c] for c in x if c in label_map])
    df_encoded = df_encoded[df_encoded["label"].map(len) > 0]

    sequences = df_encoded["sequence"].tolist()
    labels = df_encoded["label"].tolist()

    print(f"Total records after oversampling: {len(df_encoded)}")
    print("Class distribution after oversampling:")
    print(df_encoded["type"].value_counts())

    # get embeddings from bert
    encoded_seqs = [encode_sequence(seq) for seq in sequences]
    encoded_labels = [np.array(lbl) for lbl in labels]

    max_len = max(len(seq) for seq in encoded_seqs)
    hidden_dim = encoded_seqs[0].shape[1]

    X = np.stack([pad_array(seq, max_len) for seq in encoded_seqs])
    Y = np.stack([pad_array(lbl, max_len) for lbl in encoded_labels])

    X = jnp.array(X)
    Y = jnp.array(Y)

    from sklearn.model_selection import train_test_split
    train_seqs, test_seqs, train_types, test_types = train_test_split(
        X, Y, test_size=0.2, random_state=42
    )

    print(f"Training set size: {len(train_seqs)}")
    print(f"Test set size: {len(test_seqs)}")

    return train_seqs, test_seqs, train_types, test_types

# Usage:
train_seqs, test_seqs, train_types, test_types = load_and_prep_data(FASTA_PATH)

Total records: 25693
Total records after oversampling: 1710
Class distribution after oversampling:
type
NO_SP    855
LIPO     468
SP       330
PILIN     44
TAT       13
Name: count, dtype: int64
Encoded sequence: M N K Q S G M T L L E V L L A M S I F T A V A L T L M S S M Q G Q R N A I E R M R N E T L A L W I A D N Q L Q S Q D S F G E E N T S S S G K
Encoded sequence: M N Y L V V I C F A L L L M T G V E S G R D A Y I A D N L N C A Y T C G S N S Y C N T E C T K N G A V S G Y C Q W L G K Y G N A C W C I N L
Encoded sequence: M R S K K L W I S L L F A L T L I F T M A F S N M S A Q A A G K S S T E K K Y I V G F K Q T M S A M S S A K K K D V I S E K G G K V Q K Q F
Encoded sequence: M A G V R S L R C S R G C A G G C E C G D K G K C S D S S L L G K R L S E D S S R H Q L L Q K W A S M W S S M S E D A S V A D M E R A Q L E
Encoded sequence: M A K N T T N R H Y S L R K L K T G T A S V A V A L T V V G A G L V A G Q T V R A D H S D L V A E K Q R L E D L G Q K F E R L K Q R S E L Y
Encoded sequenc

In [None]:
class PerResidueClassifier(nn.Module):
    hidden_size: int = 256

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        x = x.squeeze(-1)  # shape: [batch, seq_len]
        return x



In [None]:
# STEP 5: Training loop
model = PerResidueClassifier()
rng = jax.random.PRNGKey(0)
params = model.init(rng, X)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)

@jax.jit
def loss_fn(params, X, Y):
    logits = model.apply(params, X)
    loss = optax.sigmoid_binary_cross_entropy(logits, Y).mean()
    return loss

@jax.jit
def update(params, opt_state, X, Y):
    loss, grads = jax.value_and_grad(loss_fn)(params, X, Y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

In [None]:
# STEP 6: Train
for epoch in range(10):
    params, opt_state, loss = update(params, opt_state, X, Y)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

# STEP 7: Predict
logits = model.apply(params, X)
preds = (jax.nn.sigmoid(logits) > 0.5).astype(int)

print("Predictions for first sequence:")
print(preds[0])
