In [1]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
import pickle
import seaborn as sns

from torch_geometric.utils import to_networkx
#install required packages
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt

from torch_geometric.data import Dataset
import torch_geometric.utils as pyg_utils
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool,GATv2Conv
from torch_geometric.nn.models import GCN, GAT
from torch.nn import Linear

from torch_geometric.utils import degree

import torch.nn as nn
from torch_geometric.utils import softmax
import math
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr, spearmanr
import random
from sklearn.metrics import root_mean_squared_error,mean_absolute_error
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
import copy


2.4.1+cu118


In [2]:

def set_seed(seed):
    random.seed(seed)  # Python random
    np.random.seed(seed)  # Numpy random
    torch.manual_seed(seed)  # PyTorch CPU
    torch.cuda.manual_seed(seed)  # PyTorch GPU (un singolo dispositivo)
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU (tutti i dispositivi, se usi multi-GPU)
    torch.backends.cudnn.deterministic = True  # Comportamento deterministico di cuDNN
    torch.backends.cudnn.benchmark = False  # Evita che cuDNN ottimizzi dinamicamente (influisce su riproducibilità)

# Imposta il seed
set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
from random import sample

class DeltaDataset(Dataset):
    def __init__(self, data, dim_embedding, inv = False):
        self.data = data
        self.dim_embedding = dim_embedding
        self.inv = inv

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        if self.inv: 
            return {
                'id': sample['id'],
                'wild_type': torch.tensor(sample['mut_type'], dtype=torch.float32),    #inverto mut con wild 
                'mut_type': torch.tensor(sample['wild_type'], dtype=torch.float32),    #inverto mut con wild             
                'length': torch.tensor(sample['length'], dtype=torch.float32),
                }

        else:
            return {
                'id': sample['id'],
                'wild_type': torch.tensor(sample['wild_type'], dtype=torch.float32),
                'mut_type': torch.tensor(sample['mut_type'],dtype=torch.float32),
                'length': torch.tensor(sample['length'], dtype=torch.float32),
                }

In [4]:
import random
from torch.utils.data import DataLoader


import torch
import torch.nn.functional as F

def collate_fn(batch):
    max_len = max(sample['wild_type'].shape[0] for sample in batch)  
    max_features = max(sample['wild_type'].shape[1] for sample in batch)  # Max feature size

    padded_batch = {
        'id': [],
        'wild_type': [],
        'mut_type': [],
        'length': [],
        }
    for sample in batch:
        wild_type_padded = F.pad(sample['wild_type'], (0, max_features - sample['wild_type'].shape[1], 
                                                       0, max_len - sample['wild_type'].shape[0]))
        mut_type_padded = F.pad(sample['mut_type'], (0, max_features - sample['mut_type'].shape[1], 
                                                     0, max_len - sample['mut_type'].shape[0]))

        padded_batch['id'].append(sample['id'])  
        padded_batch['wild_type'].append(wild_type_padded)  
        padded_batch['mut_type'].append(mut_type_padded)  
        padded_batch['length'].append(sample['length'])#append(torch.tensor(sample['length'], dtype=torch.float32))  

    # Convert list of tensors into a single batch tensor
    padded_batch['wild_type'] = torch.stack(padded_batch['wild_type'])  # Shape: (batch_size, max_len, max_features)
    padded_batch['mut_type'] = torch.stack(padded_batch['mut_type'])  
    padded_batch['length'] = torch.stack(padded_batch['length'])  
    
    return padded_batch


In [5]:
class Cross_Attention_DDG(nn.Module):
    
    def __init__(self, base_module, cross_att=False, dual_cross_att= False,**transf_parameters):
        super().__init__()
        self.base_ddg = base_module(**transf_parameters, cross_att=cross_att, dual_cross_att= dual_cross_att).to(device)
    
    def forward(self, x_wild, x_mut, length, train = True):

        delta_x = x_wild - x_mut
        output_TCA = self.base_ddg(delta_x, x_wild, length)

        # inv Janus
        delta_x_inv = x_mut -x_wild
        output_TCA_inv = self.base_ddg(delta_x_inv, x_mut, length)
        
        return (output_TCA - output_TCA_inv)/2



In [6]:
def output_model_from_batch(batch, model, device,train=True):

    '''Dato un modello pytorch e batch restituisce: output_modello, True labels'''
    
    x_wild = batch['wild_type'].float().to(device)
    x_mut = batch['mut_type'].float().to(device)
    length = batch['length'].to(device)
    output_ddg = model(x_wild, x_mut, length, train = train)
    
    return output_ddg



In [7]:
import torch
import torch.nn as nn


def apply_masked_pooling(position_attn_output, padding_mask):

    # Convert mask to float for element-wise multiplication
    padding_mask = padding_mask.float()

    # Global Average Pooling (GAP) - Exclude padded tokens
    # Sum only over valid positions (padding_mask is False for valid positions)
    sum_output = torch.sum(position_attn_output * (1 - padding_mask.unsqueeze(-1)), dim=1)  # (batch_size, feature_dim)
    valid_count = torch.sum((1 - padding_mask).float(), dim=1)  # (batch_size,)
    gap = sum_output / valid_count.unsqueeze(-1)  # Divide by number of valid positions

    # Global Max Pooling (GMP) - Exclude padded tokens
    # Set padded positions to -inf so they don't affect the max computation
    position_attn_output_masked = position_attn_output * (1 - padding_mask.unsqueeze(-1)) + (padding_mask.unsqueeze(-1) * (- 1e10))
    gmp, _ = torch.max(position_attn_output_masked, dim=1)  # (batch_size, feature_dim)

    return gap, gmp


class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_len=3700):    

        super(SinusoidalPositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Shape (1, max_len, embedding_dim)
        self.register_buffer('pe', pe)  # Salvato come tensore fisso (non parametro)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]


class TransformerRegression(nn.Module):
    def __init__(self, input_dim=1280, num_heads=8, dropout_rate=0., num_experts=1, f_activation = nn.ReLU(), kernel_size=20, cross_att = True,
                dual_cross_att=True):
        
        super(TransformerRegression, self).__init__()
        self.cross_att = cross_att
        self.dual_cross_att = dual_cross_att
        
        print(f'Cross Attention: {cross_att}')
        print(f'Dual Cross Attention: {dual_cross_att}')

        self.embedding_dim = input_dim
        self.act = f_activation                                       
        self.max_len = 3700
        out_channels = 128  #num filtri conv 1D                  
        kernel_size = 20
        padding = 0
        
        self.conv1d = nn.Conv1d(in_channels=self.embedding_dim, 
                                             out_channels=out_channels, 
                                             kernel_size=kernel_size, 
                                             padding=padding) 
        
        self.conv1d_wild = nn.Conv1d(in_channels=self.embedding_dim, 
                                             out_channels=out_channels, 
                                             kernel_size=kernel_size, 
                                             padding=padding)

        self.norm1 = nn.LayerNorm(out_channels)
        self.norm2 = nn.LayerNorm(out_channels)
        
        # Cross-attention layers
        self.positional_encoding = SinusoidalPositionalEncoding(out_channels, 3700) 
        self.speach_att_type = True
        self.multihead_attention = nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True )
        self.inverse_attention = nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, dropout=dropout_rate, batch_first =True)
        
        if cross_att:
            # Router (learns which expert to choose per token)
            if dual_cross_att:
                dim_position_wise_FFN = out_channels*2
            else:
                dim_position_wise_FFN = out_channels


        else:
            dim_position_wise_FFN = out_channels
        
        self.norm3 = nn.LayerNorm(dim_position_wise_FFN)
        self.norm4 = nn.LayerNorm(dim_position_wise_FFN)        
        self.router = nn.Linear(dim_position_wise_FFN, num_experts) #dim_position_wise_FFN*2
        # Mixture of Experts (Switch FFN)
        self.num_experts = num_experts
        self.experts = nn.ModuleList([nn.Sequential(
            nn.Linear(dim_position_wise_FFN, 512),
            self.act,
            nn.Linear(512, dim_position_wise_FFN)
        ) for _ in range(num_experts)])
        # self.experts = nn.Sequential(
        #     nn.Linear(dim_position_wise_FFN, 512),
        #     self.act,
        #     nn.Linear(512, dim_position_wise_FFN)
        #     )
        

        self.Linear_ddg = nn.Linear(dim_position_wise_FFN*2, 1)

    def create_padding_mask(self, length, seq_len, batch_size):
        """
        Create a padding mask for multihead attention.
        length: Tensor of shape (batch_size,) containing the actual lengths of the sequences.
        seq_len: The maximum sequence length.
        batch_size: The number of sequences in the batch.
        
        Returns a padding mask of shape (batch_size, seq_len).
        """
        mask = torch.arange(seq_len, device=length.device).unsqueeze(0) >= length.unsqueeze(1)
        return mask

    def forward(self, delta_w_m, x_wild, length):
            # Add positional encoding
            
            delta_w_m = delta_w_m.transpose(1, 2)  # (batch_size, feature_dim, seq_len) -> (seq_len, batch_size, feature_dim)
            C_delta_w_m = self.conv1d(delta_w_m)
            # C_delta_w_m = self.act(C_delta_w_m)  #CASTRENSE USA RELU IO NON AVEVO MESSO NULLA 
            C_delta_w_m = C_delta_w_m.transpose(1, 2)  # (seq_len, batch_size, feature_dim) -> (batch_size, seq_len, feature_dim)
            C_delta_w_m = self.positional_encoding(C_delta_w_m)
            
            x_wild = x_wild.transpose(1, 2)  # (batch_size, feature_dim, seq_len) -> (seq_len, batch_size, feature_dim)
            C_x_wild = self.conv1d_wild(x_wild)
            # C_x_wild = self.act(C_x_wild)  #CASTRENSE USA RELU IO NON AVEVO MESSO NULLA 
            C_x_wild = C_x_wild.transpose(1, 2)  # (seq_len, batch_size, feature_dim) -> (batch_size, seq_len, feature_dim)
            C_x_wild = self.positional_encoding(C_x_wild)            
            
            batch_size, seq_len, feature_dim = C_x_wild.size()

            padding_mask = self.create_padding_mask(length, seq_len, batch_size)        

            if self.cross_att :
                if self.dual_cross_att:
                    
                    if self.speach_att_type:
                        print('ATTENTION TYPE: Dual cross Attention\n q = wild , k = delta, v = delta and q = delta , k = wild, v = wild \n ----------------------------------')
                        self.speach_att_type = False
                        
                    direct_attn_output, _ = self.multihead_attention(C_x_wild, C_delta_w_m, C_delta_w_m, key_padding_mask=padding_mask)
                    direct_attn_output += C_delta_w_m 
                    direct_attn_output = self.norm1(direct_attn_output)                        
                    
                    inverse_attn_output, _ = self.inverse_attention(C_delta_w_m, C_x_wild, C_x_wild, key_padding_mask=padding_mask)                   
                    inverse_attn_output += C_x_wild  
                    inverse_attn_output = self.norm2(inverse_attn_output)
                    
                    attn_output = torch.cat([direct_attn_output, inverse_attn_output], dim=-1)
                    #combined_output = self.norm3(combined_output)

                else:
                    if self.speach_att_type:
                        print('ATTENTION TYPE: Cross Attention \n q = wild , k = delta, v = delta  \n ----------------------------------')
                        self.speach_att_type = False

                    attn_output, _ = self.multihead_attention(C_x_wild, C_delta_w_m, C_delta_w_m, key_padding_mask=padding_mask)
                    attn_output += C_delta_w_m 
                    attn_output = self.norm1(attn_output) 
            
            else:
                if self.speach_att_type:
                    print('ATTENTION TYPE: Self Attention \n q = delta , k = delta, v = delta  \n ----------------------------------')
                    self.speach_att_type = False
                
                attn_output, _ = self.multihead_attention(C_delta_w_m, C_delta_w_m, C_delta_w_m, key_padding_mask=padding_mask)
                attn_output += C_delta_w_m
                attn_output = self.norm1(attn_output)


            ########
            # Route tokens to experts
            routing_logits = self.router(attn_output)  # Shape: [batch, seq_len, num_experts]
            routing_weights = F.softmax(routing_logits, dim=-1)  # Probability distribution over experts
            expert_indices = torch.argmax(routing_weights, dim=-1)  # Choose the most probable expert for each token
            
            # Apply selected expert
            batch_size, seq_len, embed_dim = attn_output.shape
            output = torch.zeros_like(attn_output)
            for i in range(self.num_experts):
                mask = (expert_indices == i).unsqueeze(-1).float()  # Mask for tokens assigned to expert i
                expert_out = self.experts[i](attn_output) * mask  # Apply expert only to selected tokens
                output += expert_out  # Aggregate expert outputs
            ############ù

            # output = self.experts(attn_output)

            position_attn_output = attn_output + output

            position_attn_output = self.norm3(position_attn_output)
    
            gap, gmp = apply_masked_pooling(position_attn_output, padding_mask)
    
            # Concatenate GAP and GMP
            pooled_output = torch.cat([gap, gmp], dim=-1)  # (batch_size, 2 * feature_dim)
    
            # Pass through FFNN to predict DDG
            x = self.Linear_ddg(pooled_output)        
            
            return x.squeeze(-1)

In [8]:
def dataloader_generation_pred(dataset_test, batch_size = 128, dataloader_shuffle = True, inv= False):
    
    dim_embedding = 1280
    dataset_test = DeltaDataset(dataset_test, dim_embedding, inv = inv)
    # Creazione DataLoader
    dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=dataloader_shuffle, collate_fn=collate_fn)

    return dataloader_test


In [9]:
def model_performance_test(model, dataloader_test):

    model.eval()
    all_predictions_test = []
    
    with torch.no_grad():
       
        for i, batch in enumerate(dataloader_test):

            predictions_test=output_model_from_batch(batch, model, device, train=False)
            all_predictions_test.append(predictions_test)
    
    return all_predictions_test

In [10]:
lr = 1e-4
input_dim = 1280

transf_parameters={'input_dim':1280, 'num_heads':8,
                    'dropout_rate':0.,}

i=4
best_model = torch.load('../DeltaDelta_BELLO/JanusDDG_300epochs_plus25_hydra_slim.pth')#('JanusDDG_300epochs_plus25_hydra_slim.pth')#('JanusDDG_300epochs.pth')#('JanusDDG_300epochs_plus25_hydra_slim.pth')#('JanusDDG_300epochs_plus15_hydra_slim.pth')#('JanusDDG_300epochs.pth')#('JanusDDG_300_all_train.pth')#('JanusDDG_300epochs.pth')#(f'JanusDDG_300epochs.pth')
best_model.eval()
#torch.load(f'DDGemb_Cross_4.pth')


#IL MODELLO FINE TUNED PER SINGOLE è JanusDDG_300epochs_plus25_hydra_slim   TRANSITIVO

  best_model = torch.load('../DeltaDelta_BELLO/JanusDDG_300epochs_plus25_hydra_slim.pth')#('JanusDDG_300epochs_plus25_hydra_slim.pth')#('JanusDDG_300epochs.pth')#('JanusDDG_300epochs_plus25_hydra_slim.pth')#('JanusDDG_300epochs_plus15_hydra_slim.pth')#('JanusDDG_300epochs.pth')#('JanusDDG_300_all_train.pth')#('JanusDDG_300epochs.pth')#(f'JanusDDG_300epochs.pth')


Cross_Attention_DDG(
  (base_ddg): TransformerRegression(
    (act): ReLU()
    (conv1d): Conv1d(1280, 128, kernel_size=(20,), stride=(1,))
    (conv1d_wild): Conv1d(1280, 128, kernel_size=(20,), stride=(1,))
    (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (positional_encoding): SinusoidalPositionalEncoding()
    (multihead_attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (inverse_attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (norm4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (router): Linear(in_features=256, out_features=1, bias=True)
    (experts): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=256, out_features=512,

In [11]:

dataset_test = pd.read_pickle('Data/S669_GITHUB.pkl')#['../DeltaDelta_BELLO/ptmul_test.pkl']#['zeroshot_q3421.pkl']#['Ssym_correct_by_KORPM.pkl']#['test_TS16.pkl']#['ptmul_test.pkl']#['s669_hydra_Castrense.pkl']#['Ssym_correct_by_KORPM.pkl']#['s461_Castrense.pkl']#['s669_Castrense.pkl']#['../DeltaDelta_BELLO/cdna117k_fold_1.pkl'] + ['../DeltaDelta_BELLO/cdna117k_fold_2.pkl']#['s669_Castrense.pkl']
#['dataset_doppie.pkl']

In [12]:
dataloader_test_dir = dataloader_generation_pred(dataset_test=dataset_test,  batch_size = 1, dataloader_shuffle = False, inv= False)
#for x in range(10):
all_predictions_test_dir = model_performance_test(best_model,dataloader_test_dir)

In [13]:
pd.Series(torch.cat(all_predictions_test_dir, dim=0).cpu()).values

array([-0.11431189, -0.13115619,  0.1607447 , -0.13510053, -0.34920874,
        0.13511008, -0.6159836 , -0.02062377,  0.11663852,  0.2492003 ,
       -0.48888952,  0.5399512 , -0.07175331, -0.42336938, -0.76117593,
       -0.39408144,  0.16212992,  0.35123318, -1.1982888 ,  0.14625174,
        0.11561026,  0.08850323,  0.17889091, -0.26694697, -0.4221142 ,
       -0.5527453 , -0.56517804, -1.6603603 , -0.58505076, -0.13119891,
       -0.1985518 , -0.24318878, -1.6383183 , -1.266849  , -0.3179284 ,
       -3.301221  , -3.0439787 , -1.7461777 , -0.60497785, -3.421257  ,
       -3.1162083 , -2.320097  ,  0.33576518, -0.1985998 ,  0.02073852,
        0.24327345,  0.3730581 , -0.2630698 , -2.8110585 , -2.17335   ,
       -2.3282282 , -2.2220244 , -2.9223785 , -2.972989  , -2.3053417 ,
       -1.7682157 , -1.3394463 , -0.33359408, -0.69262743, -2.3853426 ,
        0.06818868, -0.16025242, -2.6745303 , -2.795452  , -2.0306082 ,
       -2.8934853 , -2.1347373 , -1.8698704 , -0.27625477, -0.83

In [16]:
pearsonr(pd.Series(torch.cat(all_predictions_test_dir, dim=0).cpu()).values,[i['ddg'] for i in pd.read_pickle('../DeltaDelta_BELLO/s669_Castrense.pkl')])

PearsonRResult(statistic=0.5489732584234097, pvalue=6.379444575246027e-54)

In [None]:
ddg = pd.DataFrame(pd.Series(torch.cat(all_predictions_test_dir, dim=0).cpu()).values)
ddg

In [4]:
preds = pd.read_csv('Results/Result_Tsuboyama_doppie_first50k.csv')
preds = preds[preds['DDG'] != '-']

In [14]:
preds['DDG'].values

array(['0.01366685967324699', '-0.25930853889603966',
       '0.14123238152892093', ..., '-2.733004344088733',
       '-1.9413717963098467', '-2.6021286581702276'], dtype=object)

In [15]:
pearsonr(preds['DDG_JanusDDG'],[float(i) for i in preds['DDG']])

PearsonRResult(statistic=0.1968380761732687, pvalue=0.0)