In [18]:
import os
import sys
import random
import json
import time
import logging
from typing import Optional

import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
from scipy.stats import pearsonr
from sklearn.metrics import (
    mean_squared_error, 
    mean_absolute_error, 
    r2_score, 
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score, 
    roc_auc_score, 
    average_precision_score, 
    matthews_corrcoef
)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import esm

# Custom imports from aggrepred package
top_folder_path = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..'))
sys.path.insert(0, top_folder_path)
from aggrepred.dataset import *
from aggrepred.model import *
from aggrepred.utils import *

# Seed everything function
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)


# Dataset class



In [8]:
class SeqDataset:
    def __init__(self, df, max_seq_len=1000):
        self.data = df.copy()
        
        self.data['scores'] = self.data['scores'].apply(ast.literal_eval)
        
        def count_pos_neg_values(lst):
            count_pos = sum(1 for x in lst if x > 0)
            count_neg = sum(1 for x in lst if x <= 0)
            return count_pos, count_neg

        # Apply the function to create new columns
        self.data[['count_positive', 'count_negative']] = self.data['scores'].apply(count_pos_neg_values).apply(pd.Series)
        self.data['len'] = self.data['scores'].apply(lambda x: len(x))
        
        self.data['neg_to_pos_ratio'] = self.data['count_negative'] / self.data['count_positive']
        self.max_seq_len = max_seq_len
 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
   
        if idx < 0 or idx >= len(self.data):
            raise IndexError("Index out of range")
        
        row = self.data.iloc[idx]
        code = row['ID']
        seq = row['sequence']
        scores = row['scores']
        
        y  = scores[:self.max_seq_len] + [0] * (self.max_seq_len - len(scores))
        y = torch.tensor(y)

        y_bin =  (y>0).int()

        # Generate binary mask based on sequence length (1 for actual values, 0 for padding)
        mask = torch.zeros(self.max_seq_len, dtype=torch.bool)
        mask[:len(scores)] = True  # Set the first 'len(Hchain_scores)' to 1

        return {
            'code': code,
            'seq': seq,
            'target_reg': y,
            'target_bin': y_bin,
            'mask': mask
        }

# Trainer 

## config

### to train different model, adjust the path and config here

In [9]:
# ----------------
# PARAM
# ----------------

## path here is set in format:  type_embed . type_loss . local block info . global block info .  for easily identified each model 


# Define the configuration dictionary with all the model parameters
# path = "./weights/seq/(esm35M)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(esm35M)_(combinedloss)_()/"
# path = "./weights/seq/(esm)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(protbert)_(combinedloss)_(none)/"
# path = "./weights/seq/()_(combinedloss)_(none)/"
path = "./weights/seq/(onehot_meiler)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"


config = {
    "antibody": False,
    # "use_local": False,
    "use_local": True,
    # "use_global": False,
    "use_global": True,
    "num_localextractor_block": 3,
    "input_dim": 1000,
    "output_dim": 1000,
    "in_channel": 28,
    "out_channel": 256,
    "kernel_size": 23,
    "dilation": 1,
    "stride": 1,
    "rnn_hid_dim": 128,
    "rnn_layers": 1,
    "bidirectional": True,
    "rnn_dropout": 0.2,
    "attention_heads": 4,
    "learning_rate": 1e-3,
    "batch_size": 32,
    "nb_epochs": 20,
    "encode_mode" : 'onehot_meiler'
}

# with open(path+'config.json', 'r') as json_file:
#     config = json.load(json_file)


# ----------------
#  MODEL 
# ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = Aggrepred(config)
model = model.to(device=device)


cuda


In [10]:

# ----------------
#   OPTIMIZER 
# ----------------
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
# optimizer = nn.optim.AdamW(model.parameters(), lr=learning_rate,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


# ----------------
# LOSS
# ----------------

class CombinedLoss(nn.Module):
    def __init__(self, lambda_reg=1.0, lambda_bin=1.0, pos_weight=None):
        super(CombinedLoss, self).__init__()
        self.lambda_reg = lambda_reg
        self.lambda_bin = lambda_bin
        self.mse_loss = nn.MSELoss()  # Regression Loss (MSE)
        
        if pos_weight is not None:
            # Binary Classification Loss (Weighted BCE with logits)
            self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        else:
            self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, outputs, regression_targets):
        # Calculate regression loss
        reg_loss = self.mse_loss(outputs, regression_targets)
        
        # Calculate binary classification loss
        # Convert regression output to binary labels (logits) for classification
        binary_targets = (regression_targets> 0).float()
        bin_loss = self.bce_loss(outputs, binary_targets)
        
        # Combined weighted loss
        total_loss = self.lambda_reg * reg_loss + self.lambda_bin * bin_loss
        return total_loss

mse_loss  = nn.MSELoss()
pos_class_weights = torch.Tensor([4.0]).to(device)
weighted_bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_class_weights)


combined_loss = CombinedLoss(lambda_reg=0.7, lambda_bin=0.3, pos_weight=4.0)


# ----------------
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params

trainable, non_trainable = count_parameters(model)
print(f"Number of trainable parameters: {trainable}")
print(f"Number of non-trainable parameters: {non_trainable}")

Number of trainable parameters: 3778561
Number of non-trainable parameters: 0


In [11]:
# ----------------
# DATA
# ----------------

def custom_collate(batch):
    regs_tensor = [item['target_reg'] for item in batch]
    mask = [item['mask'] for item in batch]
    max_len = regs_tensor[0].size()[0]  #1000
    
    orig_lens = [item['mask'].sum() for item in batch]
    max_orig_len = min(max(orig_lens), max_len)  # Ensure max_orig_len is at most max_len
    
    # print(max_orig_len)
    # truncated_encoded_seqs = [item['encoded_seq'][:max_orig_len,:] for item in batch]
    codes = [item['code'] for item in batch]
    seqs = [item['seq'] for item in batch]
    truncated_regs_tensor = [item['target_reg'][ :max_orig_len] for item in batch]
    truncated_bins_tensor = [item['target_bin'][:max_orig_len] for item in batch]
    truncated_mask_tensor = [item['mask'][ :max_orig_len] for item in batch]
    
    # encoded_seqs_tensor = torch.stack(truncated_encoded_seqs)
    target_regs_tensor = torch.stack(truncated_regs_tensor)
    target_bins_tensor = torch.stack(truncated_bins_tensor)
    mask_tensor = torch.stack(truncated_mask_tensor)

    return {
        'code': codes,
        'seq': seqs,
        'target_reg': target_regs_tensor,
        'target_bin': target_bins_tensor,
        'mask': mask_tensor
    }



## Dataloader

In [17]:

#########################################################################
#########################################################################
df = pd.read_csv("../data/csv/data60_fixed_split.csv")


# ## smaple down abit for esm
if config['encode_mode'] not in ['onehot', 'onehot_meiler']:
    print("yes")
    train_dataset = SeqDataset(df[df.split=='train'].sample(frac=0.10, random_state=42),1000)
    valid_dataset = SeqDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42),1000)
    test_dataset = SeqDataset(df[df.split=='test'],1000)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
    valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
else:
    train_dataset = SeqDataset(df[df.split=='train'],1000)
    valid_dataset = SeqDataset(df[df.split=='valid'],1000)
    test_dataset = SeqDataset(df[df.split=='test'],1000)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)


##collate to flexible max len in batch
# train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)
# valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)

### find the propostion of pos/neg class

#### there are about 80% of negative class vs 20% of positive class  , hence 4:1 ratio

In [13]:
sum_one = train_dataset.data['count_positive'].sum()
sum_zero = train_dataset.data['count_negative'].sum()
total = train_dataset.data['len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = valid_dataset.data['count_positive'].sum()
sum_zero = valid_dataset.data['count_negative'].sum()
total = valid_dataset.data['len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = test_dataset.data['count_positive'].sum()
sum_zero = test_dataset.data['count_negative'].sum()
total = test_dataset.data['len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

propotion of position and negative class:  0.19109947183586853 0.8089005281641315
ratio of position to negative class:  4.232876838397961
propotion of position and negative class:  0.19214283922078573 0.8078571607792142
ratio of position to negative class:  4.2044614519874415
propotion of position and negative class:  0.1936254280065119 0.8063745719934882
ratio of position to negative class:  4.164610920660527


In [14]:
print(model)

Aggrepred(
  (local_extractors): ModuleList(
    (0): LocalExtractorBlock(
      (conv): Conv1d(28, 256, kernel_size=(23,), stride=(1,), padding=(11,))
      (BN): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leakyrelu): LeakyReLU(negative_slope=0.01)
      (relu): ReLU()
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (1-2): 2 x LocalExtractorBlock(
      (conv): Conv1d(256, 256, kernel_size=(23,), stride=(1,), padding=(11,))
      (BN): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leakyrelu): LeakyReLU(negative_slope=0.01)
      (relu): ReLU()
      (dropout): Dropout(p=0.2, inplace=False)
    )
  )
  (residue_map): Linear(in_features=28, out_features=256, bias=True)
  (global_extractor): GlobalInformationExtractor(
    (att_bilstm): Att_BiLSTM(
      (lstm): LSTM(28, 128, batch_first=True, bidirectional=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantiza

In [21]:

def format_time(seconds):
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{minutes}m {seconds} s" if minutes>0 else f"{seconds} s"

def train_epoch(model, optimizer, dataloader,loss_function, encode_mode='onehot_meiler', device = 'cuda', printEvery=100):
    
    model.train()
    total_loss = 0.0
    count_iter = 0
    start_time = time.time()
    epoch_start_time = start_time
    batch_size = dataloader.batch_size
    printEvery = printEvery // batch_size if batch_size else 100  # Adjust printEvery based on batch size

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')

    with tqdm(total=len(dataloader), desc='Training', unit='batch') as pbar:
        for idx, batch in enumerate(dataloader):
                 
            batch_sequences = batch['seq']
            mask = batch['mask'].to(device)

            
            ## different encoding here
            if encode_mode == 'esm':
                x = embed_esm_batch(batch_sequences, esm_model, alphabet).to(device)
                x   = F.pad(x, (0, 0, 0, max(1000 - x.size(1), 0)))
            elif encode_mode == 'protbert':
                x = embed_protbert_batch(batch_sequences, protbert_model, protbert_tokenizer).to(device)
                x   = F.pad(x, (0, 0, 0, max(1000 - x.size(1), 0)))
            elif encode_mode == 'onehot':
                x = onehot_encode_batch(batch_sequences,1000).to(device)
            else:
                x = onehot_meiler_encode_batch(batch_sequences,1000).to(device)


            ## convert (bsz,max_len) to  (bsz,max_len,1)
            y_reg = batch['target_reg'].unsqueeze(2).float().to(device)
            y_bin = batch['target_bin'].unsqueeze(2).float().to(device)

            
            y_reg = clean_output_batch(y_reg, mask)
            y_bin = clean_output_batch(y_bin, mask)

            ### predict 
            final_info, output_reg = model(x, mask)
            
            #trim out the padded part
            output_reg = clean_output_batch(output_reg, mask)
            # print(orig_len.sum())
            assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))

            # current_loss = reg_loss 
            current_loss = loss_function(output_reg, y_reg)

            # Backpropagation
            optimizer.zero_grad()
            current_loss.backward()

            optimizer.step()
            total_loss += current_loss.item()
            
            printEvery = int(1000/x.size(0))
            count_iter += 1
            if count_iter % printEvery == 0 or idx == len(dataloader) - 1:
                elapsed_time = time.time() - start_time
                remaining_time = (elapsed_time / count_iter) * (len(dataloader) - count_iter)
                print(f"Iteration: {count_iter}, Time: {format_time(elapsed_time)}, Remaining: {format_time(remaining_time)}, Training Loss: {total_loss / count_iter:.4f}")
                start_time = time.time()
            torch.cuda.empty_cache()
            pbar.update(1)

    epoch_time = time.time() - epoch_start_time
    print(f"==> Average Training loss: mse ={total_loss / len(dataloader)}")
    print(f"==> Epoch Training Time: {format_time(epoch_time)}")
    print(f"================================================================\n")
    return total_loss / len(dataloader)


def evaluate(model, dataloader,loss_function, encode_mode='onehot_meiler', device= 'cuda', mode='valid'):
    model.eval()
    total_loss = 0.0
    
    predictions = []
    targets = []
    binary_predictions = []
    binary_targets = []
    orig_lens = []

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')



    with torch.no_grad():
        for idx, batch in enumerate(dataloader):

            batch_sequences = batch['seq']
            mask = batch['mask'].to(device)

            ## different encoding here
            if encode_mode == 'esm':
                x = embed_esm_batch(batch_sequences, esm_model, alphabet).to(device)
                x   = F.pad(x, (0, 0, 0, max(1000 - x.size(1), 0)))
            elif encode_mode == 'protbert':
                x = embed_protbert_batch(batch_sequences, protbert_model, protbert_tokenizer).to(device)
                x   = F.pad(x, (0, 0, 0, max(1000 - x.size(1), 0)))
            elif encode_mode == 'onehot':
                x = onehot_encode_batch(batch_sequences,1000).to(device)
            else:
                x = onehot_meiler_encode_batch(batch_sequences,1000).to(device)

            ## convert (bsz,max_len) to  (bsz,max_len,1)
            y_reg = batch['target_reg'].unsqueeze(2).float().to(device)
            y_bin = batch['target_bin'].unsqueeze(2).float().to(device)
            
            y_reg = clean_output_batch(y_reg, mask)
            y_bin = clean_output_batch(y_bin, mask)

            ### predict 
            final_info, output_reg = model(x, mask)
            
            #trim out the padded part
            output_reg = clean_output_batch(output_reg, mask)
            # print(orig_len.sum())
            assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))

            # current_loss = reg_loss 
            current_loss = loss_function(output_reg, y_reg)
     
            total_loss += current_loss.item()

            #append to list of all preds
            predictions.append(output_reg.cpu().numpy())
            targets.append(y_reg.cpu().numpy())

            y_bin = (y_reg.cpu().numpy() > 0).astype(int)
            out_bin = (output_reg.cpu().numpy() > 0).astype(int)

            binary_predictions.append(out_bin)
            binary_targets.append(y_bin)

    # if mode == 'test':
    all_predictions = np.concatenate(predictions, axis=0).reshape(-1)
    all_targets = np.concatenate(targets, axis=0).reshape(-1)
    all_binary_predictions = np.concatenate(binary_predictions, axis=0).reshape(-1)
    all_binary_targets = np.concatenate(binary_targets, axis=0).reshape(-1)

    # Calculate overall metrics
    overall_mse = mean_squared_error(all_targets, all_predictions)
    overall_rmse = np.sqrt(overall_mse)
    overall_mae = mean_absolute_error(all_targets, all_predictions)
    overall_r2 = r2_score(all_targets, all_predictions)
    overall_pcc, _ = pearsonr(all_targets.flatten(), all_predictions.flatten())

    # Calculate binary classification metrics
    overall_accuracy = accuracy_score(all_binary_targets, all_binary_predictions)
    overall_precision = precision_score(all_binary_targets, all_binary_predictions)
    overall_recall = recall_score(all_binary_targets, all_binary_predictions)
    overall_f1 = f1_score(all_binary_targets, all_binary_predictions)
    overall_auc_roc = roc_auc_score(all_binary_targets, all_predictions)
    overall_auc_pr = average_precision_score(all_binary_targets, all_predictions)
    overall_mcc = matthews_corrcoef(all_binary_targets, all_binary_predictions)

    print(f"Overall Regression Metrics")
    print(f"MSE: {overall_mse:.4f}, RMSE: {overall_rmse:.4f}, MAE: {overall_mae:.4f}, R2: {overall_r2:.4f}, PCC: {overall_pcc:.4f}")

    print(f"Overall classification Metrics")
    print(f"Acc: {overall_accuracy:.4f}, Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1-Score: {overall_f1:.4f}, AUC-ROC: {overall_auc_roc:.4f}, AUC-PR: {overall_auc_pr:.4f}, MCC: {overall_mcc:.4f}")  
    
    metrics = {
        "Regression Metrics": {
            "MSE": round(float(overall_mse), 4),
            "RMSE": round(float(overall_rmse), 4),
            "MAE": round(float(overall_mae), 4),
            "R2": round(float(overall_r2), 4),
            "PCC": round(float(overall_pcc), 4)
        },
        "Classification Metrics": {
            "Accuracy": round(float(overall_accuracy), 4),
            "Precision": round(float(overall_precision), 4),
            "Recall": round(float(overall_recall), 4),
            "F1-Score": round(float(overall_f1), 4),
            "AUC-ROC": round(float(overall_auc_roc), 4),
            "AUC-PR": round(float(overall_auc_pr), 4),
            "MCC": round(float(overall_mcc), 4)
        }
    }


    return total_loss / len(dataloader),metrics, predictions, targets

def train_loop(model, optimizer, train_dataloader, valid_dataloader,loss_function, nb_epochs, encode_mode='onehot_meiler', device= 'cuda', save_directory='./weights/'):
    start_epoch = 1
    best_validation_loss = float('inf')
    early_stopping_counter = 0

    # Paths for saving losses and metrics
    loss_output_path = os.path.join(save_directory, 'losses.json')
    metric_output_path = os.path.join(save_directory, 'metrics.json')
    
    # Initialize lists for losses
    train_losses = []
    val_losses = []
    
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
        print(f'Created directory: {save_directory}')

    checkpoint_path = os.path.join(save_directory, 'model_last.pt')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        
        # Load losses from the losses.json file if it exists
        if os.path.exists(loss_output_path):
            with open(loss_output_path, 'r') as f:
                losses = json.load(f)
                train_losses = losses.get('train_losses', [])
                val_losses = losses.get('val_losses', [])
            print(f'Loaded losses from {loss_output_path}.')
            print(train_losses)
            print(val_losses)
        else:
            print(f'No losses file found at {loss_output_path}.')

    else:
        print('No checkpoint found. Starting from beginning.')
    
    model.to(device)


    for epoch in range(start_epoch, nb_epochs + 1):
        print("==================================================================================")
        print(f'                            -----EPOCH {epoch}-----')
        print("==================================================================================")
        
        train_loss = train_epoch(model, optimizer, train_dataloader,loss_function, encode_mode ,device, printEvery=1000)
        train_losses.append(train_loss)

        print("==========================VALIDATION===============================================")
        val_loss ,metrics, _ , _ = evaluate(model, valid_dataloader,loss_function,encode_mode, device)
        val_losses.append(val_loss)

        print(f'==> Epoch {epoch} - Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        if val_loss < best_validation_loss:
            early_stopping_counter = 0
            best_validation_loss = val_loss
            best_model_save_path = os.path.join(save_directory, 'model_best.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'validation_accuracy': val_loss,
            }, best_model_save_path)
            print('\n')
            print(f'Best model checkpoint saved to: {best_model_save_path}')

            # Save metrics of the best model
            with open(metric_output_path, 'w') as json_file:
                json.dump(metrics, json_file, indent=4)
        
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= 5:
                print("\n==> Early stopping triggered. No improvement in validation loss for 3 epochs.")
                break

        last_model_save_path = os.path.join(save_directory, 'model_last.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'validation_accuracy': val_loss,
        }, last_model_save_path)
        print(f'Last epoch model saved to: {last_model_save_path}')

        # Save updated losses to the JSON file
        losses = {
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        with open(loss_output_path, 'w') as json_file:
            json.dump(losses, json_file, indent=4)
        print(f'Losses updated and saved to: {loss_output_path}')
        
        print("==================================================================================\n")
    
        
        
    return

# train here

In [22]:
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), 'w') as json_file:
    json.dump(config, json_file, indent=4)


loss_function = combined_loss
# loss_function = mse_loss

train_loop(model,optimizer,train_dataloader,valid_dataloader,loss_function, 50, config['encode_mode'],device,path)

Loaded checkpoint from ./weights/seq/(onehot_meiler)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/model_last.pt. Resuming from epoch 16
Loaded losses from ./weights/seq/(onehot_meiler)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/losses.json.
[0.6688351124676865, 0.5371009643697172, 0.5063865821814901, 0.4925348173276678, 0.483561003552228, 0.47253315967898213, 0.42953249032914337, 0.4248547785755734, 0.42157351925134257]
[0.6093979679249428, 0.5327651400018383, 0.5141573870504225, 0.5462429934256786, 0.4919483875100677, 0.5662333429665178, 0.42681065403126384, 0.4370848157921353, 0.43982857908751516]
                            -----EPOCH 16-----


Training:   5%|▌         | 31/589 [00:07<02:14,  4.14batch/s]

Iteration: 31, Time: 9 s, Remaining: 2m 45 s, Training Loss: 0.5000


Training:  11%|█         | 62/589 [00:15<02:08,  4.11batch/s]

Iteration: 62, Time: 7 s, Remaining: 1m 4 s, Training Loss: 0.4977


Training:  16%|█▌        | 93/589 [00:22<01:57,  4.21batch/s]

Iteration: 93, Time: 7 s, Remaining: 38 s, Training Loss: 0.4953


Training:  21%|██        | 124/589 [00:29<01:49,  4.24batch/s]

Iteration: 124, Time: 7 s, Remaining: 27 s, Training Loss: 0.4951


Training:  26%|██▋       | 155/589 [00:37<01:44,  4.16batch/s]

Iteration: 155, Time: 7 s, Remaining: 20 s, Training Loss: 0.4947


Training:  28%|██▊       | 162/589 [00:38<01:42,  4.17batch/s]


KeyboardInterrupt: 

## test on test set and return a result in json format for all the trained model above



In [23]:

def load_model_from_checkpoint(model, optimizer, checkpoint_path, device):
    """
    Loads the model and optimizer state from a checkpoint if it exists.
    
    Args:
    - model (torch.nn.Module): The model to load the state into.
    - optimizer (torch.optim.Optimizer): The optimizer to load the state into.
    - checkpoint_path (str): Path to the checkpoint file.
    - device (torch.device): Device to which the model should be moved.
    
    Returns:
    - start_epoch (int): The epoch to start training from.
    - best_validation_loss (float): The best validation loss recorded in the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        # print(f'Best validation loss: {best_validation_loss}')
    else:
        start_epoch = 0
        best_validation_loss = float('inf')  # Assuming lower is better for validation loss
        print('No checkpoint found.')
    
    model = model.to(device)
    return start_epoch, best_validation_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [25]:

# List of model paths
model_paths = [
    
    "./weights/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    "./weights/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
   
    # "./weights/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights/seq/(esm35M)_(combinedloss)_(none)/",

    # "./weights/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights/seq/(protbert)_(combinedloss)_(none)/"
    
    
]

for path in model_paths:
    # Load the config for the current model
    with open(path + 'config.json', 'r') as json_file:
        config = json.load(json_file)

    # Initialize the model
    model = Aggrepred(config)
    model = model.to(device=device)

    # Load the model weights from the checkpoint
    _, _ = load_model_from_checkpoint(model, optimizer, path + 'model_best.pt', device)

    # Evaluate the model
    loss, metrics, preds, tar = evaluate(model, test_dataloader,combined_loss, config['encode_mode'] ,device)

    # Save metrics of the best model
    with open(path + 'result.json', 'w') as json_file:
        json.dump(metrics, json_file, indent=4)

    print(f"Processed model in path: {path}\n")


Loaded checkpoint from ./weights/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 14




Overall Regression Metrics
MSE: 0.5068, RMSE: 0.7119, MAE: 0.5218, R2: 0.7615, PCC: 0.8738
Overall classification Metrics
Acc: 0.8670, Precision: 0.6567, Recall: 0.6564, F1-Score: 0.6566, AUC-ROC: 0.9064, AUC-PR: 0.7137, MCC: 0.5741
Processed model in path: ./weights/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/

Loaded checkpoint from ./weights/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 19




Overall Regression Metrics
MSE: 0.4048, RMSE: 0.6362, MAE: 0.4620, R2: 0.8095, PCC: 0.9002
Overall classification Metrics
Acc: 0.8696, Precision: 0.6309, Recall: 0.7863, F1-Score: 0.7001, AUC-ROC: 0.9276, AUC-PR: 0.7790, MCC: 0.6242
Processed model in path: ./weights/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/

