In [11]:
import torch
import pandas as pd
import sys
import os
import numpy as np
import random
import pickle
from tqdm import tqdm
import torch
import esm
import numpy as np
import os

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import json

top_folder_path = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..'))
sys.path.insert(0, top_folder_path)

from aggrepred.dataset import *
from aggrepred.model import *
from aggrepred.utils import *


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

import torch
import torch.nn as nn

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from typing import Optional

import os
import torch
import time
import logging
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score,accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, matthews_corrcoef
from scipy.stats import pearsonr


# Dataset class



In [12]:
class SeqDataset:
    def __init__(self, df, max_seq_len=1000):
        self.data = df.copy()
        
        self.data['scores'] = self.data['scores'].apply(ast.literal_eval)
        
        def count_pos_neg_values(lst):
            count_pos = sum(1 for x in lst if x > 0)
            count_neg = sum(1 for x in lst if x <= 0)
            return count_pos, count_neg

        # Apply the function to create new columns
        self.data[['count_positive', 'count_negative']] = self.data['scores'].apply(count_pos_neg_values).apply(pd.Series)
        self.data['len'] = self.data['scores'].apply(lambda x: len(x))
        
        self.data['neg_to_pos_ratio'] = self.data['count_negative'] / self.data['count_positive']
        self.max_seq_len = max_seq_len
 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
   
        if idx < 0 or idx >= len(self.data):
            raise IndexError("Index out of range")
        
        row = self.data.iloc[idx]
        code = row['ID']
        seq = row['sequence']
        scores = row['scores']
        
        y  = scores[:self.max_seq_len] + [0] * (self.max_seq_len - len(scores))
        y = torch.tensor(y)

        y_bin =  (y>0).int()

        # Generate binary mask based on sequence length (1 for actual values, 0 for padding)
        mask = torch.zeros(self.max_seq_len, dtype=torch.bool)
        mask[:len(scores)] = True  # Set the first 'len(Hchain_scores)' to 1

        return {
            'code': code,
            'seq': seq,
            'target_reg': y,
            'target_bin': y_bin,
            'mask': mask
        }

# model


In [13]:

# max length of the sequence set to 1000
SEQ_MAX_LEN = 1000

# 21 amino acids + 7 meiler features
INPUT_FEATURES = 28

# kernel size as per Parapred
KERNEL_SIZE = 11

# hidden output chanel of CNN
HIDDEN_CHANNELS = 256



def generate_mask(input_tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
    """
    Generate a mask for masked 1D convolution based on a binary mask, including non-consecutive valid positions.

    :param input_tensor: an input tensor for convolution (batch_size x features x max_seqlen)
    :param masks: a binary mask (batch_size x max_seqlen) indicating valid positions
    :return: mask (batch_size x features x max_seqlen)
    """
    batch_size, channels, max_seqlen = input_tensor.shape

    # Expand the binary mask to match the input tensor shape (batch_size x features x max_seqlen)
    conv_mask = masks.unsqueeze(1).expand(batch_size, channels, max_seqlen)

    return conv_mask.to(device=input_tensor.device)

class LocalExtractorBlock(nn.Module):
    def __init__(self,
                 input_dim: int = SEQ_MAX_LEN,
                 output_dim: int = SEQ_MAX_LEN,
                 in_channel: int = INPUT_FEATURES,
                 out_channel: Optional[int] = None,
                 kernel_size: int = KERNEL_SIZE,
                 dilation: int = 1,
                 stride: int = 1):
        
        super().__init__()

        # Assert same shape
        self.input_dim = input_dim
        self.output_dim = input_dim if output_dim is None else output_dim

        self.in_channels = in_channel
        self.out_channel = in_channel if out_channel is None else out_channel


        # Determine the padding required for keeping the same sequence length
        assert dilation >= 1 and stride >= 1, "Dilation and stride must be >= 1."
        self.dilation, self.stride = dilation, stride
        self.kernel_size = kernel_size

        padding = self.determine_padding(self.input_dim, self.output_dim)

        self.conv = nn.Conv1d(
            in_channel,
            out_channel,
            self.kernel_size,
            padding=padding)

        self.BN = nn.BatchNorm1d(out_channel)
        self.leakyrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, input_tensor: torch.Tensor, binary_mask) -> torch.Tensor:
        """
        Forward pass of the LocalExtractorBlock

        :param input_tensor: an input tensor of (bsz x features x seqlen) or (bsz x  x seqlen)
        :param mask: a boolean tensor of (bsz x 1 x seqlen)
        :return:
        """
        # self.conv.weight = self.conv.weight.to(input_tensor.device)
        # self.conv.bias = self.conv.bias.to(input_tensor.device) if self.conv.bias is not None else None
    
        o = self.conv(input_tensor)
        o = self.BN(o)
        o = self.leakyrelu(o)
        o = self.dropout(o)

        #mask to zero-out values beyond the sequence length
        mask = generate_mask(o, binary_mask)

        return o * mask
    
    def determine_padding(self, input_shape: int, output_shape: int) -> int:
        """
        Determine the padding required to keep the same length of the sequence before and after convolution.

        formula :  L_out = ((L_in + 2 x padding - dilation x (kernel_size - 1) - 1)/stride + 1)

        :return: An integer defining the amount of padding required to keep the "same" padding effect
        """
        padding = (((output_shape - 1) * self.stride) + 1 - input_shape + (self.dilation * (self.kernel_size - 1))) // 2

        # Ensure padding is non-negative and output shape is consistent
        assert padding >= 0, f"Padding must be non-negative but got {padding}."
        return padding
    

In [14]:

def generate_attn_mask(batch_size, num_heads, max_length, masks):
    """
    Generate an attention mask from a provided binary mask.

    :param batch_size: int, size of the batch.
    :param num_heads: int, number of attention heads.
    :param max_length: int, maximum sequence length.
    :param masks: a binary mask (batch_size x max_length) indicating the valid positions.
    :return: expanded mask for multi-head attention (batch_size * num_heads x max_length x max_length)
    """
    # Initialize a 3D attention mask (batch_size x max_length x max_length)
    attn_mask = torch.zeros((batch_size, max_length, max_length), dtype=torch.bool)
    

    # Populate the attention mask based on the input binary masks
    for i, mask in enumerate(masks):
        # Use the binary mask to determine valid positions
        attn_mask[i] = torch.outer(mask, mask)
        attn_mask[i].fill_diagonal_(True)
    
    # Expand the mask for multiple attention heads
    attn_mask = attn_mask.unsqueeze(1).expand(-1, num_heads, -1, -1)
    
    # Reshape to merge batch and head dimensions
    attn_mask = attn_mask.reshape(batch_size * num_heads, max_length, max_length)

    return attn_mask

class Att_BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=True, rnn_dropout=0.2, num_heads=1):
        super(Att_BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_heads = num_heads

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                            bidirectional=bidirectional, batch_first=True, dropout=rnn_dropout if num_layers > 1 else 0)
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size * 2 if bidirectional else hidden_size, num_heads=num_heads, batch_first=True)
    
    def forward(self, x, banary_mask):
        """
        Forward pass through BiLSTM with Multi-Head Attention
        """
        # Packed sequences are not necessary since we are using a mask.
        h0 = torch.randn(2 * self.num_layers if self.bidirectional else self.num_layers,
                        x.size(0), self.hidden_size).to(x.device)
        c0 = torch.randn(2 * self.num_layers if self.bidirectional else self.num_layers,
                        x.size(0), self.hidden_size).to(x.device)

        # Forward pass through LSTM
        output, (hn, cn) = self.lstm(x, (h0, c0))

        # Apply MultiHeadAttention
        mask = generate_attn_mask(x.size(0),self.num_heads, x.size(1), banary_mask).to(device=output.device)
        attn_output, attn_weight = self.attention(output, output, output, attn_mask=~mask)

        # mask = generate_attn_mask(x.size(0),self.num_heads, x.size(1), lengths).to(x.device)    #(batch_size, max_length, max_length)
        # # attn_output, attn_weight = self.attention(output, output, output)
        # attn_output, attn_weight = self.attention(output, output, output, attn_mask=~mask)

        return attn_output, (hn, cn)
    
  
    
class GlobalInformationExtractor(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=True, rnn_dropout=0.2, num_heads=1):
        super(GlobalInformationExtractor, self).__init__()
        self.att_bilstm = Att_BiLSTM(input_size, hidden_size, num_layers, bidirectional, rnn_dropout, num_heads)
        self.relu = nn.ReLU()
        self.leakyrelu = nn.LeakyReLU(0.1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, lengths):
        output, (hn,cn) = self.att_bilstm(x, lengths)
        output = self.leakyrelu(output)
        output = self.dropout(output)
        return output, (hn, cn)

In [15]:
class Aggrepred(nn.Module):
    """

    """
    def __init__(self, config):
        """
        Initialize the Aggrepred model using a configuration dictionary.

        :param config: A dictionary containing all the parameters for the model.
        """

        super().__init__()

        # Unpack the configuration dictionary
        # self.pooling = config.get(["pooling"], False)
        # self.use_local = config["use_local"]
        # # self.use_local = config.get("use_local", False)
        # self.use_global = config.get("use_global", False)

        self.pooling = config.get("pooling", False) 
        self.use_local = config.get("use_local", True)  # Default to False if not in config
        self.use_global = config.get("use_global", True)
        
        num_localextractor_block = config.get("num_localextractor_block", 3)
        input_dim = config.get("input_dim", 1000)
        output_dim = config.get("output_dim", 1000)
        in_channel = config.get("in_channel", 28)
        out_channel = config.get("out_channel", None)
        kernel_size = config.get("kernel_size", 23)
        dilation = config.get("dilation", 1)
        stride = config.get("stride", 1)
        
        rnn_hid_dim = config.get("rnn_hid_dim", 256)
        rnn_layers = config.get("rnn_layers", 1)
        bidirectional = config.get("bidirectional", True)
        rnn_dropout = config.get("rnn_dropout", 0.2)
        attention_heads = config.get("attention_heads", 1)

        # assert self.use_local or self.use_global, "At least one of the local or global information extractor must be used."

        out_channel = in_channel if out_channel is None else out_channel
        
        if self.use_local:
            assert num_localextractor_block > 0, "Number of local extractor blocks must be greater than 0."
            self.local_extractors = nn.ModuleList([
                LocalExtractorBlock(
                    input_dim=input_dim,
                    output_dim=output_dim,
                    in_channel=in_channel if i == 0 else out_channel,
                    out_channel=out_channel,
                    kernel_size=kernel_size,
                    dilation = dilation,
                    stride = stride
                ) for i in range(num_localextractor_block)
            ])
            self.residue_map = nn.Linear(in_channel,out_channel)
        
        if self.use_global:
            self.global_extractor = GlobalInformationExtractor(input_size=in_channel, hidden_size=rnn_hid_dim, num_layers=rnn_layers, bidirectional=bidirectional, rnn_dropout=rnn_dropout, num_heads=attention_heads)

        
        rnn_hid_dim = rnn_hid_dim * 2 if bidirectional else rnn_hid_dim

        # if self.use_local:
        #     fc_in_dim = out_channel + rnn_hid_dim  if self.use_global else out_channel
        # else:
        #     fc_in_dim = rnn_hid_dim 

        if self.use_local and self.use_global:
            fc_in_dim = out_channel + rnn_hid_dim
        elif self.use_local:
            fc_in_dim = out_channel
        elif self.use_global:
            fc_in_dim = rnn_hid_dim
        else:
            fc_in_dim = in_channel

        # if self.pooling:
        #     self.downproject = nn.Linear(fc_in_dim, 4)
        #     fc_in_dim = 4* input_dim
        
        self.reg_layer = nn.Sequential(
            nn.Linear(fc_in_dim, 256),
            nn.LeakyReLU(0.1),
            # nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.1),
            # nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1)
        )
 

    def forward(self, input_tensor: torch.Tensor, binary_mask) -> torch.Tensor:
        """
        
        """

        #### Local Extracted Information
        # residual connection following 3 layers of local extractor blocks 

        if self.use_local:
            residue = self.residue_map(input_tensor)
            o = input_tensor.permute(0, 2, 1)
            residue = residue.permute(0, 2, 1)
            for extractor in self.local_extractors:
                o = extractor(o, binary_mask)
                o = o + residue
                residue = o

            # print('in',input_tensor.size())
            # print('out:',o.size())
        

            local_extracted_info = o.permute(0, 2, 1)

            # print('local:',local_extracted_info.size())


        #### Local Extracted Information
        if self.use_global:
            global_extracted_info, (hn,cn) = self.global_extractor(input_tensor, binary_mask)

        
        # print('global:',global_extracted_info.size())

        # # Concatenate the local and global information
        # if self.use_local:
        #     final_info = torch.cat((local_extracted_info,global_extracted_info), dim=-1) if self.use_global else local_extracted_info
        # else:
        #     final_info = global_extracted_info

        if self.use_local and self.use_global:
            final_info = torch.cat((local_extracted_info, global_extracted_info), dim=-1)
        elif self.use_local:
            final_info = local_extracted_info
        elif self.use_global:
            final_info = global_extracted_info
        else:
            final_info = input_tensor

        reg_output = self.reg_layer(final_info)
        
        return final_info, reg_output
  

def clean_output(output_tensor: torch.Tensor, binary_mask: torch.Tensor) -> torch.Tensor:
    """
    Clean the output tensor of probabilities to remove the predictions for padded positions using a binary mask.

    :param output_tensor: output from the Parapred model; shape: (max_length x 1)
    :param binary_mask: binary mask for the sequence; shape: (max_length, ), where True indicates valid positions.

    :return: cleaned output tensor; shape: (sum(binary_mask), )
    """
    # Use the binary mask to filter out the padded positions
    return output_tensor[binary_mask].view(-1)

def clean_output_batch(output_tensor: torch.Tensor, binary_masks: torch.Tensor) -> torch.Tensor:
    """
    Clean the output tensor of probabilities to remove the predictions for padded positions in a batch using binary masks.

    :param output_tensor: output from the Parapred model; shape: (batch_size, max_length, 1)
    :param binary_masks: binary masks for the sequences; shape: (batch_size, max_length), where True indicates valid positions.

    :return: cleaned output tensor; shape: (sum of valid positions across the batch, )
    """
    batch_size, max_length, _ = output_tensor.shape
    cleaned_outputs = []

    # Loop over each sequence in the batch
    for i in range(batch_size):
        # Use the binary mask to filter out the padded positions for each sequence
        cleaned_outputs.append(output_tensor[i][binary_masks[i]].view(-1))

    # Concatenate the cleaned outputs from all sequences
    return torch.cat(cleaned_outputs, dim=0)



# Trainer 

## config

In [16]:
# ----------------
# PARAM
# ----------------



# Define the configuration dictionary with all the model parameters
# path = "./weights/seq/(esm35M)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
path = "./weights/seq/(esm35M)_(combinedloss)_()/"
# path = "./weights/seq/(esm)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(protbert)_(combinedloss)_(none)/"
# path = "./weights/seq/()_(combinedloss)_(none)/"
# path = "./weights/seq/(onehot_meiler)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"


config = {
    "antibody": False,
    "use_local": False,
    # "use_local": True,
    "use_global": False,
    # "use_global": True,
    "num_localextractor_block": 3,
    "input_dim": 1000,
    "output_dim": 1000,
    "in_channel": 480,
    "out_channel": 256,
    "kernel_size": 23,
    "dilation": 1,
    "stride": 1,
    "rnn_hid_dim": 128,
    "rnn_layers": 1,
    "bidirectional": True,
    "rnn_dropout": 0.2,
    "attention_heads": 4,
    "learning_rate": 1e-3,
    "batch_size": 32,
    "nb_epochs": 20,
    "encode_mode" : 'esm'
}

# with open(path+'config.json', 'r') as json_file:
#     config = json.load(json_file)


# ----------------
#  MODEL 
# ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = Aggrepred(config)
model = model.to(device=device)


cuda


In [17]:

# ----------------
#   OPTIMIZER 
# ----------------
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
# optimizer = nn.optim.AdamW(model.parameters(), lr=learning_rate,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


# ----------------
# LOSS
# ----------------

class CombinedLoss(nn.Module):
    def __init__(self, lambda_reg=1.0, lambda_bin=1.0, pos_weight=None):
        super(CombinedLoss, self).__init__()
        self.lambda_reg = lambda_reg
        self.lambda_bin = lambda_bin
        self.mse_loss = nn.MSELoss()  # Regression Loss (MSE)
        
        if pos_weight is not None:
            # Binary Classification Loss (Weighted BCE with logits)
            self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        else:
            self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, outputs, regression_targets):
        # Calculate regression loss
        reg_loss = self.mse_loss(outputs, regression_targets)
        
        # Calculate binary classification loss
        # Convert regression output to binary labels (logits) for classification
        binary_targets = (regression_targets> 0).float()
        bin_loss = self.bce_loss(outputs, binary_targets)
        
        # Combined weighted loss
        total_loss = self.lambda_reg * reg_loss + self.lambda_bin * bin_loss
        return total_loss

mse_loss  = nn.MSELoss()
pos_class_weights = torch.Tensor([4.0]).to(device)
weighted_bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_class_weights)


combined_loss = CombinedLoss(lambda_reg=0.7, lambda_bin=0.3, pos_weight=4.0)


# ----------------
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params

trainable, non_trainable = count_parameters(model)
print(f"Number of trainable parameters: {trainable}")
print(f"Number of non-trainable parameters: {non_trainable}")

Number of trainable parameters: 156161
Number of non-trainable parameters: 0


In [18]:
# ----------------
# DATA
# ----------------

def custom_collate(batch):
    regs_tensor = [item['target_reg'] for item in batch]
    mask = [item['mask'] for item in batch]
    max_len = regs_tensor[0].size()[0]  #1000
    
    orig_lens = [item['mask'].sum() for item in batch]
    max_orig_len = min(max(orig_lens), max_len)  # Ensure max_orig_len is at most max_len
    
    # print(max_orig_len)
    # truncated_encoded_seqs = [item['encoded_seq'][:max_orig_len,:] for item in batch]
    codes = [item['code'] for item in batch]
    seqs = [item['seq'] for item in batch]
    truncated_regs_tensor = [item['target_reg'][ :max_orig_len] for item in batch]
    truncated_bins_tensor = [item['target_bin'][:max_orig_len] for item in batch]
    truncated_mask_tensor = [item['mask'][ :max_orig_len] for item in batch]
    
    # encoded_seqs_tensor = torch.stack(truncated_encoded_seqs)
    target_regs_tensor = torch.stack(truncated_regs_tensor)
    target_bins_tensor = torch.stack(truncated_bins_tensor)
    mask_tensor = torch.stack(truncated_mask_tensor)

    return {
        'code': codes,
        'seq': seqs,
        'target_reg': target_regs_tensor,
        'target_bin': target_bins_tensor,
        'mask': mask_tensor
    }



## Dataloader

In [19]:

#########################################################################
#########################################################################
df = pd.read_csv("../data/csv/data60_fixed_split.csv")


# ## smaple down abit for esm
if config['encode_mode'] not in ['onehot', 'onehot_meiler']:
    print("yes")
    train_dataset = SeqDataset(df[df.split=='train'].sample(frac=0.10, random_state=42),1000)
    valid_dataset = SeqDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42),1000)
    test_dataset = SeqDataset(df[df.split=='test'],1000)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
    valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
else:
    train_dataset = SeqDataset(df[df.split=='train'],1000)
    valid_dataset = SeqDataset(df[df.split=='valid'],1000)
    test_dataset = SeqDataset(df[df.split=='test'],1000)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)




# ##collate to flexible max len in batch
# train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)
# valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)

yes


In [21]:
# esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
# # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
# # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# esm_model = esm_model.eval()  # Set the model to evaluation mode

# protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
# protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')


# for idx, batch in enumerate(test_dataloader):
#     print(batch['mask'].size())
#     print(batch['mask'])
#     x = embed_protbert_batch(batch['seq'],protbert_model,protbert_tokenizer)
#     print(x.size())
#     x = embed_esm_batch(batch['seq'],esm_model, alphabet)
#     print(x.size())
#     if idx == 0:
#         break

In [22]:
# x = embed_esm_batch(['D'],esm_model, alphabet)
# print(x.size())
# x

### find the propostion of pos/neg class

#### there are about 80% of negative class vs 20% of positive class  , hence 4:1 ratio

In [23]:
sum_one = train_dataset.data['count_positive'].sum()
sum_zero = train_dataset.data['count_negative'].sum()
total = train_dataset.data['len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = valid_dataset.data['count_positive'].sum()
sum_zero = valid_dataset.data['count_negative'].sum()
total = valid_dataset.data['len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = test_dataset.data['count_positive'].sum()
sum_zero = test_dataset.data['count_negative'].sum()
total = test_dataset.data['len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

propotion of position and negative class:  0.1874447907578202 0.8125552092421798
ratio of position to negative class:  4.334904192093585
propotion of position and negative class:  0.19428551851724965 0.8057144814827504
ratio of position to negative class:  4.147064009874802
propotion of position and negative class:  0.1936254280065119 0.8063745719934882
ratio of position to negative class:  4.164610920660527


In [25]:
print(model)

Aggrepred(
  (reg_layer): Sequential(
    (0): Linear(in_features=480, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.1)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): LeakyReLU(negative_slope=0.1)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=128, out_features=1, bias=True)
  )
)


In [26]:

def format_time(seconds):
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{minutes}m {seconds} s" if minutes>0 else f"{seconds} s"

def train_epoch(model, optimizer, dataloader,loss_function, encode_mode='onehot_meiler', device = 'cuda', printEvery=100):
    
    model.train()
    total_loss = 0.0
    count_iter = 0
    start_time = time.time()
    epoch_start_time = start_time
    batch_size = dataloader.batch_size
    printEvery = printEvery // batch_size if batch_size else 100  # Adjust printEvery based on batch size

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')

    with tqdm(total=len(dataloader), desc='Training', unit='batch') as pbar:
        for idx, batch in enumerate(dataloader):
                 
            batch_sequences = batch['seq']
            mask = batch['mask'].to(device)

            
            ## different encoding here
            if encode_mode == 'esm':
                x = embed_esm_batch(batch_sequences, esm_model, alphabet).to(device)
            elif encode_mode == 'protbert':
                x = embed_protbert_batch(batch_sequences, protbert_model, protbert_tokenizer).to(device)
            elif encode_mode == 'onehot':
                x = onehot_encode_batch(batch_sequences,1000).to(device)
            else:
                x = onehot_meiler_encode_batch(batch_sequences,1000).to(device)


            ## convert (bsz,max_len) to  (bsz,max_len,1)
            y_reg = batch['target_reg'].unsqueeze(2).float().to(device)
            y_bin = batch['target_bin'].unsqueeze(2).float().to(device)

            
            # print("mask:",mask.size())
            # print("y:",y_reg.size())
            # print('x', x.size())
            
            y_reg = clean_output_batch(y_reg, mask)
            y_bin = clean_output_batch(y_bin, mask)

            


            ### predict 
            final_info, output_reg = model(x, mask)
            
            #trim out the padded part
            output_reg = clean_output_batch(output_reg, mask)
            # print(orig_len.sum())
            assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))

            # current_loss = reg_loss 
            current_loss = loss_function(output_reg, y_reg)

            # Backpropagation
            optimizer.zero_grad()
            current_loss.backward()

            optimizer.step()
            total_loss += current_loss.item()
            
            printEvery = int(1000/x.size(0))
            count_iter += 1
            if count_iter % printEvery == 0 or idx == len(dataloader) - 1:
                elapsed_time = time.time() - start_time
                remaining_time = (elapsed_time / count_iter) * (len(dataloader) - count_iter)
                print(f"Iteration: {count_iter}, Time: {format_time(elapsed_time)}, Remaining: {format_time(remaining_time)}, Training Loss: {total_loss / count_iter:.4f}")
                start_time = time.time()
            torch.cuda.empty_cache()
            pbar.update(1)

    epoch_time = time.time() - epoch_start_time
    print(f"==> Average Training loss: mse ={total_loss / len(dataloader)}")
    print(f"==> Epoch Training Time: {format_time(epoch_time)}")
    print(f"================================================================\n")
    return total_loss / len(dataloader)


def evaluate(model, dataloader,loss_function, encode_mode='onehot_meiler', device= 'cuda', mode='valid'):
    model.eval()
    total_loss = 0.0
    
    predictions = []
    targets = []
    binary_predictions = []
    binary_targets = []
    orig_lens = []

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')



    with torch.no_grad():
        for idx, batch in enumerate(dataloader):

            batch_sequences = batch['seq']
            mask = batch['mask'].to(device)

            ## different encoding here
            if encode_mode == 'esm':
                x = embed_esm_batch(batch_sequences, esm_model, alphabet).to(device)
            elif encode_mode == 'protbert':
                x = embed_protbert_batch(batch_sequences, protbert_model, protbert_tokenizer).to(device)
            elif encode_mode == 'onehot':
                x = onehot_encode_batch(batch_sequences).to(device)
            else:
                x = onehot_meiler_encode_batch(batch_sequences).to(device)

            ## convert (bsz,max_len) to  (bsz,max_len,1)
            y_reg = batch['target_reg'].unsqueeze(2).float().to(device)
            y_bin = batch['target_bin'].unsqueeze(2).float().to(device)
            
            y_reg = clean_output_batch(y_reg, mask)
            y_bin = clean_output_batch(y_bin, mask)

            ### predict 
            final_info, output_reg = model(x, mask)
            
            #trim out the padded part
            output_reg = clean_output_batch(output_reg, mask)
            # print(orig_len.sum())
            assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))

            # current_loss = reg_loss 
            current_loss = loss_function(output_reg, y_reg)
     
            total_loss += current_loss.item()

            #append to list of all preds
            predictions.append(output_reg.cpu().numpy())
            targets.append(y_reg.cpu().numpy())

            ################################################################################
            # y_bin = (y_reg.cpu().numpy() > 0.5).astype(int)
            # out_bin = (output_reg.cpu().numpy() > 0.5).astype(int)
            y_bin = (y_reg.cpu().numpy() > 0).astype(int)
            out_bin = (output_reg.cpu().numpy() > 0).astype(int)
            ################################################################################

            binary_predictions.append(out_bin)
            binary_targets.append(y_bin)

    # if mode == 'test':
    all_predictions = np.concatenate(predictions, axis=0).reshape(-1)
    all_targets = np.concatenate(targets, axis=0).reshape(-1)
    all_binary_predictions = np.concatenate(binary_predictions, axis=0).reshape(-1)
    all_binary_targets = np.concatenate(binary_targets, axis=0).reshape(-1)

    # Calculate overall metrics
    overall_mse = mean_squared_error(all_targets, all_predictions)
    overall_rmse = np.sqrt(overall_mse)
    overall_mae = mean_absolute_error(all_targets, all_predictions)
    overall_r2 = r2_score(all_targets, all_predictions)
    overall_pcc, _ = pearsonr(all_targets.flatten(), all_predictions.flatten())

    # Calculate binary classification metrics
    overall_accuracy = accuracy_score(all_binary_targets, all_binary_predictions)
    overall_precision = precision_score(all_binary_targets, all_binary_predictions)
    overall_recall = recall_score(all_binary_targets, all_binary_predictions)
    overall_f1 = f1_score(all_binary_targets, all_binary_predictions)
    overall_auc_roc = roc_auc_score(all_binary_targets, all_predictions)
    overall_auc_pr = average_precision_score(all_binary_targets, all_predictions)
    overall_mcc = matthews_corrcoef(all_binary_targets, all_binary_predictions)

    print(f"Overall Regression Metrics")
    print(f"MSE: {overall_mse:.4f}, RMSE: {overall_rmse:.4f}, MAE: {overall_mae:.4f}, R2: {overall_r2:.4f}, PCC: {overall_pcc:.4f}")

    print(f"Overall classification Metrics")
    print(f"Acc: {overall_accuracy:.4f}, Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1-Score: {overall_f1:.4f}, AUC-ROC: {overall_auc_roc:.4f}, AUC-PR: {overall_auc_pr:.4f}, MCC: {overall_mcc:.4f}")  
    
    metrics = {
        "Regression Metrics": {
            "MSE": round(float(overall_mse), 4),
            "RMSE": round(float(overall_rmse), 4),
            "MAE": round(float(overall_mae), 4),
            "R2": round(float(overall_r2), 4),
            "PCC": round(float(overall_pcc), 4)
        },
        "Classification Metrics": {
            "Accuracy": round(float(overall_accuracy), 4),
            "Precision": round(float(overall_precision), 4),
            "Recall": round(float(overall_recall), 4),
            "F1-Score": round(float(overall_f1), 4),
            "AUC-ROC": round(float(overall_auc_roc), 4),
            "AUC-PR": round(float(overall_auc_pr), 4),
            "MCC": round(float(overall_mcc), 4)
        }
    }


    return total_loss / len(dataloader),metrics, predictions, targets

def train_loop(model, optimizer, train_dataloader, valid_dataloader,loss_function, nb_epochs, encode_mode='onehot_meiler', device= 'cuda', save_directory='./weights/'):
    start_epoch = 1
    best_validation_loss = float('inf')
    early_stopping_counter = 0

    # Paths for saving losses and metrics
    loss_output_path = os.path.join(save_directory, 'losses.json')
    metric_output_path = os.path.join(save_directory, 'metrics.json')
    
    # Initialize lists for losses
    train_losses = []
    val_losses = []
    
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
        print(f'Created directory: {save_directory}')

    checkpoint_path = os.path.join(save_directory, 'model_last.pt')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        
        # Load losses from the losses.json file if it exists
        if os.path.exists(loss_output_path):
            with open(loss_output_path, 'r') as f:
                losses = json.load(f)
                train_losses = losses.get('train_losses', [])
                val_losses = losses.get('val_losses', [])
            print(f'Loaded losses from {loss_output_path}.')
            print(train_losses)
            print(val_losses)
        else:
            print(f'No losses file found at {loss_output_path}.')

    else:
        print('No checkpoint found. Starting from beginning.')
    
    model.to(device)


    # # Load existing losses if available
    # if os.path.exists(loss_output_path):
    #     with open(loss_output_path, 'r') as json_file:
    #         existing_losses = json.load(json_file)
    #         train_losses = existing_losses.get('train_losses', [])
    #         val_losses = existing_losses.get('val_losses', [])

    for epoch in range(start_epoch, nb_epochs + 1):
        print("==================================================================================")
        print(f'                            -----EPOCH {epoch}-----')
        print("==================================================================================")
        
        train_loss = train_epoch(model, optimizer, train_dataloader,loss_function, encode_mode ,device, printEvery=1000)
        train_losses.append(train_loss)
        
        # # **Print Gradients**
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         print(f'Gradient - {name}: {param.grad.norm()}')  # Prints the norm of gradients

        print("==========================VALIDATION===============================================")
        val_loss ,metrics, _ , _ = evaluate(model, valid_dataloader,loss_function,encode_mode, device)
        val_losses.append(val_loss)

        print(f'==> Epoch {epoch} - Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        if val_loss < best_validation_loss:
            early_stopping_counter = 0
            best_validation_loss = val_loss
            best_model_save_path = os.path.join(save_directory, 'model_best.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'validation_accuracy': val_loss,
            }, best_model_save_path)
            print('\n')
            print(f'Best model checkpoint saved to: {best_model_save_path}')

            # Save metrics of the best model
            with open(metric_output_path, 'w') as json_file:
                json.dump(metrics, json_file, indent=4)
        
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= 5:
                print("\n==> Early stopping triggered. No improvement in validation loss for 3 epochs.")
                break

        last_model_save_path = os.path.join(save_directory, 'model_last.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'validation_accuracy': val_loss,
        }, last_model_save_path)
        print(f'Last epoch model saved to: {last_model_save_path}')

        # Save updated losses to the JSON file
        losses = {
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        with open(loss_output_path, 'w') as json_file:
            json.dump(losses, json_file, indent=4)
        print(f'Losses updated and saved to: {loss_output_path}')
        
        print("==================================================================================\n")
    
        
        
    return

# train here

In [None]:
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), 'w') as json_file:
    json.dump(config, json_file, indent=4)


loss_function = combined_loss
# loss_function = mse_loss

train_loop(model,optimizer,train_dataloader,valid_dataloader,loss_function, 50, config['encode_mode'],device,path)

## test

In [29]:

def load_model_from_checkpoint(model, optimizer, checkpoint_path, device):
    """
    Loads the model and optimizer state from a checkpoint if it exists.
    
    Args:
    - model (torch.nn.Module): The model to load the state into.
    - optimizer (torch.optim.Optimizer): The optimizer to load the state into.
    - checkpoint_path (str): Path to the checkpoint file.
    - device (torch.device): Device to which the model should be moved.
    
    Returns:
    - start_epoch (int): The epoch to start training from.
    - best_validation_loss (float): The best validation loss recorded in the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        # print(f'Best validation loss: {best_validation_loss}')
    else:
        start_epoch = 0
        best_validation_loss = float('inf')  # Assuming lower is better for validation loss
        print('No checkpoint found.')
    
    model = model.to(device)
    return start_epoch, best_validation_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [31]:

# List of model paths
model_paths = [
    
    "./weights/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    "./weights/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
   

    "./weights/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    "./weights/seq/(esm35M)_(combinedloss)_(none)/",

    "./weights/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    "./weights/seq/(protbert)_(combinedloss)_(none)/"
    
    
]

for path in model_paths:
    # Load the config for the current model
    with open(path + 'config.json', 'r') as json_file:
        config = json.load(json_file)

    # Initialize the model
    model = Aggrepred(config)
    model = model.to(device=device)

    # Load the model weights from the checkpoint
    _, _ = load_model_from_checkpoint(model, optimizer, path + 'model_best.pt', device)

    # Evaluate the model
    loss, metrics, preds, tar = evaluate(model, test_dataloader,combined_loss, config['encode_mode'] ,device)

    # Save metrics of the best model
    with open(path + 'result.json', 'w') as json_file:
        json.dump(metrics, json_file, indent=4)

    print(f"Processed model in path: {path}")


Loaded checkpoint from ./weights/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/model_best.pt. Resuming from epoch 14




RuntimeError: The expanded size of the tensor (1000) must match the existing size (908) at non-singleton dimension 2.  Target sizes: [32, 256, 1000].  Tensor sizes: [32, 1, 908]

In [None]:
preds[10][:20]

In [None]:
tar[10][:20]