In [31]:
import tqdm as notebook_tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from torch.utils.tensorboard import SummaryWriter

MODEL_NAME = "Rostlab/prot_bert" 
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" # using mps instead of cuda for training on mac
print(f"Using device: {DEVICE}")
NUM_CLASSES = 6  # num classes for classification
BATCH_SIZE = 64
EPOCHS = 7
LR = 5e-5

Using device: mps


In [32]:
import pandas as pd

records = []  # uniprot_ac, kingdom, type_, sequence, label
with open("/Users/jonas/Desktop/Uni/PBL/data/complete_set_unpartitioned.fasta", "r") as f:
    current_record = None
    for line in f:
        if line.startswith(">"):
            if current_record is not None:
                if current_record["sequence"] is not None and current_record["label"] is not None:
                    # Save the previous record before starting a new one
                    records.append(current_record)
                else:
                    # If the previous record is incomplete, skip it
                    print("Skipping incomplete record:", current_record)
            # Start a new record
            uniprot_ac, kingdom, type_ = line[1:].strip().split("|")
            current_record = {"uniprot_ac": uniprot_ac, "kingdom": kingdom, "type": type_, "sequence": None, "label": None}
        else:
            # Check if the line contains a sequence or a label
            if current_record["sequence"] is None:
                current_record["sequence"] = line.strip()
            elif current_record["label"] is None:
                current_record["label"] = line.strip()
            else:
                # If both sequence and label are already set, skip this line
                print("Skipping extra line in record:", current_record)
    # Save the last record if it's complete
    if current_record is not None:
        if current_record["sequence"] is not None and current_record["label"] is not None:
            records.append(current_record)
        else:
            print("Skipping incomplete record:", current_record)

"""
# Save the DataFrame to a CSV file
df_raw.to_csv("/Users/jonas/Desktop/Uni/PBL/data/complete_set_unpartitioned.csv", index=False)
"""
# Print the number of records
print(f"Total records: {len(records)}")
df_raw = pd.DataFrame(records)
df_raw.head()


Total records: 25693


Unnamed: 0,uniprot_ac,kingdom,type,sequence,label
0,Q8TF40,EUKARYA,NO_SP,MAPTLFQKLFSKRTGLGAPGRDARDPDCGFSWPLPEFDPSQIRLIV...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
1,Q1ENB6,EUKARYA,NO_SP,MDFTSLETTTFEEVVIALGSNVGNRMNNFKEALRLMKDYGISVTRH...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
2,P36001,EUKARYA,NO_SP,MDDISGRQTLPRINRLLEHVGNPQDSLSILHIAGTNGKETVSKFLT...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
3,P55317,EUKARYA,NO_SP,MLGTVKMEGHETSDWNSYYADTQEAYSSVPVSNMNSGLGSMNSMNT...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
4,P35583,EUKARYA,NO_SP,MLGAVKMEGHEPSDWSSYYAEPEGYSSVSNMNAGLGMNGMNTYMSM...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...


In [33]:
df = df_raw[~df_raw["label"].str.contains("P")]
df.describe()

Unnamed: 0,uniprot_ac,kingdom,type,sequence,label
count,25580,25580,25580,25580,25580
unique,25580,4,5,24367,1878
top,Q8TF40,EUKARYA,NO_SP,MKLSRRSFMKANAVAAAAAAAGLSVPGVARAVVGQQEAIKWDKAPC...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
freq,1,20423,19036,41,16382


In [34]:
label_map = {'S': 0, 'T': 1, 'L': 2, 'I': 3, 'M': 4, 'O': 5}

df_encoded = df.copy()
df_encoded["label"] = df_encoded["label"].apply(lambda x: [label_map[c] for c in x if c in label_map])
df_encoded = df_encoded[df_encoded["label"].map(len) > 0]  # Remove rows with empty label lists

# make random smaller dataset
#df_encoded = df_encoded.sample(frac=0.4, random_state=42)

sequences = df_encoded["sequence"].tolist()
label_seqs = df_encoded["label"].tolist()

df_encoded.describe()


Unnamed: 0,uniprot_ac,kingdom,type,sequence,label
count,25580,25580,25580,25580,25580
unique,25580,4,5,24367,1878
top,Q8TF40,EUKARYA,NO_SP,MKLSRRSFMKANAVAAAAAAAGLSVPGVARAVVGQQEAIKWDKAPC...,"[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ..."
freq,1,20423,19036,41,16382


In [35]:
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False )
encoder = BertModel.from_pretrained(MODEL_NAME)
encoder.to(DEVICE)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30, 1024, padding_idx=0)
    (position_embeddings): Embedding(40000, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-29): 30 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.0, i

In [36]:
# Stratify by sequence length to avoid ValueError
train_seqs, test_seqs, train_label_seqs, test_label_seqs = train_test_split(
    sequences, label_seqs, test_size=0.3, random_state=42
)

In [37]:
class SPDataset(Dataset):
    def __init__(self, sequences, label_seqs, label_map):
        self.label_map = label_map
        self.label_seqs = label_seqs
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        # preprocess the sequence (insert spaces between amino acids)
        seq_processed = " ".join(list(seq))
        labels = self.label_seqs[idx]
        # Tokenize the sequence
        encoded = tokenizer(seq_processed, return_tensors="pt",
                            padding="max_length", truncation=True, max_length=512)
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)
        
        # Build a label tensor of the same length as input_ids.
        # For tokens beyond the original sequence length, assign -100 so that loss ignores these.
        orig_length = len(seq)
        token_labels = []
        
        for i in range(input_ids.size(0)):
            if i == 0 or i > orig_length:  
                token_labels.append(-100)  # ignore [CLS] or padding tokens
            else:
                # Use the already encoded label directly
                token_labels.append(labels[i-1])
        labels_tensor = torch.tensor(token_labels)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels_tensor
        }

train_dataset = SPDataset(train_seqs, train_label_seqs, label_map)
test_dataset = SPDataset(test_seqs, test_label_seqs, label_map)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [None]:
# class for the model on top of the prot_bert encoder
# using the encoder as a protein feature extractor
# and adding a (linear) classifier on top

class SPClassifier(nn.Module):
    def __init__(self, encoder_model):
        super().__init__()
        self.encoder = encoder_model # this is just the bert model
        self.dropout = nn.Dropout(0.3) # zero out 30% of the neuron outputs to prevent overfitting
        # Ensure the classifier input dimension matches the ProtBERT hidden size
        hidden_size = self.encoder.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(), # negative values are set to 0
            nn.Dropout(0.3),
            nn.Linear(256, NUM_CLASSES)
        )

    # takes encoder output and applies the classifier
    def forward(self, input_ids, attention_mask):
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = encoder_output.last_hidden_state  # shape: (batch, seq_len, hidden)
        output = self.classifier(self.dropout(hidden_states))
        return output

In [None]:
from transformers import get_linear_schedule_with_warmup

# Initialize the model
model = SPClassifier(encoder).to(DEVICE)

# optimizer and Loss
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,           # Learning rate
    betas=(0.85, 0.999),  # momentum
    eps=1e-6,          # epsilon
    weight_decay=0.01  # regularization
)

# scheduler for learning rate
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=len(train_loader) * EPOCHS
)

# Counter({'I': 1204001, 'O': 362643, 'S': 85526, 'M': 74445, 'L': 46065, 'T': 22272, 'P': 951})
class_counts = [1204001, 85526, 22272, 46865, 74445, 362643]  # Count for each class (I, S, T, L, M, O)

# hopefully deals with the class imbalance -> loss func acknowledges the imbalance with class weights 
# most frquent class gets the lowest weight here
weights = torch.tensor([1.0 / count for count in class_counts], device=DEVICE)

# loss function that ignores the padding tokens (-100)
loss_fn = nn.CrossEntropyLoss(weight=weights, ignore_index=-100)

In [40]:
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score

writer = SummaryWriter(log_dir="/Users/jonas/Desktop/Uni/PBL/logs/prot_bert_linear_classifier_v1")

# Training Loop
model.train()

for param in model.encoder.parameters():
    param.requires_grad = False

for epoch in range(EPOCHS):

    # unfreeze the last layer of the encoder after 2 epochs abd the rest after 7 epochs
    if epoch == 2:
        for param in model.encoder.encoder.layer[-1:].parameters():
            param.requires_grad = True
    elif epoch == 7:
        for param in model.encoder.parameters():
            param.requires_grad = True

    total_loss = 0 # total epoch loss
    for batch in train_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        token_labels = batch['labels'].to(DEVICE)

        optimizer.zero_grad()  # resets gradients
        logits = model(input_ids, attention_mask)  # forward pass
        loss = loss_fn(logits.view(-1, NUM_CLASSES), token_labels.view(-1)) # flatten logits and labels and compute loss
        loss.backward()  # backpropagation
        optimizer.step()  # update weights
        scheduler.step() # update learning rate

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {avg_train_loss:.4f}")
    writer.add_scalar("Loss/train", avg_train_loss, epoch)


Epoch 1, Loss: 1.3051
Epoch 2, Loss: 0.5268
Epoch 3, Loss: 0.1495
Epoch 4, Loss: 0.0853
Epoch 5, Loss: 0.0756
Epoch 6, Loss: 0.0706
Epoch 7, Loss: 0.0679


In [41]:
# Evaluation
model.eval()
val_loss = 0
all_preds = [] # predicted labels
all_labels = [] # true types
with torch.no_grad():
    # uses the batches from the test set for eval
    for batch in test_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        logits = model(input_ids, attention_mask) # runs forward pass
        preds = torch.argmax(logits, dim=-1) # get per token predicted labels

        loss = loss_fn(logits.view(-1, NUM_CLASSES), labels.view(-1)) # calculate loss
        val_loss += loss.item()

        # flatten the predictions and labels
        preds_flat = preds.view(-1)
        labels_flat = labels.view(-1)
        valid_idx = labels_flat != -100 # exclude padding tokens (-100)
        all_preds.extend(preds_flat[valid_idx].cpu().numpy())
        all_labels.extend(labels_flat[valid_idx].cpu().numpy())

print(classification_report(all_labels, all_preds))

avg_val_loss = val_loss / len(test_loader)
val_acc = accuracy_score(all_labels, all_preds)

print(f"Epoch {epoch+1}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
writer.add_scalar("Loss/val", avg_val_loss, epoch)
writer.add_scalar("Accuracy/val", val_acc, epoch)

writer.flush()
writer.close()


              precision    recall  f1-score   support

           0       0.98      0.51      0.67     25619
           1       0.74      0.94      0.83      6739
           2       0.66      0.98      0.79     14328
           3       0.94      0.99      0.97    358968
           4       0.80      0.86      0.83     21858
           5       0.97      0.81      0.88    108891

    accuracy                           0.93    536403
   macro avg       0.85      0.85      0.83    536403
weighted avg       0.93      0.93      0.92    536403

Epoch 7, Val Loss: 0.0662, Val Acc: 0.9264


In [42]:

torch.save(model.state_dict(), "/Users/jonas/Desktop/Uni/PBL/models/prot_bert_linear_classifier_v1.pt")

In [None]:


# compute percentage of false predicted labels
def sequence_level_accuracy(predictions, labels):
    correct = 0
    total = 0
    for pred, label in zip(predictions, labels):
        if len(pred) != len(label):
            min_len = min(len(pred), len(label))
            pred = pred[:min_len]
            label = label[:min_len]
        total += 1
        if (pred == label).all():
            correct += 1
    return correct / total
acc = sequence_level_accuracy(all_preds, all_labels)
print(f"Sequence Level Accuracy: {acc:.4f}")

Sequence Level Accuracy: 0.7131
