In [1]:
import torch
import pandas as pd
import sys
import os
import numpy as np
import random
import pickle
from tqdm import tqdm
import torch
import esm
import numpy as np
import os
import torch
import torch.nn as nn
import json
import torch.nn.functional as F
from scipy.stats import spearmanr

from torch.utils.data import DataLoader
from antiberty import AntiBERTyRunner


top_folder_path = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..'))
sys.path.insert(0, top_folder_path)

from aggrepred.dataset import *
from aggrepred.model import *
from aggrepred.utils import *


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

import torch
import torch.nn as nn

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from typing import Optional

import os
import torch
import time
import logging
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score,accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, matthews_corrcoef
from scipy.stats import pearsonr


  from .autonotebook import tqdm as notebook_tqdm


# Dataset loader


In [2]:

class AntibodySeqDataset:
    def __init__(self, df,  max_seq_len=700):
        self.data = df.copy()
        
        # Convert scores from string to list
        self.data['Hchain_scores'] = self.data['Hchain_scores'].apply(ast.literal_eval)
        self.data['Lchain_scores'] = self.data['Lchain_scores'].apply(ast.literal_eval)

        # Calculate positive and negative counts for heavy and light chains
        self.data['Hchain_count_positive'] = self.data['Hchain_scores'].apply(lambda x: sum(1 for score in x if score > 0))
        self.data['Hchain_count_negative'] = self.data['Hchain_scores'].apply(lambda x: sum(1 for score in x if score <= 0))
        self.data['Lchain_count_positive'] = self.data['Lchain_scores'].apply(lambda x: sum(1 for score in x if score > 0))
        self.data['Lchain_count_negative'] = self.data['Lchain_scores'].apply(lambda x: sum(1 for score in x if score <= 0))

        # Compute lengths of heavy and light chains
        self.data['Hchain_len'] = self.data['Hchain_scores'].apply(len)
        self.data['Lchain_len'] = self.data['Lchain_scores'].apply(len)

        # Compute negative-to-positive ratio for heavy and light chains
        self.data['Hchain_neg_to_pos_ratio'] = self.data['Hchain_count_negative'] / self.data['Hchain_count_positive']
        self.data['Lchain_neg_to_pos_ratio'] = self.data['Lchain_count_negative'] / self.data['Lchain_count_positive']

        # Set max sequence length and scaling flag
        self.max_seq_len = max_seq_len
    

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self.data):
            raise IndexError("Index out of range")

        
        row = self.data.iloc[idx]
        code = row['ID']
        H_seq = row['Hchain_sequence']
        L_seq = row['Lchain_sequence']
        Hchain_scores = row['Hchain_scores']
        Lchain_scores = row['Lchain_scores']

        # Prepare target vectors for heavy chain
        H_y = Hchain_scores[:self.max_seq_len] + [0] * (450 - len(Hchain_scores))
        H_y = torch.tensor(H_y)

        # Prepare target vectors for light chain
        L_y = Lchain_scores[:self.max_seq_len] + [0] * (250 - len(Lchain_scores))
        L_y = torch.tensor(L_y)

        H_y_bin = (H_y > 0).int()
        L_y_bin = (L_y > 0).int()

        # Generate binary mask based on sequence length (1 for actual values, 0 for padding)
        H_mask = torch.zeros(450, dtype=torch.bool)
        H_mask[:len(Hchain_scores)] = True  # Set the first 'len(Hchain_scores)' to 1

        L_mask = torch.zeros(250, dtype=torch.bool)
        L_mask[:len(Lchain_scores)] = True  # Set the first 'len(Lchain_scores)' to 1

        return {
            'code': code,
            'H_seq': H_seq,
            'H_target_reg': H_y,
            'H_target_bin': H_y_bin,
            'H_mask': H_mask,

            'L_seq': L_seq,
            'L_target_reg': L_y,
            'L_target_bin': L_y_bin,
            'L_mask': L_mask  }
        


In [3]:
# df = pd.read_csv("/users/eleves-b/2023/ly-an.chhay/main/data/pisces/tm.csv")
# df

# train_dataset = AntibodySeqDataset(df[df.split=='train'],700, True)

In [4]:
# train_dataset[0]

In [5]:
def onehot_encode(sequence: str, max_length: int = 1000) -> torch.Tensor:
    """
    One-hot encode an amino acid sequence

    :param sequence:   protein sequence
    :param max_length: specify the maximum length for protein sequence to use, it helps to have have size in batches

    :return: max_length x num_features tensor
    """
    
    seqlen = len(sequence)
    use_length = max_length
    # use_length = min(seqlen,max_length) #variable to len of seq, if want fix_size like 1024, set it to fix
    encoded = torch.zeros((use_length, NUM_AMINOS))
    for i in range(min(seqlen, max_length)):
        aa = sequence[i]
        encoded[i][aa2idx.get(aa, NUM_AMINOS-1)] = 1
    return encoded

def onehot_encode_batch(sequences: list, max_length: int = 1000) -> torch.Tensor:
    """
    One-hot encode an amino acid sequence

    :param sequence:   protein sequence
    :param max_length: specify the maximum length for protein sequence to use, it helps to have have size in batches

    :return: max_length x num_features tensor
    """
    batch_size = len(sequences)
    batch_encoded = torch.zeros((batch_size, max_length, NUM_AMINOS))
    for i, seq in enumerate(sequences):
        batch_encoded[i] = onehot_encode(seq, max_length)
    return batch_encoded

def onehot_meiler_encode(sequence: str, max_length: int = 1000) -> torch.Tensor:
    """
    One-hot encode an amino acid sequence, then concatenate with Meiler features.

    :param sequence:   protein sequence
    :param max_length: specify the maximum length for protein sequence to use, it helps to have have size in batches

    :return: max_length x num_features tensor
    """
    
    seqlen = len(sequence)
    use_length = max_length
    # use_length = min(seqlen,max_length) #variable to len of seq, if want fix_size like 1024, set it to fix
    encoded = torch.zeros((use_length, NUM_AMINOS+ NUM_MEILER))
    for i in range(min(seqlen, max_length)):
        aa = sequence[i]
        encoded[i][aa2idx.get(aa, NUM_AMINOS-1)] = 1
        encoded[i][-NUM_MEILER:] = MEILER[aa] if aa in MEILER else MEILER["X"]
    return encoded


def onehot_meiler_encode_batch(sequences: list, max_length: int = 1000) -> torch.Tensor:
    """
    One-hot encode an amino acid sequence, then concatenate with Meiler features.

    :param sequence:   protein sequence
    :param max_length: specify the maximum length for protein sequence to use, it helps to have have size in batches

    :return: max_length x num_features tensor
    """
    batch_size = len(sequences)
    batch_encoded = torch.zeros((batch_size, max_length, NUM_AMINOS+ NUM_MEILER))
    for i, seq in enumerate(sequences):
        batch_encoded[i] = onehot_meiler_encode(seq, max_length)
    return batch_encoded


## model


In [6]:

# max length of the sequence set to 1000
SEQ_MAX_LEN = 1000

# 21 amino acids + 7 meiler features
INPUT_FEATURES = 28

# kernel size as per Parapred
KERNEL_SIZE = 11

# hidden output chanel of CNN
HIDDEN_CHANNELS = 256



def generate_mask(input_tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
    """
    Generate a mask for masked 1D convolution based on a binary mask, including non-consecutive valid positions.

    :param input_tensor: an input tensor for convolution (batch_size x features x max_seqlen)
    :param masks: a binary mask (batch_size x max_seqlen) indicating valid positions
    :return: mask (batch_size x features x max_seqlen)
    """
    batch_size, channels, max_seqlen = input_tensor.shape

    # Expand the binary mask to match the input tensor shape (batch_size x features x max_seqlen)
    conv_mask = masks.unsqueeze(1).expand(batch_size, channels, max_seqlen)

    return conv_mask.to(device=input_tensor.device)

class LocalExtractorBlock(nn.Module):
    def __init__(self,
                 input_dim: int = SEQ_MAX_LEN,
                 output_dim: int = SEQ_MAX_LEN,
                 in_channel: int = INPUT_FEATURES,
                 out_channel: Optional[int] = None,
                 kernel_size: int = KERNEL_SIZE,
                 dilation: int = 1,
                 stride: int = 1):
        
        super().__init__()

        # Assert same shape
        self.input_dim = input_dim
        self.output_dim = input_dim if output_dim is None else output_dim

        self.in_channels = in_channel
        self.out_channel = in_channel if out_channel is None else out_channel


        # Determine the padding required for keeping the same sequence length
        assert dilation >= 1 and stride >= 1, "Dilation and stride must be >= 1."
        self.dilation, self.stride = dilation, stride
        self.kernel_size = kernel_size

        padding = self.determine_padding(self.input_dim, self.output_dim)

        self.conv = nn.Conv1d(
            in_channel,
            out_channel,
            self.kernel_size,
            padding=padding)

        self.BN = nn.BatchNorm1d(out_channel)
        self.leakyrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, input_tensor: torch.Tensor, binary_mask) -> torch.Tensor:
        """
        Forward pass of the LocalExtractorBlock

        :param input_tensor: an input tensor of (bsz x features x seqlen) or (bsz x  x seqlen)
        :param mask: a boolean tensor of (bsz x 1 x seqlen)
        :return:
        """
        # self.conv.weight = self.conv.weight.to(input_tensor.device)
        # self.conv.bias = self.conv.bias.to(input_tensor.device) if self.conv.bias is not None else None
    
        o = self.conv(input_tensor)
        o = self.BN(o)
        o = self.leakyrelu(o)
        o = self.dropout(o)

        #mask to zero-out values beyond the sequence length
        mask = generate_mask(o, binary_mask)

        return o * mask
    
    def determine_padding(self, input_shape: int, output_shape: int) -> int:
        """
        Determine the padding required to keep the same length of the sequence before and after convolution.

        formula :  L_out = ((L_in + 2 x padding - dilation x (kernel_size - 1) - 1)/stride + 1)

        :return: An integer defining the amount of padding required to keep the "same" padding effect
        """
        padding = (((output_shape - 1) * self.stride) + 1 - input_shape + (self.dilation * (self.kernel_size - 1))) // 2

        # Ensure padding is non-negative and output shape is consistent
        assert padding >= 0, f"Padding must be non-negative but got {padding}."
        return padding
    

In [7]:


def generate_attn_mask(batch_size, num_heads, max_length, masks):
    """
    Generate an attention mask from a provided binary mask.

    :param batch_size: int, size of the batch.
    :param num_heads: int, number of attention heads.
    :param max_length: int, maximum sequence length.
    :param masks: a binary mask (batch_size x max_length) indicating the valid positions.
    :return: expanded mask for multi-head attention (batch_size * num_heads x max_length x max_length)
    """
    # Initialize a 3D attention mask (batch_size x max_length x max_length)
    attn_mask = torch.zeros((batch_size, max_length, max_length), dtype=torch.bool)
    

    # Populate the attention mask based on the input binary masks
    for i, mask in enumerate(masks):
        # Use the binary mask to determine valid positions
        attn_mask[i] = torch.outer(mask, mask)
        attn_mask[i].fill_diagonal_(True)
    
    # Expand the mask for multiple attention heads
    attn_mask = attn_mask.unsqueeze(1).expand(-1, num_heads, -1, -1)
    
    # Reshape to merge batch and head dimensions
    attn_mask = attn_mask.reshape(batch_size * num_heads, max_length, max_length)

    return attn_mask

class Att_BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=True, rnn_dropout=0.2, num_heads=1):
        super(Att_BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_heads = num_heads

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                            bidirectional=bidirectional, batch_first=True, dropout=rnn_dropout if num_layers > 1 else 0)
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size * 2 if bidirectional else hidden_size, num_heads=num_heads, batch_first=True)
    
    def forward(self, x, banary_mask):
        """
        Forward pass through BiLSTM with Multi-Head Attention
        """
        # Packed sequences are not necessary since we are using a mask.
        h0 = torch.randn(2 * self.num_layers if self.bidirectional else self.num_layers,
                        x.size(0), self.hidden_size).to(x.device)
        c0 = torch.randn(2 * self.num_layers if self.bidirectional else self.num_layers,
                        x.size(0), self.hidden_size).to(x.device)

        # Forward pass through LSTM
        output, (hn, cn) = self.lstm(x, (h0, c0))

        # Apply MultiHeadAttention
        mask = generate_attn_mask(x.size(0),self.num_heads, x.size(1), banary_mask).to(device=output.device)
        attn_output, attn_weight = self.attention(output, output, output, attn_mask=~mask)

        # mask = generate_attn_mask(x.size(0),self.num_heads, x.size(1), lengths).to(x.device)    #(batch_size, max_length, max_length)
        # # attn_output, attn_weight = self.attention(output, output, output)
        # attn_output, attn_weight = self.attention(output, output, output, attn_mask=~mask)

        return attn_output, (hn, cn)
    
  
    
class GlobalInformationExtractor(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=True, rnn_dropout=0.2, num_heads=1):
        super(GlobalInformationExtractor, self).__init__()
        self.att_bilstm = Att_BiLSTM(input_size, hidden_size, num_layers, bidirectional, rnn_dropout, num_heads)
        self.relu = nn.ReLU()
        self.leakyrelu = nn.LeakyReLU(0.1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, lengths):
        output, (hn,cn) = self.att_bilstm(x, lengths)
        output = self.leakyrelu(output)
        output = self.dropout(output)
        return output, (hn, cn)

In [8]:
class Aggrepred(nn.Module):
    """

    """
    def __init__(self, config):
        """
        Initialize the Aggrepred model using a configuration dictionary.

        :param config: A dictionary containing all the parameters for the model.
        """

        super().__init__()

        # Unpack the configuration dictionary
        # self.pooling = config.get(["pooling"], False)
        # self.use_local = config["use_local"]
        # # self.use_local = config.get("use_local", False)
        # self.use_global = config.get("use_global", False)

        self.pooling = config.get("pooling", False) 
        self.use_local = config.get("use_local", True)  # Default to False if not in config
        self.use_global = config.get("use_global", True)
        
        num_localextractor_block = config.get("num_localextractor_block", 3)
        input_dim = config.get("input_dim", 1000)
        output_dim = config.get("output_dim", 1000)
        in_channel = config.get("in_channel", 28)
        out_channel = config.get("out_channel", None)
        kernel_size = config.get("kernel_size", 23)
        dilation = config.get("dilation", 1)
        stride = config.get("stride", 1)
        
        rnn_hid_dim = config.get("rnn_hid_dim", 256)
        rnn_layers = config.get("rnn_layers", 1)
        bidirectional = config.get("bidirectional", True)
        rnn_dropout = config.get("rnn_dropout", 0.2)
        attention_heads = config.get("attention_heads", 1)

        # assert self.use_local or self.use_global, "At least one of the local or global information extractor must be used."

        out_channel = in_channel if out_channel is None else out_channel
        
        if self.use_local:
            assert num_localextractor_block > 0, "Number of local extractor blocks must be greater than 0."
            self.local_extractors = nn.ModuleList([
                LocalExtractorBlock(
                    input_dim=input_dim,
                    output_dim=output_dim,
                    in_channel=in_channel if i == 0 else out_channel,
                    out_channel=out_channel,
                    kernel_size=kernel_size,
                    dilation = dilation,
                    stride = stride
                ) for i in range(num_localextractor_block)
            ])
            self.residue_map = nn.Linear(in_channel,out_channel)
        
        if self.use_global:
            self.global_extractor = GlobalInformationExtractor(input_size=in_channel, hidden_size=rnn_hid_dim, num_layers=rnn_layers, bidirectional=bidirectional, rnn_dropout=rnn_dropout, num_heads=attention_heads)

        
        rnn_hid_dim = rnn_hid_dim * 2 if bidirectional else rnn_hid_dim

        # if self.use_local:
        #     fc_in_dim = out_channel + rnn_hid_dim  if self.use_global else out_channel
        # else:
        #     fc_in_dim = rnn_hid_dim 

        if self.use_local and self.use_global:
            fc_in_dim = out_channel + rnn_hid_dim
        elif self.use_local:
            fc_in_dim = out_channel
        elif self.use_global:
            fc_in_dim = rnn_hid_dim
        else:
            fc_in_dim = in_channel

        # if self.pooling:
        #     self.downproject = nn.Linear(fc_in_dim, 4)
        #     fc_in_dim = 4* input_dim
        
        self.reg_layer = nn.Sequential(
            nn.Linear(fc_in_dim, fc_in_dim//2),
            nn.LeakyReLU(0.1),
            # nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(fc_in_dim//2, fc_in_dim//4),
            nn.LeakyReLU(0.1),
            # nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(fc_in_dim//4, 1)
        )

        # self.tm_reg_layer = nn.Sequential(
        #     nn.Linear(fc_in_dim, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 1)
        # )


    def forward(self, input_tensor: torch.Tensor, binary_mask) -> torch.Tensor:
        """
        
        """

        #### Local Extracted Information
        # residual connection following 3 layers of local extractor blocks 

        #might use skip connection

        if self.use_local:
            residue = self.residue_map(input_tensor)
            o = input_tensor.permute(0, 2, 1)
            residue = residue.permute(0, 2, 1)
            for extractor in self.local_extractors:
                o = extractor(o, binary_mask)
                o = o + residue
                residue = o

            # print('in',input_tensor.size())
            # print('out:',o.size())
        

            local_extracted_info = o.permute(0, 2, 1)

            # print('local:',local_extracted_info.size())


        #### Local Extracted Information
        if self.use_global:
            global_extracted_info, (hn,cn) = self.global_extractor(input_tensor, binary_mask)

        
        # print('global:',global_extracted_info.size())

        # # Concatenate the local and global information
        # if self.use_local:
        #     final_info = torch.cat((local_extracted_info,global_extracted_info), dim=-1) if self.use_global else local_extracted_info
        # else:
        #     final_info = global_extracted_info

        if self.use_local and self.use_global:
            final_info = torch.cat((local_extracted_info, global_extracted_info), dim=-1)
        elif self.use_local:
            final_info = local_extracted_info
        elif self.use_global:
            final_info = global_extracted_info
        else:
            final_info = input_tensor

        

        # #### Pooling for protein-level prediction
        # if self.pooling:
        #     expanded_mask = binary_mask.unsqueeze(-1).float()  # (batch_size, max_length, 1)


        #     # Apply the mask to the output
        #     masked_output = final_info * expanded_mask  # (batch_size, max_length, feature_dim)
           

        #     downproject = self.downproject(masked_output)
        #     final_info = downproject.view(downproject.size(0), -1) 

        #     print("final info suze" ,final_info.size())

        # print('concat:',final_info.size())

        reg_output = self.reg_layer(final_info)
        
        return final_info, reg_output
  



def clean_output(output_tensor: torch.Tensor, binary_mask: torch.Tensor) -> torch.Tensor:
    """
    Clean the output tensor of probabilities to remove the predictions for padded positions using a binary mask.

    :param output_tensor: output from the Parapred model; shape: (max_length x 1)
    :param binary_mask: binary mask for the sequence; shape: (max_length, ), where True indicates valid positions.

    :return: cleaned output tensor; shape: (sum(binary_mask), )
    """
    # Use the binary mask to filter out the padded positions
    return output_tensor[binary_mask].view(-1)

def clean_output_batch(output_tensor: torch.Tensor, binary_masks: torch.Tensor) -> torch.Tensor:
    """
    Clean the output tensor of probabilities to remove the predictions for padded positions in a batch using binary masks.

    :param output_tensor: output from the Parapred model; shape: (batch_size, max_length, 1)
    :param binary_masks: binary masks for the sequences; shape: (batch_size, max_length), where True indicates valid positions.

    :return: cleaned output tensor; shape: (sum of valid positions across the batch, )
    """
    batch_size, max_length, _ = output_tensor.shape
    cleaned_outputs = []

    # Loop over each sequence in the batch
    for i in range(batch_size):
        # Use the binary mask to filter out the padded positions for each sequence
        cleaned_outputs.append(output_tensor[i][binary_masks[i]].view(-1))

    # Concatenate the cleaned outputs from all sequences
    return torch.cat(cleaned_outputs, dim=0)



## Trainer 

In [9]:
# ----------------
# DATA
# ----------------

def custom_collate(batch):
    regs_tensor = [item['target_reg'] for item in batch]
    max_len = regs_tensor[0].size()[0]
    
    orig_lens = [item['orig_len'] for item in batch]
    max_orig_len = min(max(orig_lens), max_len)  # Ensure max_orig_len is at most max_len
    
    # print(max_orig_len)
    # truncated_encoded_seqs = [item['encoded_seq'][:max_orig_len,:] for item in batch]
    codes = [item['code'] for item in batch]
    seqs = [item['seq'] for item in batch]
    truncated_regs_tensor = [item['target_reg'][ :max_orig_len] for item in batch]
    truncated_bins_tensor = [item['target_bin'][:max_orig_len] for item in batch]
    
    # encoded_seqs_tensor = torch.stack(truncated_encoded_seqs)
    target_regs_tensor = torch.stack(truncated_regs_tensor)
    target_bins_tensor = torch.stack(truncated_bins_tensor)

    return {
        'code': codes,
        'seq': seqs,
        'target_reg': target_regs_tensor,
        'target_bin': target_bins_tensor,
        'orig_len': torch.tensor(orig_lens)
    }



### Protein

In [10]:

# #########################################################################
# #########################################################################
# df = pd.read_csv("/users/eleves-b/2023/ly-an.chhay/main/data/pisces/data60_fixed_split.csv")


# # train_dataset = SeqDataset(df[df.split=='train'],1000)
# # valid_dataset = SeqDataset(df[df.split=='valid'],1000)
# # test_dataset = SeqDataset(df[df.split=='test'],1000)

# ## smaple down abit for esm
# train_dataset = SeqDataset(df[df.split=='train'].sample(frac=0.10, random_state=42),1000)
# valid_dataset = SeqDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42),1000)
# test_dataset = SeqDataset(df[df.split=='test'],1000)

# # train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# # valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
# # test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# # ##collate to flexible max len in batch
# train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)
# valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)

### Antibody

In [11]:

def load_model_from_checkpoint(model, checkpoint_path, device):
    """
    Loads the model and optimizer state from a checkpoint if it exists.
    
    Args:
    - model (torch.nn.Module): The model to load the state into.
    - optimizer (torch.optim.Optimizer): The optimizer to load the state into.
    - checkpoint_path (str): Path to the checkpoint file.
    - device (torch.device): Device to which the model should be moved.
    
    Returns:
    - start_epoch (int): The epoch to start training from.
    - best_validation_loss (float): The best validation loss recorded in the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        # print(f'Best validation loss: {best_validation_loss}')
    else:
        start_epoch = 0
        best_validation_loss = float('inf')  # Assuming lower is better for validation loss
        print('No checkpoint found.')
    
    model = model.to(device)
    return start_epoch, best_validation_loss



In [14]:

#########################################################################
#########################################################################
df = pd.read_csv("../data/csv/antibody.csv")

# train_dataset = AntibodySeqDataset(df[df.split=='train'].sample(frac=0.10, random_state=42))
# valid_dataset = AntibodySeqDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42))
# test_dataset = AntibodySeqDataset(df[df.split=='test'])
train_dataset = AntibodySeqDataset(df[df.split=='train'])
valid_dataset = AntibodySeqDataset(df[df.split=='valid'])
test_dataset = AntibodySeqDataset(df[df.split=='test'])


train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)



#########################################################################
#########################################################################

# df = pd.read_csv("/users/eleves-b/2023/ly-an.chhay/main/data/pisces/tm.csv")

# # train_dataset = AntibodySeqDataset(df[df.split=='train'].sample(frac=0.10, random_state=42))
# # valid_dataset = AntibodySeqDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42))
# # test_dataset = AntibodySeqDataset(df[df.split=='test'])
# train_dataset = AntibodySeqDataset(df[df.split=='train'],700, True)
# valid_dataset = AntibodySeqDataset(df[df.split=='test'],700, True)
# test_dataset = AntibodySeqDataset(df[df.split=='holdout'],700, True)


# train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)



In [15]:
for idx, batch in enumerate(test_dataloader):
    print(batch)
    if idx == 0:
        break

{'code': ['4j8r', '7qny', '6al4', '6hkg', '7kpb', '7kqk', '5jor', '4hzl', '7lop', '2vh5', '7jwp', '5a16', '4hj0', '3thm', '7l7e', '8d3a', '6wo5', '4eow', '2gfb', '6k65', '1nsn', '6mfp', '8ef3', '8dn7', '7zf6', '6mi2', '6bli', '8bbo', '1osp', '6fg2', '6dwi', '6xm2'], 'H_seq': ['VKLQESGGEVVRPGTSVKVSCKASGYAFTNYLIEWVKQRPGQGLEWIGVINPGSGDTNYNEKFKGKATLTADKSSSTAYMQLNSLTSDDSAVYFCARSGAAAPTYYAMDYWGQGVSVTVSSAKTTPPSVYPLAPAAAAANSMVTLGCLVKGYFPEPVTVTWNSGSLSGGVHTFPAVLQSDLYTLSSSVTVPSSTWPSETVTCNVAHPASSTKVDKKIVPR', 'EVQLLESGGDLIQPGGSLRLSCAASGVTVSSNYMSWVRQAPGKGLEWVSIIYPGGSTFYADSVKGRFTISRDNSKNTLYLQMHSLRAEDTAVYYCARDLGSGDMDVWGKGTTVTVSSASTKGPSVFPLAPSSSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKS', 'VQLVQSGAEVKKPGSSVKVSCKASGYAFSSYWMNWVRQAPGQGLEWMGQIWPGDSDTNYAQKFQGRVTITADESTSTAYMELSSLRSEDTAVYYCARRETTTVGRYYYAMDYWGQGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPK', 'EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYGMAWVRQAPGKGLEWVSF

### find the propostion of pos/neg class

#### there are about 80% of negative class vs 20% of positive class  , hence 4:1 ratio

In [16]:
sum_one = train_dataset.data['Hchain_count_positive'].sum()
sum_zero = train_dataset.data['Hchain_count_negative'].sum()
total = train_dataset.data['Hchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = valid_dataset.data['Hchain_count_positive'].sum()
sum_zero = valid_dataset.data['Hchain_count_negative'].sum()
total = valid_dataset.data['Hchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = test_dataset.data['Hchain_count_positive'].sum()
sum_zero = test_dataset.data['Hchain_count_negative'].sum()
total = test_dataset.data['Hchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

propotion of position and negative class:  0.09087559688180555 0.9091244031181944
ratio of position to negative class:  10.00405427103404
propotion of position and negative class:  0.08837646378487783 0.9116235362151222
ratio of position to negative class:  10.315229838050056
propotion of position and negative class:  0.09189772296568413 0.9081022770343159
ratio of position to negative class:  9.881662436548224


In [17]:
sum_one = train_dataset.data['Lchain_count_positive'].sum()
sum_zero = train_dataset.data['Lchain_count_negative'].sum()
total = train_dataset.data['Lchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = valid_dataset.data['Lchain_count_positive'].sum()
sum_zero = valid_dataset.data['Lchain_count_negative'].sum()
total = valid_dataset.data['Lchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

sum_one = test_dataset.data['Lchain_count_positive'].sum()
sum_zero = test_dataset.data['Lchain_count_negative'].sum()
total = test_dataset.data['Lchain_len'].sum() 

print("propotion of position and negative class: " , sum_one/total, sum_zero/total)
print("ratio of position to negative class: " , sum_zero/sum_one)

propotion of position and negative class:  0.05469640667558992 0.94530359332441
ratio of position to negative class:  17.282736669176536
propotion of position and negative class:  0.05594240179772623 0.9440575982022738
ratio of position to negative class:  16.875528541226217
propotion of position and negative class:  0.054230682755002625 0.9457693172449974
ratio of position to negative class:  17.43974571586512


In [None]:
# ----------------
# PARAM
# ----------------


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Define the configuration dictionary with all the model parameters
# path = "./weights_antibody/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(antiberty)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/"
path = "./weights_antibody/seq/(antiberty)_(combinedloss)_(none)/"

# path = "./weights_antibody/seq/(onehot)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(esm35M)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights/seq/(esm)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path = "./weights_antibody/seq/(onehot_meiler)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"

# Define the configuration dictionary with all the model parameters
# path_old = "./weights/seq/(onehot)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path_old = "./weights/seq/(onehot_meiler)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path_old = "./weights/seq/(esm35M)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path_old = "./weights/seq/(esm)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
# path_old = "./weights/seq/(protbert)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"


#path to Tm
# path = "./weights_tm/seq/(antiberty)_(regloss)_()/"
# path = "./weights_tm/seq/(onehot_meiler)_(regloss)_(local_5block256dim)_(global_2layer128_4head)/"


config = {
    "pooling": False,
    "antibody": True,
    "use_local": False,
    # "use_local": True,
    "use_global": False,
    # "use_global": True,
    "num_localextractor_block": 5,
    "input_dim": 700,
    "output_dim": 700,
    "in_channel": 512,
    "out_channel": 256,
    "kernel_size": 23,
    "dilation": 1,
    "stride": 1,
    "rnn_hid_dim": 128,
    "rnn_layers": 2,
    "bidirectional": True,
    "rnn_dropout": 0.2,
    "attention_heads": 4,
    "learning_rate": 1e-4,
    "batch_size": 32,
    "nb_epochs": 20,
    "encode_mode" : 'antiberty'
}

# with open(path_old+'config.json', 'r') as json_file:
#     config = json.load(json_file)


# Initialize the model
model = Aggrepred(config)
model = model.to(device=device)

model

# Load the model weights from the checkpoint
# _, _ = load_model_from_checkpoint(model, path_old + 'model_best.pt', device)


In [None]:

# ----------------
#   OPTIMIZER 
# ----------------
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
# optimizer = nn.optim.AdamW(model.parameters(), lr=learning_rate,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


# ----------------
# LOSS
# ----------------

class CombinedLoss(nn.Module):
    def __init__(self, lambda_reg=1.0, lambda_bin=1.0, pos_weight=None):
        super(CombinedLoss, self).__init__()
        self.lambda_reg = lambda_reg
        self.lambda_bin = lambda_bin
        self.mse_loss = nn.MSELoss()  # Regression Loss (MSE)
        
        if pos_weight is not None:
            # Binary Classification Loss (Weighted BCE with logits)
            self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        else:
            self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, outputs, regression_targets):
        # Calculate regression loss
        reg_loss = self.mse_loss(outputs, regression_targets)
        
        # Calculate binary classification loss
        # Convert regression output to binary labels (logits) for classification
        binary_targets = (regression_targets> 0).float()
        bin_loss = self.bce_loss(outputs, binary_targets)
        
        # Combined weighted loss
        total_loss = self.lambda_reg * reg_loss + self.lambda_bin * bin_loss
        return total_loss

mse_loss  = nn.MSELoss()
# mse_loss  = nn.MSELoss(reduction='sum')

# bce_loss = nn.BCELoss()
# bce_loss = nn.BCELoss(weight=class_weights)

# class_weights = torch.Tensor([1.0, 12.0]).cuda()
pos_class_weights = torch.Tensor([4.0]).to(device)
weighted_bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_class_weights)


loss_fn = CombinedLoss(lambda_reg=0.7, lambda_bin=0.3, pos_weight=17.0)


# ----------------
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params

trainable, non_trainable = count_parameters(model)
print(f"Number of trainable parameters: {trainable}")
print(f"Number of non-trainable parameters: {non_trainable}")

In [None]:
print(model)

In [None]:
def embed_esm_batch(batch_sequences, model, alphabet, repr_layer='last'):
    batch_converter = alphabet.get_batch_converter()
    data = [("protein" + str(i), seq) for i, seq in enumerate(batch_sequences)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)

    # Get the embeddings from the last layer
    last_layer = model.num_layers
    token_embeddings = results["representations"][last_layer]
    
    return token_embeddings

def embed_protbert_batch(sequences, model, tokenizer, device='cuda' ):
    model.eval()

    sequences_w_spaces = [' '.join(list(seq)) for seq in sequences]
    processed_sequences = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_w_spaces]

    ids = tokenizer.batch_encode_plus(processed_sequences, add_special_tokens=True, pad_to_max_length=True)
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    with torch.no_grad():
        embedding = model(input_ids=input_ids,attention_mask=attention_mask)[0]

    return embedding[:,1:-1,:]


def embed_antiberty_batch(sequences, model):
    
    embeddings = model.embed(sequences)
    embeddings = [t[1:-1, :] for t in embeddings]  # Removes the first and last rows

    # # Pad the trimmed tensors and stack them
    embeddings = nn.utils.rnn.pad_sequence(embeddings , batch_first=True)

    return embeddings


In [None]:
def format_time(seconds):
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{minutes}m {seconds} s" if minutes>0 else f"{seconds} s"

def train_epoch(model, optimizer, dataloader, encode_mode='onehot_meiler', device = 'cuda', printEvery=100):
    
    model.train()
    total_loss = 0.0
    count_iter = 0
    start_time = time.time()
    epoch_start_time = start_time
    batch_size = dataloader.batch_size
    printEvery = printEvery // batch_size if batch_size else 100  # Adjust printEvery based on batch size

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')
    
    antiberty_model = AntiBERTyRunner()

    with tqdm(total=len(dataloader), desc='Training', unit='batch') as pbar:
        for idx, batch in enumerate(dataloader):
                 
            batch_H_sequences = batch['H_seq']
            batch_L_sequences = batch['L_seq']

            ## different encoding here
            if encode_mode == 'esm':
                Hchain_x = embed_esm_batch(batch_H_sequences,  esm_model, alphabet).to(device)
                Lchain_x = embed_esm_batch(batch_L_sequences,  esm_model, alphabet).to(device)
                Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(450 - Hchain_x.size(1), 0)))
                Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(250 - Lchain_x.size(1), 0)))
            
            elif encode_mode == 'protbert':
                Hchain_x = embed_protbert_batch(batch_H_sequences,  protbert_model, protbert_tokenizer).to(device)
                Lchain_x = embed_protbert_batch(batch_L_sequences,  protbert_model, protbert_tokenizer).to(device)
                Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(450 - Hchain_x.size(1), 0)))
                Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(250 - Lchain_x.size(1), 0)))
            
            elif encode_mode == 'antiberty':
                Hchain_x = embed_antiberty_batch(batch_H_sequences,  antiberty_model).to(device)
                Lchain_x = embed_antiberty_batch(batch_L_sequences,  antiberty_model).to(device)
                Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(450 - Hchain_x.size(1), 0)))
                Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(250 - Lchain_x.size(1), 0)))
                 
            elif encode_mode == 'onehot':
                Hchain_x = onehot_encode_batch(batch_H_sequences, 450).to(device)
                Lchain_x = onehot_encode_batch(batch_L_sequences, 250).to(device)
            else:
                Hchain_x = onehot_meiler_encode_batch(batch_H_sequences, 450).to(device)
                Lchain_x = onehot_meiler_encode_batch(batch_L_sequences, 250).to(device)
            
            x = torch.cat((Hchain_x, Lchain_x), dim=1)


            Hchain_mask = batch['H_mask'].to(device)
            Lchain_mask = batch['L_mask'].to(device)

            masks = torch.cat((Hchain_mask, Lchain_mask ), dim=1)

            
            ## convert (bsz,max_len) to  (bsz,max_len,1)
            H_y_reg = batch['H_target_reg'].unsqueeze(2).float().to(device)
            L_y_reg = batch['L_target_reg'].unsqueeze(2).float().to(device)

            # Hchain_y_reg = clean_output_batch(Hchain_y_reg, Hchain_mask)
            
            y_reg = torch.cat((H_y_reg, L_y_reg ), dim=1)
            
            y_reg = clean_output_batch(y_reg, masks)
        
            
            
            ## prediction
            final_info, output_reg = model(x, masks)
            
            #trim out the padded part
            output_reg = clean_output_batch(output_reg, masks)
            # print(orig_len.sum())
            assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))
            
            
            # reg_loss = mse_loss(output_reg, y_reg)
            # # bin_loss = weighted_bce_loss(output_reg, y_bin)

            current_loss = loss_fn(output_reg, y_reg)

            # else:
            #     ## prediction
            #     y_reg = batch['target'].float().to(device)
            #     final_info, output_reg = model(x, masks)

            #     current_loss = mse_loss(output_reg.squeeze(-1), y_reg)
                
            #     print("pred:",output_reg.squeeze(-1) )
            #     print("tar:",y_reg )

            # Backpropagation
            optimizer.zero_grad()
            current_loss.backward()

            optimizer.step()
            total_loss += current_loss.item()
            
            printEvery = int(1000/x.size(0))
            count_iter += 1
            if count_iter % printEvery == 0 or idx == len(dataloader) - 1:
                elapsed_time = time.time() - start_time
                remaining_time = (elapsed_time / count_iter) * (len(dataloader) - count_iter)
                print(f"Iteration: {count_iter}, Time: {format_time(elapsed_time)}, Remaining: {format_time(remaining_time)}, Training Loss: {total_loss / count_iter:.4f}")
                start_time = time.time()
            torch.cuda.empty_cache()
            pbar.update(1)

    epoch_time = time.time() - epoch_start_time
    print(f"==> Average Training loss: mse ={total_loss / len(dataloader)}")
    print(f"==> Epoch Training Time: {format_time(epoch_time)}")
    print(f"================================================================\n")
    return total_loss / len(dataloader)


def evaluate(model, dataloader, encode_mode='onehot_meiler', device= 'cuda', mode='valid'):
    model.eval()
    total_loss = 0.0
    
    predictions = []
    targets = []
    binary_predictions = []
    binary_targets = []
    orig_lens = []

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')

    antiberty_model = AntiBERTyRunner()
    
    with torch.no_grad():
        with tqdm(total=len(dataloader), unit='batch') as pbar:
            for idx, batch in enumerate(dataloader):
                    
                batch_H_sequences = batch['H_seq']
                batch_L_sequences = batch['L_seq']
                            

                if encode_mode == 'esm':
                    Hchain_x = embed_esm_batch(batch_H_sequences,  esm_model, alphabet).to(device)
                    Lchain_x = embed_esm_batch(batch_L_sequences,  esm_model, alphabet).to(device)
                    H_max_length =  450
                    L_max_length = 250
                    Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(H_max_length - Hchain_x.size(1), 0)))
                    Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(L_max_length - Lchain_x.size(1), 0)))
                
                elif encode_mode == 'protbert':
                    Hchain_x = embed_protbert_batch(batch_H_sequences,  protbert_model, protbert_tokenizer).to(device)
                    Lchain_x = embed_protbert_batch(batch_L_sequences,  protbert_model, protbert_tokenizer).to(device)
                    H_max_length =  450
                    L_max_length = 250
                    Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(H_max_length - Hchain_x.size(1), 0)))
                    Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(L_max_length - Lchain_x.size(1), 0)))
                    
                elif encode_mode == 'antiberty':
                    Hchain_x = embed_antiberty_batch(batch_H_sequences,  antiberty_model).to(device)
                    Lchain_x = embed_antiberty_batch(batch_L_sequences,  antiberty_model).to(device)
                    H_max_length =  450
                    L_max_length = 250
                    Hchain_x   = F.pad(Hchain_x, (0, 0, 0, max(H_max_length - Hchain_x.size(1), 0)))
                    Lchain_x   = F.pad(Lchain_x, (0, 0, 0, max(L_max_length - Lchain_x.size(1), 0)))
                    
                elif encode_mode == 'onehot':
                    Hchain_x = onehot_encode_batch(batch_H_sequences, 450).to(device)
                    Lchain_x = onehot_encode_batch(batch_L_sequences, 250).to(device)
                else:
                    Hchain_x = onehot_meiler_encode_batch(batch_H_sequences, 450).to(device)
                    Lchain_x = onehot_meiler_encode_batch(batch_L_sequences, 250).to(device)

      
                x = torch.cat((Hchain_x, Lchain_x), dim=1)


                Hchain_mask = batch['H_mask'].to(device)
                Lchain_mask = batch['L_mask'].to(device)

                masks = torch.cat((Hchain_mask, Lchain_mask ), dim=1)

                # if not config["pooling"]:

                ## convert (bsz,max_len) to  (bsz,max_len,1)
                H_y_reg = batch['H_target_reg'].unsqueeze(2).float().to(device)
                L_y_reg = batch['L_target_reg'].unsqueeze(2).float().to(device)

                # Hchain_y_reg = clean_output_batch(Hchain_y_reg, Hchain_mask)
                
                y_reg = torch.cat((H_y_reg, L_y_reg ), dim=1)
                
                y_reg = clean_output_batch(y_reg, masks)
            
                
                
                ## prediction
                final_info, output_reg = model(x, masks)
                
                #trim out the padded part
                output_reg = clean_output_batch(output_reg, masks)
                # print(orig_len.sum())
                assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))
                
                
                # reg_loss = mse_loss(output_reg, y_reg)
                # # bin_loss = weighted_bce_loss(output_reg, y_bin)

                current_loss = loss_fn(output_reg, y_reg)

                # else:
                #     ## prediction
                #     y_reg = batch['target'].float().to(device)
                #     final_info, output_reg = model(x, masks)

                #     # print(output_reg.squeeze(-1) )
                #     # print(y_reg )

                #     current_loss = mse_loss(output_reg.squeeze(-1), y_reg)

                total_loss += current_loss.item()

                #append to list of all preds
                predictions.append(output_reg.cpu().numpy())
                targets.append(y_reg.cpu().numpy())

                ################################################################################
                y_bin = (y_reg.cpu().numpy() > 0).astype(int)
                out_bin = (output_reg.cpu().numpy() > 0).astype(int)
                ################################################################################

                
            
                binary_predictions.append(out_bin)
                binary_targets.append(y_bin)

                torch.cuda.empty_cache()
                pbar.update(1)

    all_predictions = np.concatenate(predictions, axis=0).reshape(-1)
    all_targets = np.concatenate(targets, axis=0).reshape(-1)

    print("all pred:",all_predictions)
    print("all tar:",all_targets)

    overall_mse = mean_squared_error(all_targets, all_predictions)
    overall_rmse = np.sqrt(overall_mse)
    overall_mae = mean_absolute_error(all_targets, all_predictions)
    overall_r2 = r2_score(all_targets, all_predictions)
    overall_pcc, _ = pearsonr(all_targets.flatten(), all_predictions.flatten())
    overall_spearman, p_value = spearmanr(all_targets, all_predictions)

    print(f"Overall Regression Metrics")
    print(f"MSE: {overall_mse:.4f}, RMSE: {overall_rmse:.4f}, MAE: {overall_mae:.4f}, R2: {overall_r2:.4f}, PCC: {overall_pcc:.4f}, spear: {overall_spearman:.4f}, P-value: {p_value:.4f}")
    
    # metrics = {
    #         "Regression Metrics": {
    #             "MSE": round(float(overall_mse), 4),
    #             "RMSE": round(float(overall_rmse), 4),
    #             "MAE": round(float(overall_mae), 4),
    #             "R2": round(float(overall_r2), 4),
    #             "PCC": round(float(overall_pcc), 4)
    #         }}
    
    # if not pooling:
    all_binary_predictions = np.concatenate(binary_predictions, axis=0).reshape(-1)
    all_binary_targets = np.concatenate(binary_targets, axis=0).reshape(-1)
    
    overall_accuracy = accuracy_score(all_binary_targets, all_binary_predictions)
    overall_precision = precision_score(all_binary_targets, all_binary_predictions)
    overall_recall = recall_score(all_binary_targets, all_binary_predictions)
    overall_f1 = f1_score(all_binary_targets, all_binary_predictions)
    overall_auc_roc = roc_auc_score(all_binary_targets, all_predictions)
    overall_auc_pr = average_precision_score(all_binary_targets, all_predictions)
    overall_mcc = matthews_corrcoef(all_binary_targets, all_binary_predictions)
    
    
    print(f"Overall classification Metrics")
    print(f"Acc: {overall_accuracy:.4f}, Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1-Score: {overall_f1:.4f}, AUC-ROC: {overall_auc_roc:.4f}, AUC-PR: {overall_auc_pr:.4f}, MCC: {overall_mcc:.4f}")  

    metrics = {
        "Regression Metrics": {
            "MSE": round(float(overall_mse), 4),
            "RMSE": round(float(overall_rmse), 4),
            "MAE": round(float(overall_mae), 4),
            "R2": round(float(overall_r2), 4),
            "PCC": round(float(overall_pcc), 4)
        },
        "Classification Metrics": {
            "Accuracy": round(float(overall_accuracy), 4),
            "Precision": round(float(overall_precision), 4),
            "Recall": round(float(overall_recall), 4),
            "F1-Score": round(float(overall_f1), 4),
            "AUC-ROC": round(float(overall_auc_roc), 4),
            "AUC-PR": round(float(overall_auc_pr), 4),
            "MCC": round(float(overall_mcc), 4)
        }
    }


    return total_loss / len(dataloader),metrics, predictions, targets

def train_loop(model, optimizer, train_dataloader, valid_dataloader, nb_epochs,  encode_mode='onehot_meiler', device= 'cuda', save_directory='./weights/'):
    start_epoch = 1
    best_validation_loss = float('inf')
    early_stopping_counter = 0

    # Paths for saving losses and metrics
    loss_output_path = os.path.join(save_directory, 'losses.json')
    metric_output_path = os.path.join(save_directory, 'metrics.json')
    
    # Initialize lists for losses
    train_losses = []
    val_losses = []
    
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
        print(f'Created directory: {save_directory}')

    checkpoint_path = os.path.join(save_directory, 'model_last.pt')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        
        # # Load losses from the losses.json file if it exists
        # if os.path.exists(loss_output_path):
        #     with open(loss_output_path, 'r') as f:
        #         losses = json.load(f)
        #         train_losses = losses.get('train_losses', [])
        #         val_losses = losses.get('val_losses', [])
        #     print(f'Loaded losses from {loss_output_path}.')
        #     print(train_losses)
        #     print(val_losses)
        # else:
        #     print(f'No losses file found at {loss_output_path}.')

    else:
        print('No checkpoint found. Starting from beginning.')
    
    model.to(device)


    # Load existing losses if available
    if os.path.exists(loss_output_path):
        with open(loss_output_path, 'r') as json_file:
            existing_losses = json.load(json_file)
            train_losses = existing_losses.get('train_losses', [])
            val_losses = existing_losses.get('val_losses', [])
            print(f'Loaded losses from {loss_output_path}.')
            print(train_losses)
            print(val_losses)

    for epoch in range(start_epoch, nb_epochs + 1):
        print("==================================================================================")
        print(f'                            -----EPOCH {epoch}-----')
        print("==================================================================================")
        
        train_loss = train_epoch(model, optimizer, train_dataloader, encode_mode ,device, printEvery=1000)
        train_losses.append(train_loss)
        
        print("==========================VALIDATION===============================================")
        val_loss ,metrics, _ , _ = evaluate(model, valid_dataloader,encode_mode, device)
        val_losses.append(val_loss)

        print(f'==> Epoch {epoch} - Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        if val_loss < best_validation_loss:
            early_stopping_counter = 0
            best_validation_loss = val_loss
            best_model_save_path = os.path.join(save_directory, 'model_best.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'validation_accuracy': val_loss,
            }, best_model_save_path)
            print('\n')
            print(f'Best model checkpoint saved to: {best_model_save_path}')

            # Save metrics of the best model
            with open(metric_output_path, 'w') as json_file:
                json.dump(metrics, json_file, indent=4)
        
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= 5:
                print("\n==> Early stopping triggered. No improvement in validation loss for 3 epochs.")
                break

        last_model_save_path = os.path.join(save_directory, 'model_last.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'validation_accuracy': val_loss,
        }, last_model_save_path)
        print(f'Last epoch model saved to: {last_model_save_path}')

        # Save updated losses to the JSON file
        losses = {
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        with open(loss_output_path, 'w') as json_file:
            json.dump(losses, json_file, indent=4)
        print(f'Losses updated and saved to: {loss_output_path}')
        
        print("==================================================================================\n")
    
        
        
    return

## train here

In [None]:
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), 'w') as json_file:
    json.dump(config, json_file, indent=4)

train_loop(model,optimizer,train_dataloader,valid_dataloader, 50, config['encode_mode'],device,path)

In [None]:
val, metric, preds_val, tar_val = evaluate(model,test_dataloader,config["encode_mode"],device)

In [None]:
preds_val[0]

In [None]:
tar_val[0]

In [None]:
# import matplotlib.pyplot as plt
# x = np.hstack([arr.squeeze() for arr in preds_val])  # Squeeze 2D arrays to 1D and concatenate
# y = np.hstack(tar_val)  # Concatenate 1D arrays

# # Ensure the lengths of x and y are compatible
# assert x.shape[0] == y.shape[0], "The number of elements in x should match the length of y"

# # Plotting the data
# plt.scatter(x, y)
# plt.xlabel("X (from list of 2D arrays)")
# plt.ylabel("Y (from list of 1D arrays)")
# plt.title("Scatter plot of X vs Y")
# plt.xlim(50,80)
# plt.ylim(50,80)
# plt.show()

# # Calculate Pearson Correlation Coefficient (PCC)
# pcc, _ = pearsonr(x, y)
# print(f"Pearson Correlation Coefficient (PCC): {pcc}")

## test

In [None]:

def load_model_from_checkpoint(model, optimizer, checkpoint_path, device):
    """
    Loads the model and optimizer state from a checkpoint if it exists.
    
    Args:
    - model (torch.nn.Module): The model to load the state into.
    - optimizer (torch.optim.Optimizer): The optimizer to load the state into.
    - checkpoint_path (str): Path to the checkpoint file.
    - device (torch.device): Device to which the model should be moved.
    
    Returns:
    - start_epoch (int): The epoch to start training from.
    - best_validation_loss (float): The best validation loss recorded in the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        # print(f'Best validation loss: {best_validation_loss}')
    else:
        start_epoch = 0
        best_validation_loss = float('inf')  # Assuming lower is better for validation loss
        print('No checkpoint found.')
    
    model = model.to(device)
    return start_epoch, best_validation_loss



In [None]:
# # Define the configuration dictionary with all the model parameters
# path = "./weights/seq/(onehot)_(regloss)_(global_1layer256_4head)/"

# with open(path+'config.json', 'r') as json_file:
#     config = json.load(json_file)

# # ----------------
# #  MODEL 
# # ----------------
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)

In [None]:

# List of model paths
model_paths = [
    # "./weights_antibody/seq/(onehot_meiler)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights_antibody/seq/(onehot)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights_antibody/seq/(protbert)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights_antibody/seq/(esm35M)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights_antibody/seq/(antiberty)_(combinedloss)_(local_3block256dim)_(global_1layer128_4head)/",
    
    # "./weights_antibody/seq/(esm35M)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights_antibody/seq/(onehot_meiler)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights_antibody/seq/(onehot)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/",
    # "./weights_antibody/seq/(esm35M)_(regloss)_(local_3block256dim)_(global_1layer128_4head)/"
    
    "./weights_tm/seq/(antiberty)_(regloss)_()/"
]

for path in model_paths:
    # Load the config for the current model
    with open(path + 'config.json', 'r') as json_file:
        config = json.load(json_file)

    # Initialize the model
    model = Aggrepred(config)
    model = model.to(device=device)

    # Load the model weights from the checkpoint
    _, _ = load_model_from_checkpoint(model, optimizer, path + 'model_best.pt', device)

    # Evaluate the model
    loss, metrics, preds, tar = evaluate(model, test_dataloader, config['pooling'], config['encode_mode'] ,device)

    # Save metrics of the best model
    with open(path + 'result.json', 'w') as json_file:
        json.dump(metrics, json_file, indent=4)

    print(f"Processed model in path: {path}")


In [None]:
preds

In [None]:
tar

In [None]:
# Iterate over batches
model1.eval()
data = []
for i, batch in enumerate(dataloader):
    lengths = [len(seq) for seq in batch]
    print(lengths)
    lengths = torch.tensor(lengths).to(device)
    x = onehot_meiler_encode_batch(batch).to(device)
    
    _, out, _ = model1(x, lengths)
    
    for j, sequence in enumerate(batch):
        cleaned_output = clean_output(out[j], lengths[j])
        for k, value in enumerate(cleaned_output):
            data.append([f'protein_{i*32+j+1}', k+1, sequence[k], value.item()])

# Create DataFrame
df = pd.DataFrame(data, columns=['protein_id', 'amino_acid_id', 'amino_acid', 'value'])
df

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

def plot_protein_values(protein_df):

    amino_acid_ids = protein_df['amino_acid_id']
    values = protein_df['value']
    
    # Create the plot
    plt.figure(figsize=(12, 6))
    plt.plot(amino_acid_ids, values, linestyle='-', color='#2596be')
    
    # Add a vertical line at y=0
    plt.axhline(y=0, color='black', linestyle='--')
    
    plt.title(f'Predicted Values ')
    plt.xlabel('Amino Acid ID')
    plt.ylabel('Predicted Value')
    
    # Show the plot
    plt.show()

# Example usage

plot_protein_values(df[df["protein_id"]=='protein_1'])

In [None]:
df[df["protein_id"]=='protein_1']

In [None]:
df.to_csv("tmp/output.csv")

In [None]:
import torch
import torch.nn as nn

class CombinedLoss(nn.Module):
    def __init__(self, lambda_reg=1.0, lambda_bin=1.0, pos_weight=None):
        super(CombinedLoss, self).__init__()
        self.lambda_reg = lambda_reg
        self.lambda_bin = lambda_bin
        self.mse_loss = nn.MSELoss()  # Regression Loss (MSE)
        
        if pos_weight is not None:
            # Binary Classification Loss (Weighted BCE with logits)
            self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        else:
            self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, outputs, regression_targets, binary_targets):
        # Calculate regression loss
        reg_loss = self.mse_loss(outputs, regression_targets)
        
        # Calculate binary classification loss
        # Convert regression output to binary labels (logits) for classification
        bin_loss = self.bce_loss(outputs, binary_targets)
        
        # Combined weighted loss
        total_loss = self.lambda_reg * reg_loss + self.lambda_bin * bin_loss
        return total_loss

# Example usage
outputs = torch.randn(10)  # Raw model outputs (logits or scores)
regression_targets = torch.randn(10)  # Regression target values
binary_targets = (regression_targets > 0).float()  # Binary targets based on the regression task

# Define the combined loss function with custom weights
loss_fn = CombinedLoss(lambda_reg=0.7, lambda_bin=0.3, pos_weight=2.0)

# Compute the combined loss
loss = loss_fn(outputs, regression_targets, binary_targets)
print(outputs)
print(outputs)
print("Combined Loss:", loss.item())


## ARCHIVE

In [None]:
# import os
# import torch
# import time
# from tqdm import tqdm

# def train_and_validate(model, optimizer, train_dataloader, valid_dataloader, nb_epochs, device, save_directory='./weights/', printEvery=50):
#     start_epoch = 1
#     loss = 0
#     losses = []
#     count_iter = 0
#     best_validation_loss = 1000000

#     #if saving directory doesn't exist , create one
#     if not os.path.exists(save_directory):
#         os.makedirs(save_directory)
#         print(f'Created directory: {save_directory}')
    
#     #continue training by loading the last saved model
#     checkpoint_path = save_directory+'/model_last.pt' 
#     if os.path.exists(checkpoint_path):
#         checkpoint = torch.load(checkpoint_path)
#         model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         start_epoch = checkpoint['epoch'] + 1
#         best_validation_loss = checkpoint['validation_accuracy']
#         print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
#     else:
#         print('No checkpoint found. Starting from beginning.')

#     for i in range(start_epoch, nb_epochs):
#         print('-----EPOCH{}-----'.format(i))
#         model.train()
#         time1 = time.time()
#         total_len = 0
#         for idx, batch in tqdm(enumerate(train_dataloader)):
#             x = batch['encoded_seq'].to(device) # already padded
#             orig_len = batch['orig_len'].to(device)

#             ## convert (bsz,max_len) to  (bsz,max_len,1)
#             y_reg = batch['target_reg'].unsqueeze(2).to(device)
#             y_bin = batch['target_bin'].unsqueeze(2).float().to(device)
            
#             y_reg = clean_output_batch(y_reg, orig_len)
#             y_bin = clean_output_batch(y_bin, orig_len)

#             ## in case that y from dataset is not padded, can squeeze to 1D directly.
#             # y_reg = batch['target_reg'].squeeze(-1).to(device)
#             # y_bin = batch['target_bin'].squeeze(-1).float().to(device)

#             total_len += orig_len.sum().item()


#             ### predict 
#             final_info, output_reg, logit = model(x, orig_len)
            
#             #trim out the padded part
#             output_reg = clean_output_batch(output_reg, orig_len)
#             logit = clean_output_batch(logit, orig_len)

#             assert len(logit)==len(y_bin) , 'binary output {} and target {} not same length'.format(len(output_bin),len(y_bin))
#             assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))

#             reg_loss = mse_loss(output_reg, y_reg)
#             bin_loss = weighted_bce_loss(logit, y_bin)
            
#             # print("batch {} , reg_loss :{}, bin_loss: {}".format(idx,reg_loss, bin_loss))

#             current_loss = reg_loss + bin_loss

#             optimizer.zero_grad()
#             current_loss.backward()
#             optimizer.step()
#             loss += current_loss.item()

#             count_iter += 1
#             if count_iter % printEvery == 0:
#                 time2 = time.time()
#                 print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter, time2 - time1, loss / printEvery))
#                 losses.append(loss)
#                 loss = 0
#                 time1 = time.time()
#         count_iter = 0
#         losses.append(loss)
#         loss = 0

#         model.eval()
#         val_loss = 0.0
#         time1 = time.time()
#         total_len = 0
#         with torch.no_grad():
#             for idx, batch in tqdm(enumerate(valid_dataloader)):
#                 x = batch['encoded_seq'].to(device) # already padded
#                 orig_len = batch['orig_len'].to(device)

#                 ## convert (bsz,max_len) to  (bsz,max_len,1)
#                 y_reg = batch['target_reg'].unsqueeze(2).to(device)
#                 y_bin = batch['target_bin'].unsqueeze(2).float().to(device)
                
#                 y_reg = clean_output_batch(y_reg, orig_len)
#                 y_bin = clean_output_batch(y_bin, orig_len)

#                 ## in case that y from dataset is not padded, can squeeze to 1D directly.
#                 # y_reg = batch['target_reg'].squeeze(-1).to(device)
#                 # y_bin = batch['target_bin'].squeeze(-1).float().to(device)

                
#                 total_len += orig_len.sum().item()


#                 ### predict 
#                 final_info, output_reg, logit = model(x, orig_len)
                
#                 #trim out the padded part
#                 output_reg = clean_output_batch(output_reg, orig_len)
#                 logit = clean_output_batch(logit, orig_len)

#                 assert len(logit)==len(y_bin) , 'binary output {} and target {} not same length'.format(len(output_bin),len(y_bin))
#                 assert len(output_reg)==len(y_reg) , 'reg output {} and target {} not same length'.format(len(output_reg),len(y_reg))

#                 reg_loss = mse_loss(output_reg, y_reg)
#                 bin_loss = weighted_bce_loss(logit, y_bin)
                
#                 current_loss = reg_loss 

#                 ## old code
#                 # x = batch['encoded_seq'].to(device)
#                 # y_reg = batch['target_reg'].unsqueeze(2).to(device)
#                 # y_bin = batch['target_bin'].unsqueeze(2).float().to(device)
#                 # orig_len = batch['orig_len'].to(device)
#                 # total_len += orig_len.sum().item()

#                 # final_info, reg_output, bin_output = model(x, orig_len)
#                 # reg_loss = mse_loss(reg_output, y_reg) / total_len
#                 # current_loss = reg_loss


#                 val_loss += current_loss.item()
#             time2 = time.time()

#             print('best_val', best_validation_loss)
#             print("val loss", val_loss)

#             best_validation_loss = min(best_validation_loss, val_loss)
#             # print('best_val after: ', best_validation_loss)

#             print('-----EPOCH' + str(i) + '----- done.  Validation loss: ', str(val_loss / len(valid_dataloader)))
#             #save model with best valid loss
#             if best_validation_loss == val_loss:
#                 print('validation loss improved saving checkpoint...')
                
#                 best_model_save_path = save_directory+'/model_best.pt' 
#                 # best_model_save_path = os.path.join(save_directory,'/model_best.pt')
#                 torch.save({
#                     'epoch': i,
#                     'model_state_dict': model.state_dict(),
#                     'optimizer_state_dict': optimizer.state_dict(),
#                     'validation_accuracy': val_loss,
#                     'loss': loss,
#                 }, best_model_save_path)
#                 print('Best model checkpoint saved to: {}'.format(best_model_save_path))
            
#             #save model of last epoch
#             last_model_save_path = save_directory+'/model_last.pt' 
#             torch.save({
#                 'epoch': i,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'validation_accuracy': val_loss,
#                 'loss': loss,
#             }, last_model_save_path)
#             print('Last epoch model saved to: {}'.format(last_model_save_path))
            
#     return 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def fermi_function(s, alpha=0, beta=0.3):
    return 1 / (np.exp(-(s - alpha) * beta) + 1)

def reverse_fermi_function(y, alpha=0, beta=0.3):
    return alpha - np.log(1 / y - 1) / beta

def apply_fermi_function(score_list, alpha=0, beta=0.3):
    # Apply fermi_function to each score and round to 4 decimal places
    return [round(fermi_function(score, alpha, beta), 4) for score in score_list]

def apply_reverse_fermi_function(mapped_list, alpha=0, beta=0.3):
    # Apply reverse_fermi_function to each mapped score and round to 4 decimal places
    return [round(reverse_fermi_function(mapped_score, alpha, beta), 4) for mapped_score in mapped_list]


# Generate a range of aggregation risk values from -10 to 10 to better see the full range of the Fermi function
s = np.linspace(-8, 8, 500)

# Different values of beta to see the effect on the curve
betas = [0.1, 0.3, 0.5, 1, 2]

# Plot the results
plt.figure(figsize=(10, 6))

for beta in betas:
    fermi_values = fermi_function(s, alpha=0, beta=beta)
    plt.plot(s, fermi_values, label=f'beta = {beta}')

plt.title('Fermi Function Mapping of Aggregation Risk Values')
plt.xlabel('Aggregation Risk (s)')
plt.ylabel('Mapped Value (0 to 1)')
plt.ylim(-0.05, 1.05)  # Extend y-axis limits to better show the range 0 to 1
# plt.xlim(-5, 5)  # Extend x-axis limits to cover the full range of the function
plt.legend()
plt.grid(True)
plt.show()


In [None]:
from Bio import SeqIO
def read_fasta(fasta_text):
    headers = []
    sequences = []
    for record in SeqIO.parse(fasta_text, "fasta"):
        headers.append(record.id)
        sequences.append(str(record.seq))
    return headers,sequences

fasta_text =">prot1\nQPPVPPQRPM"

head, seqs = read_fasta(fasta_text)