In [1]:
import os
import sys
import time
import random
import json
import logging
from typing import Optional

import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import (
    mean_squared_error, 
    mean_absolute_error, 
    r2_score,
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score, 
    roc_auc_score, 
    average_precision_score, 
    matthews_corrcoef
)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import esm
from antiberty import AntiBERTyRunner

# Custom imports from aggrepred package
top_folder_path = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..'))
sys.path.insert(0, top_folder_path)

from aggrepred.dataset import *
from aggrepred.model import *
from aggrepred.utils import *

# Seed everything function
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)


  from .autonotebook import tqdm as notebook_tqdm


# Dataset loader


In [2]:

class AntibodySeqDataset:
    def __init__(self, df,  max_seq_len=700):
        self.data = df.copy()
        
        # Convert scores from string to list
        self.data['Hchain_scores'] = self.data['Hchain_scores'].apply(ast.literal_eval)
        self.data['Lchain_scores'] = self.data['Lchain_scores'].apply(ast.literal_eval)

        # Calculate positive and negative counts for heavy and light chains
        self.data['Hchain_count_positive'] = self.data['Hchain_scores'].apply(lambda x: sum(1 for score in x if score > 0))
        self.data['Hchain_count_negative'] = self.data['Hchain_scores'].apply(lambda x: sum(1 for score in x if score <= 0))
        self.data['Lchain_count_positive'] = self.data['Lchain_scores'].apply(lambda x: sum(1 for score in x if score > 0))
        self.data['Lchain_count_negative'] = self.data['Lchain_scores'].apply(lambda x: sum(1 for score in x if score <= 0))

        # Compute lengths of heavy and light chains
        self.data['Hchain_len'] = self.data['Hchain_scores'].apply(len)
        self.data['Lchain_len'] = self.data['Lchain_scores'].apply(len)

        # Compute negative-to-positive ratio for heavy and light chains
        self.data['Hchain_neg_to_pos_ratio'] = self.data['Hchain_count_negative'] / self.data['Hchain_count_positive']
        self.data['Lchain_neg_to_pos_ratio'] = self.data['Lchain_count_negative'] / self.data['Lchain_count_positive']

        # Set max sequence length and scaling flag
        self.max_seq_len = max_seq_len
    

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self.data):
            raise IndexError("Index out of range")

        
        row = self.data.iloc[idx]
        code = row['ID']
        H_seq = row['Hchain_sequence']
        L_seq = row['Lchain_sequence']
        Hchain_scores = row['Hchain_scores']
        Lchain_scores = row['Lchain_scores']

        # Prepare target vectors for heavy chain
        H_y = Hchain_scores[:self.max_seq_len] + [0] * (450 - len(Hchain_scores))
        H_y = torch.tensor(H_y)

        # Prepare target vectors for light chain
        L_y = Lchain_scores[:self.max_seq_len] + [0] * (250 - len(Lchain_scores))
        L_y = torch.tensor(L_y)

        H_y_bin = (H_y > 0).int()
        L_y_bin = (L_y > 0).int()

        # Generate binary mask based on sequence length (1 for actual values, 0 for padding)
        H_mask = torch.zeros(450, dtype=torch.bool)
        H_mask[:len(Hchain_scores)] = True  # Set the first 'len(Hchain_scores)' to 1

        L_mask = torch.zeros(250, dtype=torch.bool)
        L_mask[:len(Lchain_scores)] = True  # Set the first 'len(Lchain_scores)' to 1

        return {
            'code': code,
            'H_seq': H_seq,
            'H_target_reg': H_y,
            'H_target_bin': H_y_bin,
            'H_mask': H_mask,

            'L_seq': L_seq,
            'L_target_reg': L_y,
            'L_target_bin': L_y_bin,
            'L_mask': L_mask  }


# model


# Trainer 

In [3]:
# ----------------
# DATA
# ----------------

def custom_collate(batch):
    regs_tensor = [item['target_reg'] for item in batch]
    max_len = regs_tensor[0].size()[0]
    
    orig_lens = [item['orig_len'] for item in batch]
    max_orig_len = min(max(orig_lens), max_len)  # Ensure max_orig_len is at most max_len
    
    # print(max_orig_len)
    # truncated_encoded_seqs = [item['encoded_seq'][:max_orig_len,:] for item in batch]
    codes = [item['code'] for item in batch]
    seqs = [item['seq'] for item in batch]
    truncated_regs_tensor = [item['target_reg'][ :max_orig_len] for item in batch]
    truncated_bins_tensor = [item['target_bin'][:max_orig_len] for item in batch]
    
    # encoded_seqs_tensor = torch.stack(truncated_encoded_seqs)
    target_regs_tensor = torch.stack(truncated_regs_tensor)
    target_bins_tensor = torch.stack(truncated_bins_tensor)

    return {
        'code': codes,
        'seq': seqs,
        'target_reg': target_regs_tensor,
        'target_bin': target_bins_tensor,
        'orig_len': torch.tensor(orig_lens)
    }



### Antibody

In [4]:

def load_model_from_checkpoint(model, checkpoint_path, device):
    """
    Loads the model and optimizer state from a checkpoint if it exists.
    
    Args:
    - model (torch.nn.Module): The model to load the state into.
    - optimizer (torch.optim.Optimizer): The optimizer to load the state into.
    - checkpoint_path (str): Path to the checkpoint file.
    - device (torch.device): Device to which the model should be moved.
    
    Returns:
    - start_epoch (int): The epoch to start training from.
    - best_validation_loss (float): The best validation loss recorded in the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        # print(f'Best validation loss: {best_validation_loss}')
    else:
        start_epoch = 0
        best_validation_loss = float('inf')  # Assuming lower is better for validation loss
        print('No checkpoint found.')
    
    model = model.to(device)
    return start_epoch, best_validation_loss



In [5]:

#########################################################################
#########################################################################
df = pd.read_csv("../data/csv/antibody.csv")

# train_dataset = AntibodySeqDataset(df[df.split=='train'].sample(frac=0.10, random_state=42))
# valid_dataset = AntibodySeqDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42))
# test_dataset = AntibodySeqDataset(df[df.split=='test'])
train_dataset = AntibodySeqDataset(df[df.split=='train'])
valid_dataset = AntibodySeqDataset(df[df.split=='valid'])
test_dataset = AntibodySeqDataset(df[df.split=='test'])


train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)



In [6]:
for idx, batch in enumerate(test_dataloader):
    print(batch)
    if idx == 0:
        break

{'code': ['4j8r', '7qny', '6al4', '6hkg', '7kpb', '7kqk', '5jor', '4hzl', '7lop', '2vh5', '7jwp', '5a16', '4hj0', '3thm', '7l7e', '8d3a', '6wo5', '4eow', '2gfb', '6k65', '1nsn', '6mfp', '8ef3', '8dn7', '7zf6', '6mi2', '6bli', '8bbo', '1osp', '6fg2', '6dwi', '6xm2'], 'H_seq': ['VKLQESGGEVVRPGTSVKVSCKASGYAFTNYLIEWVKQRPGQGLEWIGVINPGSGDTNYNEKFKGKATLTADKSSSTAYMQLNSLTSDDSAVYFCARSGAAAPTYYAMDYWGQGVSVTVSSAKTTPPSVYPLAPAAAAANSMVTLGCLVKGYFPEPVTVTWNSGSLSGGVHTFPAVLQSDLYTLSSSVTVPSSTWPSETVTCNVAHPASSTKVDKKIVPR', 'EVQLLESGGDLIQPGGSLRLSCAASGVTVSSNYMSWVRQAPGKGLEWVSIIYPGGSTFYADSVKGRFTISRDNSKNTLYLQMHSLRAEDTAVYYCARDLGSGDMDVWGKGTTVTVSSASTKGPSVFPLAPSSSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKS', 'VQLVQSGAEVKKPGSSVKVSCKASGYAFSSYWMNWVRQAPGQGLEWMGQIWPGDSDTNYAQKFQGRVTITADESTSTAYMELSSLRSEDTAVYYCARRETTTVGRYYYAMDYWGQGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPK', 'EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYGMAWVRQAPGKGLEWVSF

### find the propostion of pos/neg class

#### there are about 80% of negative class vs 20% of positive class  , hence 4:1 ratio

In [7]:
sum_one = train_dataset.data['Hchain_count_positive'].sum()
sum_zero = train_dataset.data['Hchain_count_negative'].sum()
total = train_dataset.data['Hchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = valid_dataset.data['Hchain_count_positive'].sum()
sum_zero = valid_dataset.data['Hchain_count_negative'].sum()
total = valid_dataset.data['Hchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = test_dataset.data['Hchain_count_positive'].sum()
sum_zero = test_dataset.data['Hchain_count_negative'].sum()
total = test_dataset.data['Hchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

propotion of position and negative class:  0.09087559688180555 0.9091244031181944
ratio of position to negative class:  10.00405427103404
propotion of position and negative class:  0.08837646378487783 0.9116235362151222
ratio of position to negative class:  10.315229838050056
propotion of position and negative class:  0.09189772296568413 0.9081022770343159
ratio of position to negative class:  9.881662436548224


In [8]:
sum_one = train_dataset.data['Lchain_count_positive'].sum()
sum_zero = train_dataset.data['Lchain_count_negative'].sum()
total = train_dataset.data['Lchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = valid_dataset.data['Lchain_count_positive'].sum()
sum_zero = valid_dataset.data['Lchain_count_negative'].sum()
total = valid_dataset.data['Lchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = test_dataset.data['Lchain_count_positive'].sum()
sum_zero = test_dataset.data['Lchain_count_negative'].sum()
total = test_dataset.data['Lchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

propotion of position and negative class:  0.05469640667558992 0.94530359332441
ratio of position to negative class:  17.282736669176536
propotion of position and negative class:  0.05594240179772623 0.9440575982022738
ratio of position to negative class:  16.875528541226217
propotion of position and negative class:  0.054230682755002625 0.9457693172449974
ratio of position to negative class:  17.43974571586512


## Config


In [9]:
# ----------------
# PARAM
# ----------------


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Define the configuration dictionary with all the model parameters

# path = "./weights_antibody/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
path = "./weights_antibody/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(antiberty)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"

# Define the path to the pretrained weight of same model on protein dataset
# path_old = "./weights/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
path_old = "./weights/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path_old = "./weights/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path_old = "./weights/seq/(esm)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path_old = "./weights/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"


config = {
    "pooling": False,
    "antibody": True,
    # "use_local": False,
    "use_local": True,
    # "use_global": False,
    "use_global": True,
    "num_localextractor_block": 3,
    "input_dim": 700,
    "output_dim": 700,
    "in_channel": 28,
    "out_channel": 256,
    "kernel_size": 23,
    "dilation": 1,
    "stride": 1,
    "rnn_hid_dim": 128,
    "rnn_layers": 1,
    "bidirectional": True,
    "rnn_dropout": 0.2,
    "attention_heads": 4,
    "learning_rate": 1e-4,
    "batch_size": 32,
    "nb_epochs": 20,
    "encode_mode" : 'onehot_meiler'
}

# with open(path_old+'config.json', 'r') as json_file:
#     config = json.load(json_file)


# Initialize the model
model = Aggrepred(config)
model = model.to(device=device)

model

# Load the model weights from the checkpoint
_, _ = load_model_from_checkpoint(model, path_old + 'model_best.pt', device)


cuda


Loaded checkpoint from ./weights/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 19


In [10]:

# ----------------
#   OPTIMIZER 
# ----------------
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
# optimizer = nn.optim.AdamW(model.parameters(), lr=learning_rate,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


# ----------------
# LOSS
# ----------------

class CombinedLoss(nn.Module):
    def __init__(self, lambda_reg=1.0, lambda_bin=1.0, pos_weight=None):
        super(CombinedLoss, self).__init__()
        self.lambda_reg = lambda_reg
        self.lambda_bin = lambda_bin
        self.mse_loss = nn.MSELoss()  # Regression Loss (MSE)
        
        if pos_weight is not None:
            # Binary Classification Loss (Weighted BCE with logits)
            self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        else:
            self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, outputs, regression_targets):
        # Calculate regression loss
        reg_loss = self.mse_loss(outputs, regression_targets)
        
        # Calculate binary classification loss
        # Convert regression output to binary labels (logits) for classification
        binary_targets = (regression_targets> 0).float()
        bin_loss = self.bce_loss(outputs, binary_targets)
        
        # Combined weighted loss
        total_loss = self.lambda_reg * reg_loss + self.lambda_bin * bin_loss
        return total_loss

mse_loss  = nn.MSELoss()
# mse_loss  = nn.MSELoss(reduction='sum')

# bce_loss = nn.BCELoss()
# bce_loss = nn.BCELoss(weight=class_weights)

# class_weights = torch.Tensor([1.0, 12.0]).cuda()
pos_class_weights = torch.Tensor([4.0]).to(device)
weighted_bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_class_weights)


loss_fn = CombinedLoss(lambda_reg=0.7, lambda_bin=0.3, pos_weight=17.0)


# ----------------
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params

trainable, non_trainable = count_parameters(model)
print(f"Number of trainable parameters: {trainable}")
print(f"Number of non-trainable parameters: {non_trainable}")

Number of trainable parameters: 3778561
Number of non-trainable parameters: 0


In [11]:
print(model)

Aggrepred(
  (local_extractors): ModuleList(
    (0): LocalExtractorBlock(
      (conv): Conv1d(28, 256, kernel_size=(23,), stride=(1,), padding=(11,))
      (BN): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leakyrelu): LeakyReLU(negative_slope=0.01)
      (relu): ReLU()
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (1-2): 2 x LocalExtractorBlock(
      (conv): Conv1d(256, 256, kernel_size=(23,), stride=(1,), padding=(11,))
      (BN): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leakyrelu): LeakyReLU(negative_slope=0.01)
      (relu): ReLU()
      (dropout): Dropout(p=0.2, inplace=False)
    )
  )
  (residue_map): Linear(in_features=28, out_features=256, bias=True)
  (global_extractor): GlobalInformationExtractor(
    (att_bilstm): Att_BiLSTM(
      (lstm): LSTM(28, 128, batch_first=True, bidirectional=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantiza

In [12]:
def format_time(seconds):
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{minutes}m {seconds} s" if minutes>0 else f"{seconds} s"

def train_epoch(model, optimizer, dataloader, encode_mode='onehot_meiler', device = 'cuda', printEvery=100):
    
    model.train()
    total_loss = 0.0
    count_iter = 0
    start_time = time.time()
    epoch_start_time = start_time
    batch_size = dataloader.batch_size
    printEvery = printEvery // batch_size if batch_size else 100  # Adjust printEvery based on batch size

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')
    
    antiberty_model = AntiBERTyRunner()

    with tqdm(total=len(dataloader), desc='Training', unit='batch') as pbar:
        for idx, batch in enumerate(dataloader):
                 
            batch_H_sequences = batch['H_seq']
            batch_L_sequences = batch['L_seq']

            ## different encoding here
            if encode_mode == 'esm':
                Hchain_x = embed_esm_batch(batch_H_sequences,  esm_model, alphabet).to(device)
                Lchain_x = embed_esm_batch(batch_L_sequences,  esm_model, alphabet).to(device)
                Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(450 - Hchain_x.size(1), 0)))
                Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(250 - Lchain_x.size(1), 0)))
            
            elif encode_mode == 'protbert':
                Hchain_x = embed_protbert_batch(batch_H_sequences,  protbert_model, protbert_tokenizer).to(device)
                Lchain_x = embed_protbert_batch(batch_L_sequences,  protbert_model, protbert_tokenizer).to(device)
                Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(450 - Hchain_x.size(1), 0)))
                Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(250 - Lchain_x.size(1), 0)))
            
            elif encode_mode == 'antiberty':
                Hchain_x = embed_antiberty_batch(batch_H_sequences,  antiberty_model).to(device)
                Lchain_x = embed_antiberty_batch(batch_L_sequences,  antiberty_model).to(device)
                Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(450 - Hchain_x.size(1), 0)))
                Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(250 - Lchain_x.size(1), 0)))
                 
            elif encode_mode == 'onehot':
                Hchain_x = onehot_encode_batch(batch_H_sequences, 450).to(device)
                Lchain_x = onehot_encode_batch(batch_L_sequences, 250).to(device)
            else:
                Hchain_x = onehot_meiler_encode_batch(batch_H_sequences, 450).to(device)
                Lchain_x = onehot_meiler_encode_batch(batch_L_sequences, 250).to(device)
            
            x = torch.cat((Hchain_x, Lchain_x), dim=1)


            Hchain_mask = batch['H_mask'].to(device)
            Lchain_mask = batch['L_mask'].to(device)

            masks = torch.cat((Hchain_mask, Lchain_mask ), dim=1)

            
            ## convert (bsz,max_len) to  (bsz,max_len,1)
            H_y_reg = batch['H_target_reg'].unsqueeze(2).float().to(device)
            L_y_reg = batch['L_target_reg'].unsqueeze(2).float().to(device)


            y_reg = torch.cat((H_y_reg, L_y_reg ), dim=1)
            y_reg = clean_output_batch(y_reg, masks)
        
            ## prediction
            final_info, output_reg = model(x, masks)
            
            #trim out the padded part
            output_reg = clean_output_batch(output_reg, masks)
            # print(orig_len.sum())
            assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))
            
            current_loss = loss_fn(output_reg, y_reg)

            # Backpropagation
            optimizer.zero_grad()
            current_loss.backward()

            optimizer.step()
            total_loss += current_loss.item()
            
            printEvery = int(1000/x.size(0))
            count_iter += 1
            if count_iter % printEvery == 0 or idx == len(dataloader) - 1:
                elapsed_time = time.time() - start_time
                remaining_time = (elapsed_time / count_iter) * (len(dataloader) - count_iter)
                print(f"Iteration: {count_iter}, Time: {format_time(elapsed_time)}, Remaining: {format_time(remaining_time)}, Training Loss: {total_loss / count_iter:.4f}")
                start_time = time.time()
            torch.cuda.empty_cache()
            pbar.update(1)

    epoch_time = time.time() - epoch_start_time
    print(f"==> Average Training loss: mse ={total_loss / len(dataloader)}")
    print(f"==> Epoch Training Time: {format_time(epoch_time)}")
    print(f"================================================================\n")
    return total_loss / len(dataloader)


def evaluate(model, dataloader, encode_mode='onehot_meiler', device= 'cuda', mode='valid'):
    model.eval()
    total_loss = 0.0
    
    predictions = []
    targets = []
    binary_predictions = []
    binary_targets = []
    orig_lens = []

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')

    antiberty_model = AntiBERTyRunner()
    
    with torch.no_grad():
        with tqdm(total=len(dataloader), unit='batch') as pbar:
            for idx, batch in enumerate(dataloader):
                    
                batch_H_sequences = batch['H_seq']
                batch_L_sequences = batch['L_seq']
                            

                if encode_mode == 'esm':
                    Hchain_x = embed_esm_batch(batch_H_sequences,  esm_model, alphabet).to(device)
                    Lchain_x = embed_esm_batch(batch_L_sequences,  esm_model, alphabet).to(device)
                    H_max_length =  450
                    L_max_length = 250
                    Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(H_max_length - Hchain_x.size(1), 0)))
                    Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(L_max_length - Lchain_x.size(1), 0)))
                
                elif encode_mode == 'protbert':
                    Hchain_x = embed_protbert_batch(batch_H_sequences,  protbert_model, protbert_tokenizer).to(device)
                    Lchain_x = embed_protbert_batch(batch_L_sequences,  protbert_model, protbert_tokenizer).to(device)
                    H_max_length =  450
                    L_max_length = 250
                    Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(H_max_length - Hchain_x.size(1), 0)))
                    Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(L_max_length - Lchain_x.size(1), 0)))
                    
                elif encode_mode == 'antiberty':
                    Hchain_x = embed_antiberty_batch(batch_H_sequences,  antiberty_model).to(device)
                    Lchain_x = embed_antiberty_batch(batch_L_sequences,  antiberty_model).to(device)
                    H_max_length =  450
                    L_max_length = 250
                    Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(H_max_length - Hchain_x.size(1), 0)))
                    Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(L_max_length - Lchain_x.size(1), 0)))
                    
                elif encode_mode == 'onehot':
                    Hchain_x = onehot_encode_batch(batch_H_sequences, 450).to(device)
                    Lchain_x = onehot_encode_batch(batch_L_sequences, 250).to(device)
                else:
                    Hchain_x = onehot_meiler_encode_batch(batch_H_sequences, 450).to(device)
                    Lchain_x = onehot_meiler_encode_batch(batch_L_sequences, 250).to(device)

      
                x = torch.cat((Hchain_x, Lchain_x), dim=1)


                Hchain_mask = batch['H_mask'].to(device)
                Lchain_mask = batch['L_mask'].to(device)

                masks = torch.cat((Hchain_mask, Lchain_mask ), dim=1)

                # if not config["pooling"]:

                ## convert (bsz,max_len) to  (bsz,max_len,1)
                H_y_reg = batch['H_target_reg'].unsqueeze(2).float().to(device)
                L_y_reg = batch['L_target_reg'].unsqueeze(2).float().to(device)

                # Hchain_y_reg = clean_output_batch(Hchain_y_reg, Hchain_mask)
                
                y_reg = torch.cat((H_y_reg, L_y_reg ), dim=1)
                
                y_reg = clean_output_batch(y_reg, masks)
            
                
                
                ## prediction
                final_info, output_reg = model(x, masks)
                
                #trim out the padded part
                output_reg = clean_output_batch(output_reg, masks)
                # print(orig_len.sum())
                assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))
                
                current_loss = loss_fn(output_reg, y_reg)
                total_loss += current_loss.item()

                #append to list of all preds
                predictions.append(output_reg.cpu().numpy())
                targets.append(y_reg.cpu().numpy())
 
                y_bin = (y_reg.cpu().numpy() > 0).astype(int)
                out_bin = (output_reg.cpu().numpy() > 0).astype(int)

                binary_predictions.append(out_bin)
                binary_targets.append(y_bin)

                torch.cuda.empty_cache()
                pbar.update(1)

    all_predictions = np.concatenate(predictions, axis=0).reshape(-1)
    all_targets = np.concatenate(targets, axis=0).reshape(-1)

    print("all pred:",all_predictions)
    print("all tar:",all_targets)

    overall_mse = mean_squared_error(all_targets, all_predictions)
    overall_rmse = np.sqrt(overall_mse)
    overall_mae = mean_absolute_error(all_targets, all_predictions)
    overall_r2 = r2_score(all_targets, all_predictions)
    overall_pcc, _ = pearsonr(all_targets.flatten(), all_predictions.flatten())
    overall_spearman, p_value = spearmanr(all_targets, all_predictions)

    print(f"Overall Regression Metrics")
    print(f"MSE: {overall_mse:.4f}, RMSE: {overall_rmse:.4f}, MAE: {overall_mae:.4f}, R2: {overall_r2:.4f}, PCC: {overall_pcc:.4f}, spear: {overall_spearman:.4f}, P-value: {p_value:.4f}")
    
    all_binary_predictions = np.concatenate(binary_predictions, axis=0).reshape(-1)
    all_binary_targets = np.concatenate(binary_targets, axis=0).reshape(-1)
    
    overall_accuracy = accuracy_score(all_binary_targets, all_binary_predictions)
    overall_precision = precision_score(all_binary_targets, all_binary_predictions)
    overall_recall = recall_score(all_binary_targets, all_binary_predictions)
    overall_f1 = f1_score(all_binary_targets, all_binary_predictions)
    overall_auc_roc = roc_auc_score(all_binary_targets, all_predictions)
    overall_auc_pr = average_precision_score(all_binary_targets, all_predictions)
    overall_mcc = matthews_corrcoef(all_binary_targets, all_binary_predictions)
    
    
    print(f"Overall classification Metrics")
    print(f"Acc: {overall_accuracy:.4f}, Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1-Score: {overall_f1:.4f}, AUC-ROC: {overall_auc_roc:.4f}, AUC-PR: {overall_auc_pr:.4f}, MCC: {overall_mcc:.4f}")  

    metrics = {
        "Regression Metrics": {
            "MSE": round(float(overall_mse), 4),
            "RMSE": round(float(overall_rmse), 4),
            "MAE": round(float(overall_mae), 4),
            "R2": round(float(overall_r2), 4),
            "PCC": round(float(overall_pcc), 4)
        },
        "Classification Metrics": {
            "Accuracy": round(float(overall_accuracy), 4),
            "Precision": round(float(overall_precision), 4),
            "Recall": round(float(overall_recall), 4),
            "F1-Score": round(float(overall_f1), 4),
            "AUC-ROC": round(float(overall_auc_roc), 4),
            "AUC-PR": round(float(overall_auc_pr), 4),
            "MCC": round(float(overall_mcc), 4)
        }
    }


    return total_loss / len(dataloader),metrics, predictions, targets

def train_loop(model, optimizer, train_dataloader, valid_dataloader, nb_epochs,  encode_mode='onehot_meiler', device= 'cuda', save_directory='./weights/'):
    start_epoch = 1
    best_validation_loss = float('inf')
    early_stopping_counter = 0

    # Paths for saving losses and metrics
    loss_output_path = os.path.join(save_directory, 'losses.json')
    metric_output_path = os.path.join(save_directory, 'metrics.json')
    
    # Initialize lists for losses
    train_losses = []
    val_losses = []
    
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
        print(f'Created directory: {save_directory}')

    checkpoint_path = os.path.join(save_directory, 'model_last.pt')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        
    else:
        print('No checkpoint found. Starting from beginning.')
    
    model.to(device)


    # Load existing losses if available
    if os.path.exists(loss_output_path):
        with open(loss_output_path, 'r') as json_file:
            existing_losses = json.load(json_file)
            train_losses = existing_losses.get('train_losses', [])
            val_losses = existing_losses.get('val_losses', [])
            print(f'Loaded losses from {loss_output_path}.')
            print(train_losses)
            print(val_losses)

    for epoch in range(start_epoch, nb_epochs + 1):
        print("==================================================================================")
        print(f'                            -----EPOCH {epoch}-----')
        print("==================================================================================")
        
        train_loss = train_epoch(model, optimizer, train_dataloader, encode_mode ,device, printEvery=1000)
        train_losses.append(train_loss)
        
        print("==========================VALIDATION===============================================")
        val_loss ,metrics, _ , _ = evaluate(model, valid_dataloader,encode_mode, device)
        val_losses.append(val_loss)

        print(f'==> Epoch {epoch} - Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        if val_loss < best_validation_loss:
            early_stopping_counter = 0
            best_validation_loss = val_loss
            best_model_save_path = os.path.join(save_directory, 'model_best.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'validation_accuracy': val_loss,
            }, best_model_save_path)
            print('\n')
            print(f'Best model checkpoint saved to: {best_model_save_path}')

            # Save metrics of the best model
            with open(metric_output_path, 'w') as json_file:
                json.dump(metrics, json_file, indent=4)
        
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= 5:
                print("\n==> Early stopping triggered. No improvement in validation loss for 3 epochs.")
                break

        last_model_save_path = os.path.join(save_directory, 'model_last.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'validation_accuracy': val_loss,
        }, last_model_save_path)
        print(f'Last epoch model saved to: {last_model_save_path}')

        # Save updated losses to the JSON file
        losses = {
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        with open(loss_output_path, 'w') as json_file:
            json.dump(losses, json_file, indent=4)
        print(f'Losses updated and saved to: {loss_output_path}')
        
        print("==================================================================================\n")
    
        
        
    return

## train here

In [13]:
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), 'w') as json_file:
    json.dump(config, json_file, indent=4)

train_loop(model,optimizer,train_dataloader,valid_dataloader, 50, config['encode_mode'],device,path)

Loaded checkpoint from ./weights_antibody/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_last.pt. Resuming from epoch 20
Loaded losses from ./weights_antibody/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/losses.json.
[0.4576486242406162, 0.4007444068973447, 0.3910627243695436, 0.384471039345235, 0.3798428854824584, 0.37673885145305114, 0.3735851023668124, 0.37019821670320296, 0.36761892320197304, 0.36561967745239354, 0.36342159575886196, 0.3608956108858556, 0.35994238323635525, 0.35772130886713666, 0.35626638303568336, 0.3550759531833507, 0.3538394382706395, 0.3518247464556753, 0.35061384093614273]
[0.3840257633816112, 0.377922537651929, 0.3751929212700237, 0.3652646297758276, 0.36427256194027985, 0.36586505445567047, 0.3608144630085338, 0.3624091094190424, 0.3653991601683877, 0.36567277799953113, 0.3550783856348558, 0.36211435361342, 0.35823058540170843, 0.3521280451254411, 0.35726405815644696, 0.3581189567392

Training:  10%|▉         | 8/81 [00:01<00:18,  4.01batch/s]


KeyboardInterrupt: 

In [20]:
val, metric, preds_val, tar_val = evaluate(model,test_dataloader,config["encode_mode"],device)

100%|██████████| 11/11 [00:01<00:00,  6.18batch/s]


all pred: [-1.8407423  -0.76543653 -1.0598333  ... -2.0413842  -2.2311563
 -1.5531867 ]
all tar: [-1.9357 -0.9904 -1.2494 ... -2.2551 -2.4902 -1.4156]
Overall Regression Metrics
MSE: 0.1164, RMSE: 0.3411, MAE: 0.2352, R2: 0.8352, PCC: 0.9173, spear: 0.9017, P-value: 0.0000
Overall classification Metrics
Acc: 0.9468, Precision: 0.5904, Recall: 0.8951, F1-Score: 0.7115, AUC-ROC: 0.9598, AUC-PR: 0.8160, MCC: 0.7015


In [21]:
preds_val[0]

array([-1.8407423 , -0.76543653, -1.0598333 , ..., -2.0461276 ,
       -2.2356448 , -1.559216  ], dtype=float32)

In [22]:
tar_val[0]

array([-1.9357, -0.9904, -1.2494, ..., -1.773 , -1.8248, -1.2426],
      dtype=float32)

## test

In [23]:

def load_model_from_checkpoint(model, optimizer, checkpoint_path, device):
    """
    Loads the model and optimizer state from a checkpoint if it exists.
    
    Args:
    - model (torch.nn.Module): The model to load the state into.
    - optimizer (torch.optim.Optimizer): The optimizer to load the state into.
    - checkpoint_path (str): Path to the checkpoint file.
    - device (torch.device): Device to which the model should be moved.
    
    Returns:
    - start_epoch (int): The epoch to start training from.
    - best_validation_loss (float): The best validation loss recorded in the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        # print(f'Best validation loss: {best_validation_loss}')
    else:
        start_epoch = 0
        best_validation_loss = float('inf')  # Assuming lower is better for validation loss
        print('No checkpoint found.')
    
    model = model.to(device)
    return start_epoch, best_validation_loss



In [24]:
# # Define the configuration dictionary with all the model parameters
# path = "./weights/seq/(onehot)_(regloss)_(global_1layer256_4head)/"

# with open(path+'config.json', 'r') as json_file:
#     config = json.load(json_file)

# # ----------------
# #  MODEL 
# # ----------------
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# List of model paths
model_paths = [
    "./weights_antibody/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    "./weights_antibody/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    "./weights_antibody/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    "./weights_antibody/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    "./weights_antibody/seq/(antiberty)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    
]

for path in model_paths:
    # Load the config for the current model
    with open(path + 'config.json', 'r') as json_file:
        config = json.load(json_file)

    # Initialize the model
    model = Aggrepred(config)
    model = model.to(device=device)

    # Load the model weights from the checkpoint
    _, _ = load_model_from_checkpoint(model, optimizer, path + 'model_best.pt', device)

    # Evaluate the model
    loss, metrics, preds, tar = evaluate(model, test_dataloader, config['encode_mode'] ,device)

    # Save metrics of the best model
    with open(path + 'result.json', 'w') as json_file:
        json.dump(metrics, json_file, indent=4)

    print(f"Processed model in path: {path}\n")


cuda
Loaded checkpoint from ./weights_antibody/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 15


100%|██████████| 11/11 [00:01<00:00,  6.38batch/s]


all pred: [-2.056324   -0.97252727 -1.254567   ...  0.47811094  1.0779473
  0.16233161]
all tar: [-2.1062 -1.1456 -1.4378 ...  0.1882  0.4603  0.0055]
Overall Regression Metrics
MSE: 0.1166, RMSE: 0.3414, MAE: 0.2384, R2: 0.8349, PCC: 0.9179, spear: 0.9017, P-value: 0.0000
Overall classification Metrics
Acc: 0.9511, Precision: 0.6168, Recall: 0.8807, F1-Score: 0.7255, AUC-ROC: 0.9579, AUC-PR: 0.8100, MCC: 0.7129
Processed model in path: ./weights_antibody/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/

Loaded checkpoint from ./weights_antibody/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 17


100%|██████████| 11/11 [00:01<00:00,  9.97batch/s]


all pred: [ 0.38212374 -1.3036753  -0.10528025 ... -2.7338097  -2.9358606
 -2.9886072 ]
all tar: [ 0.0792 -0.8966  0.     ... -2.6129 -2.8245 -2.8537]
Overall Regression Metrics
MSE: 0.1263, RMSE: 0.3554, MAE: 0.2367, R2: 0.8211, PCC: 0.9157, spear: 0.9043, P-value: 0.0000
Overall classification Metrics
Acc: 0.9502, Precision: 0.6081, Recall: 0.9006, F1-Score: 0.7260, AUC-ROC: 0.9608, AUC-PR: 0.8232, MCC: 0.7161
Processed model in path: ./weights_antibody/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/

Loaded checkpoint from ./weights_antibody/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 23


100%|██████████| 11/11 [00:11<00:00,  1.03s/batch]


all pred: [-2.3063288  -1.0799348  -1.8330145  ... -1.9021101  -2.154097
 -0.85817593]
all tar: [-2.2782  0.     -2.1869 ... -1.9211 -2.1215 -0.6413]
Overall Regression Metrics
MSE: 0.1220, RMSE: 0.3492, MAE: 0.2421, R2: 0.8273, PCC: 0.9154, spear: 0.9023, P-value: 0.0000
Overall classification Metrics
Acc: 0.9498, Precision: 0.6071, Recall: 0.8921, F1-Score: 0.7225, AUC-ROC: 0.9587, AUC-PR: 0.8050, MCC: 0.7116
Processed model in path: ./weights_antibody/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/

Loaded checkpoint from ./weights_antibody/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 22


100%|██████████| 11/11 [00:47<00:00,  4.29s/batch]


all pred: [ 0.6604782  -0.787432   -0.04835248 ... -2.1069932  -2.4169354
 -1.2268027 ]
all tar: [-1.9421 -0.7348 -0.958  ... -2.8083 -2.4804 -2.716 ]
Overall Regression Metrics
MSE: 0.7229, RMSE: 0.8502, MAE: 0.6242, R2: -0.0237, PCC: 0.5036, spear: 0.4408, P-value: 0.0000
Overall classification Metrics
Acc: 0.8910, Precision: 0.3356, Recall: 0.4968, F1-Score: 0.4006, AUC-ROC: 0.8175, AUC-PR: 0.3277, MCC: 0.3511
Processed model in path: ./weights_antibody/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/

Loaded checkpoint from ./weights_antibody/seq/(antiberty)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 19


100%|██████████| 11/11 [00:02<00:00,  5.46batch/s]

all pred: [-1.3935404  -0.87559426 -1.4437468  ... -2.500895   -2.3823454
 -1.0091255 ]
all tar: [-1.7325 -1.6625 -2.2042 ... -2.7051 -2.5862 -1.4152]
Overall Regression Metrics
MSE: 0.1240, RMSE: 0.3522, MAE: 0.2411, R2: 0.8243, PCC: 0.9117, spear: 0.8982, P-value: 0.0000
Overall classification Metrics
Acc: 0.9540, Precision: 0.6393, Recall: 0.8563, F1-Score: 0.7321, AUC-ROC: 0.9466, AUC-PR: 0.7921, MCC: 0.7165
Processed model in path: ./weights_antibody/seq/(antiberty)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/




