In [1]:
import tqdm as notebook_tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

MODEL_NAME = "Rostlab/ProstT5"
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" # using mps instead of cuda for training on mac
#DEVICE = "cpu"  # use GPU if available, otherwise CPU
print(f"Using device: {DEVICE}")
NUM_CLASSES = 6  # num classes for classification
BATCH_SIZE = 64
EPOCHS = 10
LR = 0.001

Using device: cuda


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
import pandas as pd

records = []  # uniprot_ac, kingdom, type_, sequence, label
with open("/content/drive/MyDrive/PBL Rost/data/complete_set_unpartitioned.fasta", "r") as f:
    current_record = None
    for line in f:
        if line.startswith(">"):
            if current_record is not None:
                if current_record["sequence"] is not None and current_record["label"] is not None:
                    # Save the previous record before starting a new one
                    records.append(current_record)
                else:
                    # If the previous record is incomplete, skip it
                    print("Skipping incomplete record:", current_record)
            # Start a new record
            uniprot_ac, kingdom, type_ = line[1:].strip().split("|")
            current_record = {"uniprot_ac": uniprot_ac, "kingdom": kingdom, "type": type_, "sequence": None, "label": None}
        else:
            # Check if the line contains a sequence or a label
            if current_record["sequence"] is None:
                current_record["sequence"] = line.strip()
            elif current_record["label"] is None:
                current_record["label"] = line.strip()
            else:
                # If both sequence and label are already set, skip this line
                print("Skipping extra line in record:", current_record)
    # Save the last record if it's complete
    if current_record is not None:
        if current_record["sequence"] is not None and current_record["label"] is not None:
            records.append(current_record)
        else:
            print("Skipping incomplete record:", current_record)

"""
# Save the DataFrame to a CSV file
df_raw.to_csv("/Users/jonas/Desktop/Uni/PBL/data/complete_set_unpartitioned.csv", index=False)
"""
# Print the number of records
print(f"Total records: {len(records)}")
df_raw = pd.DataFrame(records)
df_raw.head()


Total records: 25693


Unnamed: 0,uniprot_ac,kingdom,type,sequence,label
0,Q8TF40,EUKARYA,NO_SP,MAPTLFQKLFSKRTGLGAPGRDARDPDCGFSWPLPEFDPSQIRLIV...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
1,Q1ENB6,EUKARYA,NO_SP,MDFTSLETTTFEEVVIALGSNVGNRMNNFKEALRLMKDYGISVTRH...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
2,P36001,EUKARYA,NO_SP,MDDISGRQTLPRINRLLEHVGNPQDSLSILHIAGTNGKETVSKFLT...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
3,P55317,EUKARYA,NO_SP,MLGTVKMEGHETSDWNSYYADTQEAYSSVPVSNMNSGLGSMNSMNT...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
4,P35583,EUKARYA,NO_SP,MLGAVKMEGHEPSDWSSYYAEPEGYSSVSNMNAGLGMNGMNTYMSM...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...


In [4]:
df = df_raw[~df_raw["label"].str.contains("P")]
df.describe()

Unnamed: 0,uniprot_ac,kingdom,type,sequence,label
count,25580,25580,25580,25580,25580
unique,25580,4,5,24367,1878
top,B3GZ85,EUKARYA,NO_SP,MKLSRRSFMKANAVAAAAAAAGLSVPGVARAVVGQQEAIKWDKAPC...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
freq,1,20423,19036,41,16382


In [5]:
# perform oversampling
from sklearn.utils import resample
# Separate majority and minority classes
df_majority = df[df["type"] == "NO_SP"]
df_minority = df[df["type"] != "NO_SP"]

# Upsample minority class
df_minority_upsampled = resample(df_minority,
                                  replace=True,     # sample with replacement
                                  n_samples=len(df_majority),    # to match majority class
                                  random_state=42) # reproducible results
# Combine majority class with upsampled minority class
df_upsampled = pd.concat([df_majority, df_minority_upsampled])
# Shuffle the dataset
df_upsampled = df_upsampled.sample(frac=1, random_state=42).reset_index(drop=True)

# total records after oversampling
print(f"Total records after oversampling: {len(df_upsampled)}")

# majority class distribution
print("Class distribution after oversampling:")
print(df_upsampled["type"].value_counts())

Total records after oversampling: 38072
Class distribution after oversampling:
type
NO_SP      19036
SP         10614
LIPO        6614
TAT         1712
TATLIPO       96
Name: count, dtype: int64


In [6]:
label_map = {'S': 0, 'T': 1, 'L': 2, 'I': 3, 'M': 4, 'O': 5}

df_encoded = df.copy()
df_encoded["label"] = df_encoded["label"].apply(lambda x: [label_map[c] for c in x if c in label_map])
df_encoded = df_encoded[df_encoded["label"].map(len) > 0]  # Remove rows with empty label lists

# make random smaller dataset
#df_encoded = df_encoded.sample(frac=0.4, random_state=42)

sequences = df_encoded["sequence"].tolist()
label_seqs = df_encoded["label"].tolist()

df_encoded.describe()


Unnamed: 0,uniprot_ac,kingdom,type,sequence,label
count,25580,25580,25580,25580,25580
unique,25580,4,5,24367,1878
top,B3GZ85,EUKARYA,NO_SP,MKLSRRSFMKANAVAAAAAAAGLSVPGVARAVVGQQEAIKWDKAPC...,"[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ..."
freq,1,20423,19036,41,16382


In [7]:
from transformers import T5Tokenizer, T5EncoderModel

# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
encoder = T5EncoderModel.from_pretrained(MODEL_NAME)
encoder.to(DEVICE)

KeyboardInterrupt: 

In [None]:
# Stratify by sequence length to avoid ValueError
train_seqs, test_seqs, train_label_seqs, test_label_seqs = train_test_split(
    sequences, label_seqs, test_size=0.3, random_state=42
)

In [None]:
# Load the data
class SPDataset(Dataset):
    def __init__(self, sequences, label_seqs, label_map):
        self.label_map = label_map
        self.label_seqs = label_seqs
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        labels = self.label_seqs[idx]
        # Tokenize the sequence (padding to ensure all sequences are the same length -> 512 tokens)
        encoded = tokenizer(seq, return_tensors="pt",
                            padding="max_length", truncation=True, max_length=512)
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)

        # Build a label tensor of the same length as input_ids.
        orig_length = len(seq)
        token_labels = []

        for i in range(input_ids.size(0)):
            if i == 0 or i > orig_length:
                token_labels.append(0)  # pad with 0
            else:
                token_labels.append(labels[i-1])

        labels_tensor = torch.tensor(token_labels)

        return {
            'input_ids': input_ids, # tokenized and padded
            'attention_mask': attention_mask, # differentiate between padding and non-padding tokens
            'labels': labels_tensor # aligned label tensor
        }

train_dataset = SPDataset(train_seqs, train_label_seqs, label_map)
test_dataset = SPDataset(test_seqs, test_label_seqs, label_map)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [13]:
!pip install pytorch-crf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchcrf import CRF

class SPCNNClassifier(nn.Module):
    def __init__(self, encoder_model, num_labels):
        super().__init__()
        self.encoder = encoder_model
        self.dropout = nn.Dropout(0.2)
        hidden_size = self.encoder.config.hidden_size
        # detects local features in the sequence
        self.conv = nn.Conv1d(in_channels=hidden_size, out_channels=1024, kernel_size=8, dilation=2, padding=7)
        # Normalize the convolution output (expects shape: (batch, 1024, seq_len))
        self.bn_conv = nn.BatchNorm1d(1024)
        # 2 layer long short term memory network
        self.lstm = nn.LSTM(input_size=1024, hidden_size=512, num_layers=3, bidirectional=True, batch_first=True)
        # dense layer
        self.classifier = nn.Linear(512 * 2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)



    def forward(self, input_ids, attention_mask, labels=None):
        # Encode with BERT
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = encoder_output.last_hidden_state  # (batch, seq_len, hidden_size)

        #  CNN 1
        # Apply conv, then batch normalization and ReLU
        x_conv = self.conv(hidden_states.transpose(1, 2))  # (batch, 1024, seq_len)
        x_conv = self.bn_conv(x_conv)
        x_conv = F.relu(x_conv)                          # (batch, 1024, seq_len)

        # Transpose CNN output for LSTM: (batch, seq_len, features)
        x_lstm_input = x_conv.transpose(1, 2)           # (batch, seq_len, 1024)

        # Apply BiLSTM
        lstm_out, _ = self.lstm(x_lstm_input)            # (batch, seq_len, 1024)

        # Classifier to num_labels
        x_linear = self.classifier(lstm_out)             # (batch, seq_len, num_labels)
        logits = self.dropout(x_linear)                  # (batch, seq_len, num_labels)

        if labels is not None:
            loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean')
            return loss
        else:
            predictions = self.crf.decode(logits, mask=attention_mask.bool())
            return predictions



[0mCollecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl.metadata (2.4 kB)
Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
[0mInstalling collected packages: pytorch-crf
[0mSuccessfully installed pytorch-crf-0.7.2


In [None]:
from transformers import get_linear_schedule_with_warmup

# Initialize the model
model = SPCNNClassifier(encoder, NUM_CLASSES).to(DEVICE)

# optimizer
optimizer = torch.optim.AdamW([
    {"params": model.encoder.encoder.layer[-4:].parameters(), "lr": 5e-6},
    {"params": model.conv.parameters(), "lr": 1e-3},
    {"params": model.classifier.parameters(), "lr": 1e-3},
    {"params": model.lstm.parameters(), "lr": 1e-3},
    {"params": model.crf.parameters(), "lr": 1e-3},
], weight_decay=0.01)

total_steps = len(train_loader) * EPOCHS
# scheduler for learning rate
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)


In [None]:
# Compute sequence-level accuracy, skipping -100 (ignored) positions
def sequence_level_accuracy(preds_flat, labels_flat, test_label_seqs):
    # reconstruct the sequences from the flat predictions
    seq_lengths = [len(seq) for seq in test_label_seqs]
    preds_seq = []
    labels_seq = []
    idx = 0
    for l in seq_lengths:
        preds_seq.append(preds_flat[idx:idx+l])
        labels_seq.append(labels_flat[idx:idx+l])
        idx += l

    # check if the valid predictions match the labels
    correct = 0
    for pred, label in zip(preds_seq, labels_seq):
        is_valid = [l != -100 for l in label]
        valid_preds = [p for p, valid in zip(pred, is_valid) if valid]
        valid_labels = [l for l, valid in zip(label, is_valid) if valid]
        if valid_preds == valid_labels:
            correct += 1

    total = len(seq_lengths)
    return correct / total if total > 0 else 0.0


In [None]:
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from torch.amp import autocast, GradScaler

writer = SummaryWriter(log_dir="/content/drive/MyDrive/PBL Rost/logs/bert_classifier_v6")
scaler = GradScaler()

# Freeze encoder parameters initially
for param in model.encoder.parameters():
    param.requires_grad = False

model.encoder.gradient_checkpointing_enable()  # Enable gradient checkpointing for memory efficienc

for epoch in range(EPOCHS):
    model.train()

    # Gradually unfreeze subsets of encoder layers for efficiency
    if epoch == 4:
        try:
            for param in model.encoder.encoder.layer[-4:].parameters():
                param.requires_grad = True
            print("Unfreezing last 4 layers of encoder")
        except RuntimeError as e:
            print("Error unfreezing parameters in Epoch 4:", e)
            torch.mps.empty_cache() # Clear cache immediately on error
            continue

    elif epoch == 7:
        try:
            for param in model.encoder.encoder.layer[-7:].parameters():
                param.requires_grad = True
            print("Unfreezing last 7 layers of encoder")
        except RuntimeError as e:
            print("Error unfreezing parameters in Epoch 7:", e)
            torch.mps.empty_cache() # Clear cache immediately on error
            continue

    elif epoch == 9:
        try:
            for param in model.encoder.parameters():
                param.requires_grad = True
            print("Unfreezing all layers of encoder")
        except RuntimeError as e:
            print("Error unfreezing parameters in Epoch 9:", e)
            torch.mps.empty_cache() # Clear cache immediately on error
            continue

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch")
    total_loss = 0  # total epoch loss

    for batch in pbar:
        try:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            token_labels = batch['labels'].to(DEVICE)

            optimizer.zero_grad()  # reset gradients
            with autocast(device_type=DEVICE):  # mixed precision training
                loss = model(input_ids, attention_mask, token_labels)  # forward pass

            loss = model(input_ids, attention_mask, token_labels)  # forward pass
            scaler.scale(loss).backward()      # backpropagation
            scaler.unscale_(optimizer) # Unscale gradients before clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # gradient clipping
            scaler.step(optimizer)     # update weights
            scaler.update()           # update scaler
            scheduler.step()     # update learning rate

            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        except RuntimeError as e:
                print("Error during training:", e)
                torch.mps.empty_cache()
                continue

    torch.mps.empty_cache()  # Clear cache at the end of each epoch
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {avg_train_loss:.4f}")
    writer.add_scalar("Loss/train", avg_train_loss, epoch)

writer.flush()
writer.close()



In [None]:

torch.save(model.state_dict(), "/content/drive/MyDrive/PBL Rost/models/bert_classifier_v6.pt")

In [None]:
# Evaluation

import os
import matplotlib.pyplot as plt
import numpy as np
import torch


model = SPCNNClassifier(encoder, NUM_CLASSES).to(DEVICE)
model.load_state_dict(torch.load("/content/drive/MyDrive/PBL Rost/models/bert_classifier_v6.pt", map_location=DEVICE))


val_loss = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        # Compute loss using CRF (pass labels)
        loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        val_loss += loss.item()

        # Decode predictions using CRF (no labels passed)
        predictions = model(input_ids=input_ids, attention_mask=attention_mask)  # List[List[int]]

        # Loop through batch and collect valid tokens
        for pred_seq, label_seq, mask in zip(predictions, labels, attention_mask):
            for pred, true, is_valid in zip(pred_seq, label_seq, mask):
                if true.item() != -100 and is_valid.item() == 1:
                    all_preds.append(pred)
                    all_labels.append(true.item())

from sklearn.metrics import classification_report, matthews_corrcoef

# Report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=list(label_map.keys())))

# F1 Score weighted
from sklearn.metrics import f1_score
f1 = f1_score(all_labels, all_preds, average='weighted')
print(f"F1 Score (weighted): {f1:.4f}")
# F1 Score macro
f1_macro = f1_score(all_labels, all_preds, average='macro')
print(f"F1 Score (macro): {f1_macro:.4f}")

# Sequence Level Accuracy
seq_acc_val = sequence_level_accuracy(all_preds, all_labels, test_label_seqs)
print(f"Sequence Level Accuracy: {seq_acc_val:.4f}")

# Matthews Correlation Coefficient (MCC)
mcc = matthews_corrcoef(all_labels, all_preds)
print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")

# Token-level Accuracy
token_acc = accuracy_score(all_labels, all_preds)
print(f"Token-level Accuracy: {token_acc:.4f}")

# Confusion Matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(all_labels, all_preds, labels=list(label_map.values()))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_map.keys())
disp.plot(cmap="OrRd", xticks_rotation=45)
plt.title("Confusion Matrix")
plt.show()

writer.add_scalar("Loss/test", val_loss)

writer.flush()
writer.close()
