In [None]:
import tqdm as notebook_tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split

MODEL_NAME = "Rostlab/prot_bert" 
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" # using mps instead of cuda for training on mac
print(f"Using device: {DEVICE}")
NUM_CLASSES = 6  # num classes for classification
BATCH_SIZE = 44
EPOCHS = 10
LR = 0.001

In [None]:
import pandas as pd

records = []  # uniprot_ac, kingdom, type_, sequence, label
with open("/Users/jonas/Desktop/Uni/PBL/data/complete_set_unpartitioned.fasta", "r") as f:
    current_record = None
    for line in f:
        if line.startswith(">"):
            if current_record is not None:
                if current_record["sequence"] is not None and current_record["label"] is not None:
                    # Save the previous record before starting a new one
                    records.append(current_record)
                else:
                    # If the previous record is incomplete, skip it
                    print("Skipping incomplete record:", current_record)
            # Start a new record
            uniprot_ac, kingdom, type_ = line[1:].strip().split("|")
            current_record = {"uniprot_ac": uniprot_ac, "kingdom": kingdom, "type": type_, "sequence": None, "label": None}
        else:
            # Check if the line contains a sequence or a label
            if current_record["sequence"] is None:
                current_record["sequence"] = line.strip()
            elif current_record["label"] is None:
                current_record["label"] = line.strip()
            else:
                # If both sequence and label are already set, skip this line
                print("Skipping extra line in record:", current_record)
    # Save the last record if it's complete
    if current_record is not None:
        if current_record["sequence"] is not None and current_record["label"] is not None:
            records.append(current_record)
        else:
            print("Skipping incomplete record:", current_record)

"""
# Save the DataFrame to a CSV file
df_raw.to_csv("/Users/jonas/Desktop/Uni/PBL/data/complete_set_unpartitioned.csv", index=False)
"""
# Print the number of records
print(f"Total records: {len(records)}")
df_raw = pd.DataFrame(records)
df_raw.head()


In [None]:
df = df_raw[~df_raw["label"].str.contains("P")]
df.describe()

In [None]:
label_map = {'S': 0, 'T': 1, 'L': 2, 'I': 3, 'M': 4, 'O': 5}

df_encoded = df.copy()
df_encoded["label"] = df_encoded["label"].apply(lambda x: [label_map[c] for c in x if c in label_map])
df_encoded = df_encoded[df_encoded["label"].map(len) > 0]  # Remove rows with empty label lists

# make random smaller dataset
#df_encoded = df_encoded.sample(frac=0.4, random_state=42)

sequences = df_encoded["sequence"].tolist()
label_seqs = df_encoded["label"].tolist()

df_encoded.describe()


In [None]:
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
encoder = BertModel.from_pretrained(MODEL_NAME)
encoder.to(DEVICE)

In [None]:
# Stratify by sequence length to avoid ValueError
train_seqs, test_seqs, train_label_seqs, test_label_seqs = train_test_split(
    sequences, label_seqs, test_size=0.3, random_state=42
)

In [None]:
# Load the data 
class SPDataset(Dataset):
    def __init__(self, sequences, label_seqs, label_map):
        self.label_map = label_map
        self.label_seqs = label_seqs
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        # preprocess the sequence (insert spaces between amino acids)
        seq_processed = " ".join(list(seq))
        labels = self.label_seqs[idx]
        # Tokenize the sequence (padding to ensure all sequences are the same length -> 512 tokens) 
        encoded = tokenizer(seq_processed, return_tensors="pt",
                            padding="max_length", truncation=True, max_length=512)
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)
        
        # Build a label tensor of the same length as input_ids.
        # For tokens beyond the original sequence length, assign -100 so that loss func ignores them.
        orig_length = len(seq)
        token_labels = []
        
        for i in range(input_ids.size(0)):
            if i == 0 or i > orig_length:  
                token_labels.append(-100)  # ignore padding tokens
            else:
                # Use the already encoded label directly
                token_labels.append(labels[i-1])
        labels_tensor = torch.tensor(token_labels)
        
        return {
            'input_ids': input_ids, # tokenized and padded 
            'attention_mask': attention_mask, # differentiate between padding and non-padding tokens
            'labels': labels_tensor # aligned label tensor
        }

train_dataset = SPDataset(train_seqs, train_label_seqs, label_map)
test_dataset = SPDataset(test_seqs, test_label_seqs, label_map)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [None]:
# class for the model on top of the prot_bert encoder
# using the encoder as a protein feature extractor
# and adding a one dimensional convolutional layer on top of it 

import torch
import torch.nn as nn
import torch.nn.functional as F

class SPCNNClassifier(nn.Module):
    def __init__(self, encoder_model):
        super().__init__()
        self.encoder = encoder_model  # encoder: ProtBERT model
        self.dropout = nn.Dropout(0.2)
        hidden_size = self.encoder.config.hidden_size
        # First convolution: map features to 1024 channels (match the hidden size)
        self.conv1 = nn.Conv1d(in_channels=hidden_size, out_channels=1024, kernel_size=6, dilation=2, padding=5)
        # Second convolution: map features to the 6 classes for each token
        self.conv2 = nn.Conv1d(in_channels=1024, out_channels=NUM_CLASSES, kernel_size=3, dilation=2, padding=2)
        
    def forward(self, input_ids, attention_mask):
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = encoder_output.last_hidden_state  # shape: (batch, seq_len, hidden)
        
        x = hidden_states.transpose(1, 2)         # Transpose to (batch, hidden, seq_len) for 1D convolution
        x = self.conv1(x)         # shape: (batch, 256, seq_len)
        x = F.relu(x)
        x = self.conv2(x)         # shape: (batch, NUM_CLASSES, seq_len)
        x = self.dropout(x)
        x = x.transpose(1, 2)         # Transpose back to (batch, seq_len, NUM_CLASSES)
        return x

In [None]:
from transformers import get_linear_schedule_with_warmup

# Initialize the model
model = SPCNNClassifier(encoder).to(DEVICE)

# optimizer and Loss
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,           # Learning rate
    betas=(0.85, 0.999),  # momentum, can overshoot 
    eps=1e-6,          # epsilon
    weight_decay=0.01  # regularization to prevent overfitting
)

# scheduler for learning rate
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=len(train_loader) * EPOCHS
)

# Counter({'I': 1204001, 'O': 362643, 'S': 85526, 'M': 74445, 'L': 46065, 'T': 22272, 'P': 951})
class_counts = [1204001, 85526, 22272, 46865, 74445, 362643]  # Count for each class (I, S, T, L, M, O)
# hopefully deals with the class imbalance
weights = torch.tensor([1.0 / count for count in class_counts], device=DEVICE)

# loss function that ignores the padding tokens (-100)
loss_fn = nn.CrossEntropyLoss(weight=weights, ignore_index=-100)

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm


def evaluate(model, dataloader, device):
    model.eval()
    all_predictions = []
    all_labels = []
    total_loss = 0.0

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            # if i == 10:
            #     break
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            predictions = model(input_ids=input_ids, attention_mask=attention_mask)  # batch, seq_len
            for pred, label in zip(predictions, labels):
                if len(pred) != len(label):
                    label = label[:len(pred)]
                all_predictions.append(pred)
                all_labels.append(label)

    flattened_predictions = [item for sublist in all_predictions for item in sublist]
    flattened_labels = [label_map[item] for sublist in all_labels for item in sublist]

    val_loss = total_loss / len(dataloader)
    token_acc = accuracy_score(flattened_labels, flattened_predictions)

    seq_acc = sum(
        [pred == label for pred, label in zip(all_predictions, all_labels)]
    ) / len(labels)

    return val_loss, token_acc, seq_acc, classification_report(flattened_labels, flattened_predictions, output_dict=True, zero_division=np.nan)

In [None]:
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score

writer = SummaryWriter(log_dir="/Users/jonas/Desktop/Uni/PBL/logs/prot_bert_linear_classifier_v2")

# Training Loop
model.train()

for param in model.encoder.parameters():
    param.requires_grad = False

for epoch in range(EPOCHS):

    model.train()

    # Gradually unfreeze only a subset of encoder layers for efficiency
    if epoch == 4:
        # Unfreeze only the last encoder layer
        for param in model.encoder.encoder.layer[-1].parameters():
            param.requires_grad = True
    elif epoch == 7:
        # Optionally unfreeze one more layer rather than the full encoder
        for param in model.encoder.encoder.layer[-2].parameters():
            param.requires_grad = True

    total_loss = 0 # total epoch loss
    for batch in train_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        token_labels = batch['labels'].to(DEVICE)

        optimizer.zero_grad()  # resets gradients
        logits = model(input_ids, attention_mask)  # forward pass
        loss = loss_fn(logits.reshape(-1, NUM_CLASSES), token_labels.reshape(-1)) # flatten logits and labels and compute loss
        loss.backward()  # backpropagation
        optimizer.step()  # update weights with optimizer
        scheduler.step() # update learning rate

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {avg_train_loss:.4f}")
    writer.add_scalar("Loss/train", avg_train_loss, epoch)

    # Run evaluation at the end of the epoch
    model.eval()
    val_loss, token_acc, seq_acc, report = evaluate(model, test_loader, DEVICE)
    print(f"Epoch {epoch+1}, Val Loss: {val_loss:.4f}, Token Acc: {token_acc:.4f}, Seq Acc: {seq_acc:.4f}")
    writer.add_scalar("Loss/val", val_loss, epoch)
    writer.add_scalar("Accuracy/val", token_acc, epoch)
    writer.add_scalar("Seq_Accuracy/val", seq_acc, epoch)


In [None]:
# Evaluation

import matplotlib.pyplot as plt

model.eval()
val_loss = 0
all_preds = [] # predicted labels
all_labels = [] # true types
with torch.no_grad():
    # uses the batches from the test set for eval
    for batch in test_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        logits = model(input_ids, attention_mask) # runs forward pass
        preds = torch.argmax(logits, dim=-1) # get per token predicted labels

        loss = loss_fn(logits.view(-1, NUM_CLASSES), labels.view(-1)) # calculate loss
        val_loss += loss.item()

        # flatten the predictions and labels
        preds_flat = preds.view(-1)
        labels_flat = labels.view(-1)
        valid_idx = labels_flat != -100 # exclude padding tokens (-100)
        all_preds.extend(preds_flat[valid_idx].cpu().numpy())
        all_labels.extend(labels_flat[valid_idx].cpu().numpy())

# Report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=list(label_map.keys())))

from sklearn.metrics import matthews_corrcoef

# MCC
mcc = matthews_corrcoef(all_labels, all_preds)
print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")

# Confusion Matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(all_labels, all_preds, labels=list(label_map.values()))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_map.keys())
disp.plot(cmap="Blues", xticks_rotation=45)
plt.title("Confusion Matrix")
plt.show()


# For TensorBoard
avg_val_loss = val_loss / len(test_loader)
val_acc = accuracy_score(all_labels, all_preds)
print(f"Epoch {epoch+1}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
writer.add_scalar("Loss/val", avg_val_loss, epoch)
writer.add_scalar("Accuracy/val", val_acc, epoch)

writer.flush()
writer.close()


In [None]:

torch.save(model.state_dict(), "/Users/jonas/Desktop/Uni/PBL/models/prot_bert_linear_classifier_v2.pt")

In [None]:
# Compute sequence-level accuracy, skipping -100 (ignored) positions
def sequence_level_accuracy(preds_flat, labels_flat, test_label_seqs):
    # Step 1: Reconstruct sequence-wise predictions and labels
    seq_lengths = [len(seq) for seq in test_label_seqs]
    preds_seq = []
    labels_seq = []
    idx = 0
    for l in seq_lengths:
        preds_seq.append(preds_flat[idx:idx+l])
        labels_seq.append(labels_flat[idx:idx+l])
        idx += l

    # Step 2: Compute sequence-level accuracy using is_valid variable
    correct = 0
    for pred, label in zip(preds_seq, labels_seq):
        is_valid = [l != -100 for l in label]
        valid_preds = [p for p, valid in zip(pred, is_valid) if valid]
        valid_labels = [l for l, valid in zip(label, is_valid) if valid]
        if valid_preds == valid_labels:
            correct += 1

    total = len(seq_lengths)
    return correct / total if total > 0 else 0.0

acc = sequence_level_accuracy(all_preds, all_labels, test_label_seqs)
print(f"Sequence Level Accuracy: {acc:.4f}")
