# Binary classification of sequences into SP or Non-SP using transformer embedding and Gradient Boosting (XGBoost)

#### Imports and dependancies

In [None]:
%pip install transformers
%pip install xgboost
import pandas as  pd
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
import numpy as np
import xgboost as xgb
from sklearn.metrics import accuracy_score, matthews_corrcoef, f1_score, precision_score, recall_score, roc_auc_score
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm
import os




#### Constants

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')
DRIVE_PATH = "/content/drive/MyDrive/PBLRost/"
FASTA_PATH = os.path.join(DRIVE_PATH, "data/complete_set_unpartitioned.fasta")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
encoder = AutoModel.from_pretrained("Rostlab/prot_bert").to(device)

#### Embedding function using prot_bert transformer

In [16]:
def get_protbert_embeddings_batched(sequences, batch_size=16):
    embeddings = []

    # space-separated amino acids (ProtBERT requirement)
    formatted = [" ".join(list(seq)) for seq in sequences]

    with torch.no_grad():
        for i in tqdm(range(0, len(formatted), batch_size)):
            batch_seqs = formatted[i:i+batch_size]
            encoded = tokenizer(batch_seqs, return_tensors="pt", padding=True, truncation=True, max_length=512)

            input_ids = encoded['input_ids'].to(device)
            attention_mask = encoded['attention_mask'].to(device)

            outputs = encoder(input_ids=input_ids, attention_mask=attention_mask)
            batch_embs = outputs.last_hidden_state.mean(dim=1).cpu().numpy()  # average pooling
            embeddings.extend(batch_embs)

    return np.array(embeddings)

#### Data loading and prep (including embedding)

In [None]:
def load_and_prep_data(dataPath: str):
    records = []
    with open(dataPath, "r") as f:
        current_record = {}
        for line in f:
            if line.startswith(">"):
                if current_record:
                    records.append(current_record)
                header = line[1:].strip().split("|")
                if len(header) == 3:
                    current_record = {
                        "uniprot_ac": header[0],
                        "kingdom": header[1],
                        "type": header[2],
                        "sequence": ""
                    }
                else:
                    current_record = {}
            elif current_record:
                if not current_record.get("sequence"):
                    current_record["sequence"] = line.strip()
    if current_record:
        records.append(current_record)
    df_raw = pd.DataFrame(records)

    # drop na rows
    df_raw.dropna(subset=['sequence', 'type'], inplace=True)

    # Remove records with 'P' in sequence (if needed)
    df = df_raw[~df_raw["sequence"].str.contains("P")].copy()

    # binary labeling
    df["type"] = df["type"].replace({
        "NO_SP": "0",
        "LIPO": "1",
        "SP": "1",
        "TAT": "1",
        "TATLIPO": "1"
    })

    df_majority = df[df["type"] == "0"]
    df_minority = df[df["type"] == "1"]

    # randomly oversample the data to equalize the NO_SP to SP ratio
    if not df_minority.empty and not df_majority.empty:
        df_minority_upsampled = resample(
            df_minority,
            replace=True,
            n_samples=len(df_majority),
            random_state=42
        )
        df_balanced = pd.concat([df_majority, df_minority_upsampled])
    else:
        df_balanced = df.copy()

    df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
    print(f"Total records after oversampling: {len(df_balanced)}")
    print("Class distribution after oversampling:")
    print(df_balanced["type"].value_counts())

    """
    # randomly undersample the majority class to match the minority class size
    if not df_minority.empty and not df_majority.empty:
        df_majority_undersampled = resample(
            df_majority,
            replace=False,
            n_samples=len(df_minority),
            random_state=42
        )
        df_balanced = pd.concat([df_majority_undersampled, df_minority])
    else:
        df_balanced = df.copy()

    df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
    print(f"Total records after undersampling: {len(df_balanced)}")
    print("Class distribution after undersampling:")
    print(df_balanced["type"].value_counts())
    """

    sequences = df_balanced["sequence"].tolist()
    labels = df_balanced["type"].astype(int).tolist()

    # get embeddings from bert
    finSeqs = get_protbert_embeddings_batched(sequences, batch_size=16)
    finLabels = np.array(labels)

    # each part
    train_seqs, test_seqs, train_types, test_types = train_test_split(
        finSeqs, finLabels, test_size=0.2, random_state=42, stratify=finLabels
    )

    print(f"Training set size: {len(train_seqs)}")
    print(f"Test set size: {len(test_seqs)}")

    return train_seqs, test_seqs, train_types, test_types

train_seqs, test_seqs, train_types, test_types = load_and_prep_data(FASTA_PATH)


Total records after oversampling: 1710
Class distribution after oversampling:
type
1    855
0    855
Name: count, dtype: int64


100%|██████████| 107/107 [00:34<00:00,  3.09it/s]

Training set size: 1368
Test set size: 342





#### Model def and training

In [18]:
model = xgb.XGBClassifier(
    objective='binary:logistic',
    eval_metric='logloss',
    use_label_encoder=False,
    random_state=42
)

model.fit(train_seqs, train_types)

Parameters: { "use_label_encoder" } are not used.




#### Model eval

In [19]:
pred_types = model.predict(test_seqs)
print("Accuracy:", round(accuracy_score(test_types, pred_types), 3))

# Precision, Recall, F1
print("Precision:", round(precision_score(test_types, pred_types), 3))
print("Recall:", round(recall_score(test_types, pred_types), 3))
print("F1 Score:", round(f1_score(test_types, pred_types), 3))

# Matthews Correlation Coefficient
print("Matthews Correlation Coefficient:", round(matthews_corrcoef(test_types, pred_types), 3))

if hasattr(model, "predict_proba"):
    proba = model.predict_proba(test_seqs)[:, 1]
    print("ROC AUC:", round(roc_auc_score(test_types, proba), 3))

Accuracy: 0.994
Precision: 0.994
Recall: 0.994
F1 Score: 0.994
Matthews Correlation Coefficient: 0.988
ROC AUC: 1.0
