In [None]:
import tqdm as notebook_tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split

MODEL_NAME = "Rostlab/prot_bert" 
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" # using mps instead of cuda for training on mac
#DEVICE = "cpu"  # use GPU if available, otherwise CPU
print(f"Using device: {DEVICE}")
NUM_CLASSES = 2  
BATCH_SIZE = 16
EPOCHS = 10
LR = 0.01

In [None]:
import pandas as pd

records = []  # uniprot_ac, kingdom, type_, sequence, label
with open("/Users/jonas/Desktop/Uni/PBL/data/complete_set_unpartitioned.fasta", "r") as f:
    current_record = None
    for line in f:
        if line.startswith(">"):
            if current_record is not None:
                if current_record["sequence"] is not None and current_record["label"] is not None:
                    # Save the previous record before starting a new one
                    records.append(current_record)
                else:
                    # If the previous record is incomplete, skip it
                    print("Skipping incomplete record:", current_record)
            # Start a new record
            uniprot_ac, kingdom, type_ = line[1:].strip().split("|")
            current_record = {"uniprot_ac": uniprot_ac, "kingdom": kingdom, "type": type_, "sequence": None, "label": None}
        else:
            # Check if the line contains a sequence or a label
            if current_record["sequence"] is None:
                current_record["sequence"] = line.strip()
            elif current_record["label"] is None:
                current_record["label"] = line.strip()
            else:
                # If both sequence and label are already set, skip this line
                print("Skipping extra line in record:", current_record)
    # Save the last record if it's complete
    if current_record is not None:
        if current_record["sequence"] is not None and current_record["label"] is not None:
            records.append(current_record)
        else:
            print("Skipping incomplete record:", current_record)

"""
# Save the DataFrame to a CSV file
df_raw.to_csv("/Users/jonas/Desktop/Uni/PBL/data/complete_set_unpartitioned.csv", index=False)
"""
# Print the number of records
print(f"Total records: {len(records)}")
df_raw = pd.DataFrame(records)
df_raw.head()


In [None]:
df = df_raw[~df_raw["label"].str.contains("P")]
df["type"] = df["type"].replace("NO_SP", "0")
df["type"] = df["type"].replace("LIPO", "1")
df["type"] = df["type"].replace("SP", "1")
df["type"] = df["type"].replace("TAT", "1")
df["type"] = df["type"].replace("TATLIPO", "1")

df.describe()

In [None]:
label_map = {'S': 0, 'T': 1, 'L': 2, 'I': 3, 'M': 4, 'O': 5}

df_encoded = df.copy()
df_encoded["label"] = df_encoded["label"].apply(lambda x: [label_map[c] for c in x if c in label_map])
df_encoded = df_encoded[df_encoded["label"].map(len) > 0]  # Remove rows with empty label lists

# make random smaller dataset
#df_encoded = df_encoded.sample(frac=0.4, random_state=42)

sequences = df_encoded["sequence"].tolist()
types = df_encoded["type"].tolist()

df_encoded.describe()


In [None]:
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
encoder = BertModel.from_pretrained(MODEL_NAME)
encoder.to(DEVICE)

In [None]:
# Stratify by sequence length to avoid ValueError
# TODO maybe change the stratify
train_seqs, test_seqs, train_label_seqs, test_label_seqs = train_test_split(
    sequences, types, test_size=0.3, random_state=42, stratify=types
)

In [None]:
# Load the data 
class SPDataset(Dataset):
    def __init__(self, sequences, label_seqs, label_map):
        self.label_map = label_map
        self.label_seqs = label_seqs
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        # preprocess the sequence (insert spaces between amino acids)
        seq_processed = " ".join(list(seq))
        # Convert the single label string to an integer
        label_value = int(self.label_seqs[idx])
        # Tokenize the sequence (padding to ensure all sequences are the same length -> 512 tokens)
        encoded = tokenizer(seq_processed, return_tensors="pt",
                            padding="max_length", truncation=True, max_length=512)
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)
        
        # Build a label tensor of the same length as input_ids.
        # For tokens beyond the original sequence length, assign -100 so that loss func ignores them.
        orig_length = len(seq)
        token_labels = []
        
        for i in range(input_ids.size(0)):
            if i == 0 or i > orig_length:
                token_labels.append(-100)  # ignore padding tokens
            else:
                token_labels.append(label_value)
        labels_tensor = torch.tensor(token_labels)
        
        return {
            'input_ids': input_ids,  # tokenized and padded 
            'attention_mask': attention_mask,  # differentiate between padding and non-padding tokens
            'labels': labels_tensor  # aligned label tensor
        }

train_dataset = SPDataset(train_seqs, train_label_seqs, label_map)
test_dataset = SPDataset(test_seqs, test_label_seqs, label_map)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchcrf import CRF

class BinaryClassifier(nn.Module):
    def __init__(self, encoder_model, num_labels):
        super().__init__()
        self.encoder = encoder_model  
        self.dropout = nn.Dropout(0.2)
        self.relu = nn.ReLU()
        hidden_size = self.encoder.config.hidden_size

        # 2 layer long short term memory network
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=512, num_layers=2, bidirectional=True, batch_first=True)
        # dense layer
        self.classifier = nn.Linear(512 * 2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)


    def forward(self, input_ids, attention_mask, labels=None):
        # Encode with BERT
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = encoder_output.last_hidden_state  # (batch, seq_len, hidden_size)

        # Apply BiLSTM
        lstm_out, _ = self.lstm(hidden_states)            # (batch, seq_len, 1024)

        # Classifier to num_labels
        x_linear = self.classifier(lstm_out)             # (batch, seq_len, num_labels)

        x_linear = self.relu(x_linear)     # Apply ReLU activation

        logits = self.dropout(x_linear)     # (batch, seq_len, num_labels)

        if labels is not None:
            # Replace ignore-index (-100) with a valid label (0) since CRF doesn't support -100
            mod_labels = labels.clone()
            mod_labels[labels == -100] = 0
            loss = -self.crf(logits, mod_labels, mask=attention_mask.bool(), reduction='mean')
            return loss
        else:
            predictions = self.crf.decode(logits, mask=attention_mask.bool())
            return predictions



In [None]:
from transformers import get_linear_schedule_with_warmup

# Initialize the model
model = BinaryClassifier(encoder, NUM_CLASSES).to(DEVICE)

# optimizer 
optimizer = torch.optim.AdamW([
    {"params": model.encoder.encoder.layer[-4:].parameters(), "lr": 5e-6},
    {"params": model.classifier.parameters(), "lr": 1e-3},
    {"params": model.lstm.parameters(), "lr": 1e-3},
    {"params": model.crf.parameters(), "lr": 1e-3},
])  # adjust weight_decay as needed

total_steps = len(train_loader) * EPOCHS
# scheduler for learning rate
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)


In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.amp import autocast, GradScaler
import gc

writer = SummaryWriter(log_dir="/Users/jonas/Desktop/Uni/PBL/logs/bert_sp_binary_classifier")
scaler = GradScaler()

# Freeze encoder parameters initially (last 10 layers)
for param in model.encoder.parameters():
    param.requires_grad = False


for epoch in range(EPOCHS):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch")
    total_loss = 0  # total epoch loss

    for batch in pbar:
        try:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            token_labels = batch['labels'].to(DEVICE)

            optimizer.zero_grad()  # reset gradients

            loss = model(input_ids, attention_mask, token_labels)  # single forward pass

            scaler.scale(loss).backward()      # backpropagation
            scaler.unscale_(optimizer)         # Unscale gradients before clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # gradient clipping
            scaler.step(optimizer)             # update weights
            scaler.update()                    # update scaler
            scheduler.step()                   # update learning rate

            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        except RuntimeError as e:
            print("Error during training:", e)
            gc.collect()
            torch.mps.empty_cache()
            continue

    torch.mps.empty_cache()  
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {avg_train_loss:.4f}")
    writer.add_scalar("Loss/train", avg_train_loss, epoch)

writer.flush()
writer.close()

In [None]:
# Evaluation

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, accuracy_score, classification_report

model = BinaryClassifier(encoder, NUM_CLASSES).to(DEVICE)

val_loss = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        # Compute loss using CRF (pass labels)
        loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        val_loss += loss.item()

        # Decode predictions using CRF (no labels passed)
        predictions = model(input_ids=input_ids, attention_mask=attention_mask)  # List[List[int]]

        # Loop through batch and collect valid tokens
        for pred_seq, label_seq, mask in zip(predictions, labels, attention_mask):
            for pred, true, is_valid in zip(pred_seq, label_seq, mask):
                if true.item() != -100 and is_valid.item() == 1:
                    all_preds.append(pred)
                    all_labels.append(true.item())


# Report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=list(label_map.keys())))

# Accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f"Accuracy: {accuracy:.4f}")

# Confusion Matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
conf_matrix = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=list(label_map.keys()), yticklabels=list(label_map.keys()))
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.savefig("/Users/jonas/Desktop/Uni/PBL/logs/confusion_matrix.png")
plt.show()

# F1 Score weighted
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, accuracy_score, classification_report
f1 = f1_score(all_labels, all_preds, average='weighted')
print(f"F1 Score (weighted): {f1:.4f}")

# F1 Score macro
f1_macro = f1_score(all_labels, all_preds, average='macro')
print(f"F1 Score (macro): {f1_macro:.4f}")

# Matthews Correlation Coefficient (MCC)
mcc = matthews_corrcoef(all_labels, all_preds)
print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")


writer.add_scalar("Loss/test", val_loss)

writer.flush()
writer.close()
