In [None]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
import pickle
import seaborn as sns
from torch_geometric.utils import to_networkx
import os
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
# Helper function for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from torch_geometric.data import Dataset
import torch_geometric.utils as pyg_utils
import torch.nn.functional as F
from torch.nn import Linear
import torch.nn as nn
from torch_geometric.utils import softmax
import math
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr, spearmanr
import random
from sklearn.metrics import root_mean_squared_error,mean_absolute_error


In [None]:

def set_seed(seed):
    random.seed(seed)  # Python random
    np.random.seed(seed)  # Numpy random
    torch.manual_seed(seed)  # PyTorch CPU
    torch.cuda.manual_seed(seed)  # PyTorch GPU (un singolo dispositivo)
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU (tutti i dispositivi, se usi multi-GPU)
    torch.backends.cudnn.deterministic = True  # Comportamento deterministico di cuDNN
    torch.backends.cudnn.benchmark = False  # Evita che cuDNN ottimizzi dinamicamente (influisce su riproducibilitÃ )

# Imposta il seed
set_seed(42)


In [None]:
import esm

model_esm, alphabet_esm = esm.pretrained.esm2_t33_650M_UR50D()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_esm = model_esm.to(device)
batch_converter_esm = alphabet_esm.get_batch_converter()
model_esm.eval()

def Esm2_embedding(seq, model_esm = model_esm, batch_converter_esm = batch_converter_esm):
    sequences = [("protein", seq),]
    
    batch_labels, batch_strs, batch_tokens = batch_converter_esm(sequences)
    batch_tokens = batch_tokens.to(device)

    with torch.no_grad():
        results = model_esm(batch_tokens, repr_layers=[33])  # Usa l'ultimo layer
        token_representations = results["representations"][33]
    
    embedding = token_representations[0, 1:-1].cpu().numpy()
    return embedding

In [None]:
from random import sample

class DeltaDataset(Dataset):
    def __init__(self, data, dim_embedding, inv = False):
        self.data = data
        self.dim_embedding = dim_embedding
        self.inv = inv

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        if self.inv: 
            return {
                'id': sample['id'],
                'wild_type': torch.tensor(sample['mut_type'], dtype=torch.float32),    # inverto mut con wild 
                'mut_type': torch.tensor(sample['wild_type'], dtype=torch.float32),    # inverto mut con wild             
                'length': torch.tensor(sample['length'], dtype=torch.float32),
                'ddg': torch.tensor(-float(sample['ddg']), dtype=torch.float32),       # -ddg
                'pos_mut': torch.tensor(sample['pos_mut'], dtype=torch.int64),
                'hydra_slim': torch.tensor(sample['mut_type']*0, dtype=torch.float32),

                }

        else:
            return {
                'id': sample['id'],
                'wild_type': torch.tensor(sample['wild_type'], dtype=torch.float32),
                'mut_type': torch.tensor(sample['mut_type'],dtype=torch.float32),
                'length': torch.tensor(sample['length'], dtype=torch.float32),
                'ddg': torch.tensor(float(sample['ddg']), dtype=torch.float32),
                'pos_mut': torch.tensor(sample['pos_mut'], dtype=torch.int64),
                'hydra_slim': torch.tensor(sample['mut_type']*0, dtype=torch.float32),

                }


In [None]:
from torch_geometric.loader import DataLoader
import random

import torch
import torch.nn.functional as F

def collate_fn(batch):
    max_len = max(sample['wild_type'].shape[0] for sample in batch)  # Max sequence length in batch   700
    max_features = max(sample['wild_type'].shape[1] for sample in batch)  # Max feature size

    padded_batch = {
        'id': [],
        'wild_type': [],
        'mut_type': [],
        'length': [],
        'ddg': [],
        'pos_mut': [],
        'hydra_slim':[],
    }

    for sample in batch:
        
        wild_type_padded = F.pad(sample['wild_type'], (0, max_features - sample['wild_type'].shape[1], 
                                                       0, max_len - sample['wild_type'].shape[0]))
        
        mut_type_padded = F.pad(sample['mut_type'], (0, max_features - sample['mut_type'].shape[1], 
                                                     0, max_len - sample['mut_type'].shape[0]))
        
        hydra_slim_type_padded = F.pad(sample['hydra_slim'], (0, max_features - sample['hydra_slim'].shape[1], 
                                                       0, max_len - sample['hydra_slim'].shape[0]))        

        padded_batch['id'].append(sample['id'])  
        padded_batch['wild_type'].append(wild_type_padded)  
        padded_batch['mut_type'].append(mut_type_padded)  
        padded_batch['length'].append(sample['length'])
        padded_batch['ddg'].append(sample['ddg'])
        padded_batch['hydra_slim'].append(hydra_slim_type_padded)


    # Convert list of tensors into a single batch tensor
    padded_batch['wild_type'] = torch.stack(padded_batch['wild_type'])  # Shape: (batch_size, max_len, max_features)
    padded_batch['mut_type'] = torch.stack(padded_batch['mut_type'])  
    padded_batch['length'] = torch.stack(padded_batch['length'])  
    padded_batch['ddg'] = torch.stack(padded_batch['ddg'])
    padded_batch['hydra_slim'] = torch.stack(padded_batch['hydra_slim'])

    return padded_batch


def dataloader_generation(path, collate_fn, batch_size = 128, dataloader_shuffle = True, inv= False):
    
    dim_embedding = 1280
    dataset= []

    for path in path:
        with open(path, 'rb') as f:
            dataset += pickle.load(f)

    delta_dataset = DeltaDataset(dataset, dim_embedding, inv = inv)  
    dataloader_delta = DataLoader(delta_dataset, batch_size=batch_size, shuffle=dataloader_shuffle, collate_fn=collate_fn)

    return dataloader_delta


In [None]:
from torch.utils.data import DataLoader  # Use standard PyTorch DataLoader
import random
from itertools import chain
from collections import Counter

train_path =[f'train_data/s2450_fold_{i}_hydra_slim.pkl' for i in [0,1,2,3,4]]+[f'train_data/s2450_fold_{i}_hydra_slim_inv.pkl' for i in [0,1,2,3,4]]#+[f's2450_fold_{i}_hydra_slim.pkl' for i in [0,1,2,3,4]]+[f's2450_fold_{i}_hydra_slim_inv.pkl' for i in [0,1,2,3,4]]

val_path = ['train_data/M28_test.pkl']
test_path = ['train_data/M28_test.pkl']

dataloader_train = dataloader_generation(path = train_path, batch_size = 6,collate_fn=collate_fn, dataloader_shuffle = True, inv= False)
dataloader_validation = dataloader_generation(path = val_path, batch_size = 1, collate_fn=collate_fn, dataloader_shuffle = False, inv= False)
dataloader_test = dataloader_generation(path = test_path, batch_size = 1, collate_fn=collate_fn, dataloader_shuffle = False, inv= False)


In [None]:
import copy


def output_model_from_batch(batch, model, device, hydra = False,train=False):
    
    x_wild = batch['wild_type'].float().to(device)
    x_mut = batch['mut_type'].float().to(device)
    hydra_slim = batch['hydra_slim'].float().to(device)
    labels = batch['ddg'].float().to(device)
    length = batch['length'].to(device)
    output_ddg = model(x_wild, x_mut, hydra_slim, length, hydra=hydra, train = train)
    
    return output_ddg, labels



def training_and_validation_loop_ddg(model, dataloader_train, dataloader_test, dataloader_validation, path_save_fig, epochs=20, lr =0.001, patience=10):
            
    criterion =nn.MSELoss()# nn.HuberLoss()#nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    
    pearson_r_train = []
    pearson_r_test = []
    pearson_r_validation = []
    
    loss_ddg_train = []
    loss_ddg_train_TRANS = []
    loss_ddg_train_TOT = []
    
    loss_ddg_validation = []
    loss_ddg_validation_TRANS = []
    loss_ddg_validation_TOT = []
    
    loss_ddg_test = []
    loss_ddg_test_TRANS=[]
    loss_ddg_test_TOT = []

    num_epochs = epochs
    for epoch in range(num_epochs):
            
        # Training Loop
        model.train()
        preds_ddg_train = []
        labels_tot_epoch = []

        preds_ddg_train_TRANS = []
        labels_tot_epoch_TRANS = []

        for i, batch in enumerate(dataloader_train):
            train = True
            optimizer.zero_grad()
            output_ddg_train, labels_train = output_model_from_batch(batch, model, device, hydra=False, train=True)
            output_ddg_HYDRA_SLIM_train, _ = output_model_from_batch(batch, model, device, hydra=True, train=True)
            
            loss_ddg = criterion(output_ddg_train, labels_train)  #usa se NON uso hydra            
            tot_loss = loss_ddg + criterion(output_ddg_HYDRA_SLIM_train, output_ddg_train)
            
            # Backpropagation and optimization
            tot_loss.backward()
            optimizer.step()

            # Collect predictions
            preds_ddg_train.extend(output_ddg_train.cpu().reshape(-1).tolist())
            labels_tot_epoch.extend(labels_train.cpu().tolist())

            preds_ddg_train_TRANS.extend(output_ddg_HYDRA_SLIM_train.cpu().reshape(-1).tolist())
            labels_tot_epoch_TRANS.extend(output_ddg_train.cpu().tolist())            

        # Calculate and print train metrics
        train_loss = mean_squared_error(preds_ddg_train, labels_tot_epoch)
        train_loss_TRANS = mean_squared_error(preds_ddg_train_TRANS, labels_tot_epoch_TRANS)
        
        train_correlation = pearsonr(preds_ddg_train, labels_tot_epoch)[0]
        train_spearman = spearmanr(preds_ddg_train, labels_tot_epoch)[0]
        
        loss_ddg_train.append(train_loss)
        loss_ddg_train_TRANS.append(train_loss_TRANS)
        loss_ddg_train_TOT.append(train_loss_TRANS+train_loss)
        pearson_r_train.append(train_correlation)
        
        # Validation Loop
        model.eval()  # Set model to evaluation mode
                
        all_preds_validation = []
        all_labels_validation = []
        all_preds_validation_TRANS = []

        
        all_preds_test = []
        all_labels_test = []
        all_preds_test_TRANS = []
                
        with torch.no_grad():  # Disable gradient calculation
            for i, batch in enumerate(dataloader_test):

                output_ddg_test, labels_test = output_model_from_batch(batch, model, device, hydra=False, train=False) 
                output_ddg_HYDRA_SLIM_test, _ = output_model_from_batch(batch, model, device, hydra=True, train=False)      
                    
                all_preds_test.extend(output_ddg_test.cpu().reshape(-1).tolist())
                all_labels_test.extend(labels_test.cpu().tolist())

                all_preds_test_TRANS.extend(output_ddg_HYDRA_SLIM_test.cpu().reshape(-1).tolist())
            
            # Calculate validation metrics
            test_loss = mean_squared_error(all_preds_test, all_labels_test)
            loss_ddg_test.append(test_loss)

            test_loss_TRANS = mean_squared_error(all_preds_test_TRANS, all_preds_test)
            loss_ddg_test_TRANS.append(test_loss_TRANS)

            loss_ddg_test_TOT.append(test_loss+test_loss_TRANS)
            
            test_correlation, _ = pearsonr(all_preds_test, all_labels_test)
            pearson_r_test.append(test_correlation)

            test_correlation_TRANS = pearsonr(all_preds_test_TRANS, all_preds_test)

            for i, batch in enumerate(dataloader_validation):
                output_ddg_validation, labels_validation = output_model_from_batch(batch, model, device, hydra=False, train=False,)#inizio = 'wild',fine='mut')
                output_ddg_HYDRA_SLIM_validation, _ = output_model_from_batch(batch, model, device, hydra=True, train=False)      

                all_preds_validation.extend(output_ddg_validation.cpu().reshape(-1).tolist())
                all_labels_validation.extend(labels_validation.cpu().tolist()) #MESSO UN -  se DEF AL CONTRARIO

                all_preds_validation_TRANS.extend(output_ddg_HYDRA_SLIM_validation.cpu().reshape(-1).tolist())

            
            # Calculate validation metrics
            val_loss = mean_squared_error(all_preds_validation, all_labels_validation)
            loss_ddg_validation.append(val_loss)

            val_loss_TRANS = mean_squared_error(all_preds_validation_TRANS, all_preds_validation)
            loss_ddg_validation_TRANS.append(val_loss_TRANS)

            loss_ddg_validation_TOT.append(val_loss+val_loss_TRANS)
            
            
            val_correlation, _ = pearsonr(all_preds_validation, all_labels_validation)
            pearson_r_validation.append(val_correlation)

        # print(f'pearson tra triangolazione e non triangolazione : {test_correlation_TRANS}\n')
        # print(f'pearson tra triangolazione e true ddg: {pearsonr(all_preds_test_TRANS, all_labels_test)}\n')
        
        if val_correlation >= max(pearson_r_validation): 
            best_model = copy.deepcopy(model)
            print(f'\033[91mEpoch {epoch+1}/{num_epochs}')
            print(f'Train -  trans_loss={train_loss_TRANS:.4f},    Loss: {train_loss:.4f}, Pearson r: {train_correlation:.4f}, Rho spearman: {train_spearman:.4f}')
            print(f'Validation - Loss: {val_loss:.4f}, Pearson r: {val_correlation:.4f}, Rho spearman: {spearmanr(all_preds_validation, all_labels_validation)[0]:.4f}',)        
            print(f'Test - trans_loss={test_loss_TRANS:.4f},      Loss: {test_loss:.4f}, Pearson r: {test_correlation:.4f}, Rho spearman: {spearmanr(all_preds_test, all_labels_test)[0]:.4f}\033[0m\n')
      

        else:
            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f'Train -    trans_loss={train_loss_TRANS:.4f},    Loss: {train_loss:.4f}, Pearson r: {train_correlation:.4f}, Rho spearman: {train_spearman:.4f}')
            print(f'Validation - Loss: {val_loss:.4f}, Pearson r: {val_correlation:.4f}, Rho spearman: {spearmanr(all_preds_validation, all_labels_validation)[0]:.4f}',)        
            print(f'Test -  trans_loss={test_loss_TRANS:.4f}      Loss: {test_loss:.4f}, Pearson r: {test_correlation:.4f}, Rho spearman: {spearmanr(all_preds_test, all_labels_test)[0]:.4f}\n')
                  
        if epoch > (np.argmax(pearson_r_validation) + patience):
            print(f'\033[91mEarly stopping at epoch {epoch+1}\033[0m')
            break
    
    pearson_max_val = np.max(pearson_r_validation)

    return pearson_r_train, pearson_r_validation, pearson_r_test, loss_ddg_train, loss_ddg_validation, loss_ddg_test, loss_ddg_train_TRANS, loss_ddg_validation_TRANS, loss_ddg_test_TRANS, loss_ddg_train_TOT, loss_ddg_validation_TOT, loss_ddg_test_TOT

In [None]:

class Cross_Attention_DDG(nn.Module):
    
    def __init__(self, base_module, cross_att=False, dual_cross_att= False, hydra=True ,**transf_parameters):
        super().__init__()
        self.base_ddg = base_module(**transf_parameters, cross_att=cross_att, dual_cross_att= dual_cross_att).to(device)
        self.hydra=hydra
    
    def forward(self, x_wild, x_mut, hydra_slim, length, hydra=False, train = True):

        if train:
            if hydra:
                
                # Calcolo DDG tra wild e primo intermezzo
                delta_dir = x_wild - hydra_slim
                wild_half_DDG = self.base_ddg(delta_dir, x_wild, length)
                
                # Calcolo DDG tra ultimo intermezzo e mutato
                delta_dir = hydra_slim - x_mut
                half_mut_DDG = self.base_ddg(delta_dir, hydra_slim, length)
                
                # Somma totale
                output_TCA = wild_half_DDG + half_mut_DDG
    
            else:
                # Calcolo DDG tra wild e primo intermezzo
                delta_dir = x_wild - x_mut
                output_TCA = self.base_ddg(delta_dir, x_wild, length)         
    
        else:
            if hydra:
                
                # Calcolo DDG tra wild e primo intermezzo
                delta_dir = x_wild - hydra_slim
                delta_inv = hydra_slim - x_wild
                wild_half_DDG = (self.base_ddg(delta_dir, x_wild, length) - self.base_ddg(delta_inv, hydra_slim, length)) / 2
                
                # Calcolo DDG tra ultimo intermezzo e mutato
                delta_dir = hydra_slim - x_mut
                delta_inv = x_mut - hydra_slim
                half_mut_DDG = (self.base_ddg(delta_dir, hydra_slim, length) - self.base_ddg(delta_inv, x_mut, length)) / 2
                
                # Somma totale
                output_TCA = wild_half_DDG + half_mut_DDG
    
            else:
                # Calcolo DDG tra wild e primo intermezzo
                delta_dir = x_wild - x_mut
                delta_inv = x_mut - x_wild
                output_TCA = (self.base_ddg(delta_dir, x_wild, length) - self.base_ddg(delta_inv, x_mut, length)) / 2            
            
        return output_TCA  


In [None]:

def apply_masked_pooling(position_attn_output, padding_mask):

    # Convert mask to float for element-wise multiplication
    padding_mask = padding_mask.float()

    # Global Average Pooling (GAP) - Exclude padded tokens
    # Sum only over valid positions (padding_mask is False for valid positions)
    sum_output = torch.sum(position_attn_output * (1 - padding_mask.unsqueeze(-1)), dim=1)  # (batch_size, feature_dim)
    valid_count = torch.sum((1 - padding_mask).float(), dim=1)  # (batch_size,)
    gap = sum_output / valid_count.unsqueeze(-1)  # Divide by number of valid positions

    # Global Max Pooling (GMP) - Exclude padded tokens
    # Set padded positions to -inf so they don't affect the max computation
    position_attn_output_masked = position_attn_output * (1 - padding_mask.unsqueeze(-1)) + (padding_mask.unsqueeze(-1) * (- 1e10))
    gmp, _ = torch.max(position_attn_output_masked, dim=1)  # (batch_size, feature_dim)

    return gap, gmp


class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_len=3700):
        super(SinusoidalPositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Shape (1, max_len, embedding_dim)
        self.register_buffer('pe', pe)  # Salvato come tensore fisso (non parametro)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]


class TransformerRegression(nn.Module):
    def __init__(self, input_dim=1280, num_heads=8, dropout_rate=0., num_experts=1, f_activation = nn.ReLU(), kernel_size=20, cross_att = True,
                dual_cross_att=True):
        
        super(TransformerRegression, self).__init__()

        self.embedding_dim = input_dim
        self.act = f_activation
        self.max_len = 3700 #lunghezza massima proteina
        out_channels = 128  #num filtri conv 1D
        kernel_size = 20
        padding = 0
        
        self.conv1d = nn.Conv1d(in_channels=self.embedding_dim, 
                                             out_channels=out_channels, 
                                             kernel_size=kernel_size, 
                                             padding=padding) 
        
        self.conv1d_wild = nn.Conv1d(in_channels=self.embedding_dim, 
                                             out_channels=out_channels, 
                                             kernel_size=kernel_size, 
                                             padding=padding)

        self.norm1 = nn.LayerNorm(out_channels)
        self.norm2 = nn.LayerNorm(out_channels)
        
        # Cross-attention layers
        self.positional_encoding = SinusoidalPositionalEncoding(out_channels, 3700)
        self.speach_att_type = True
        self.multihead_attention = nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True )
        self.inverse_attention = nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, dropout=dropout_rate, batch_first =True)
        
        dim_position_wise_FFN = out_channels*2

        self.norm3 = nn.LayerNorm(dim_position_wise_FFN)
        self.router = nn.Linear(dim_position_wise_FFN, num_experts) #dim_position_wise_FFN*2

        self.pw_ffnn = nn.Sequential(
            nn.Linear(dim_position_wise_FFN, 512),
            self.act,
            nn.Linear(512, dim_position_wise_FFN)
            )
        

        self.Linear_ddg = nn.Linear(dim_position_wise_FFN*2, 1)

            

    def create_padding_mask(self, length, seq_len, batch_size):
        """
        Create a padding mask for multihead attention.
        length: Tensor of shape (batch_size,) containing the actual lengths of the sequences.
        seq_len: The maximum sequence length.
        batch_size: The number of sequences in the batch.
        
        Returns a padding mask of shape (batch_size, seq_len).
        """
        mask = torch.arange(seq_len, device=length.device).unsqueeze(0) >= length.unsqueeze(1)
        return mask



    def forward(self, delta_w_m, x_wild, length):
            
            delta_w_m = delta_w_m.transpose(1, 2)  # (batch_size, feature_dim, seq_len) -> (seq_len, batch_size, feature_dim)
            C_delta_w_m = self.conv1d(delta_w_m)
            C_delta_w_m = C_delta_w_m.transpose(1, 2)  # (seq_len, batch_size, feature_dim) -> (batch_size, seq_len, feature_dim)
            C_delta_w_m = self.positional_encoding(C_delta_w_m)
            
            x_wild = x_wild.transpose(1, 2)  # (batch_size, feature_dim, seq_len) -> (seq_len, batch_size, feature_dim)
            C_x_wild = self.conv1d_wild(x_wild)
            C_x_wild = C_x_wild.transpose(1, 2)  # (seq_len, batch_size, feature_dim) -> (batch_size, seq_len, feature_dim)
            C_x_wild = self.positional_encoding(C_x_wild)            
            
            batch_size, seq_len, feature_dim = C_x_wild.size()

            padding_mask = self.create_padding_mask(length, seq_len, batch_size)        
                    
            if self.speach_att_type:
                print('ATTENTION TYPE: Dual cross Attention\n q = wild , k = delta, v = delta and q = delta , k = wild, v = wild \n ----------------------------------')
                self.speach_att_type = False
                
            direct_attn_output, _ = self.multihead_attention(C_x_wild, C_delta_w_m, C_delta_w_m, key_padding_mask=padding_mask)
            direct_attn_output += C_delta_w_m 
            direct_attn_output = self.norm1(direct_attn_output)                        
            
            inverse_attn_output, _ = self.inverse_attention(C_delta_w_m, C_x_wild, C_x_wild, key_padding_mask=padding_mask)                   
            inverse_attn_output += C_x_wild  
            inverse_attn_output = self.norm2(inverse_attn_output)
            
            attn_output = torch.cat([direct_attn_output, inverse_attn_output], dim=-1)

            output = self.pw_ffnn(attn_output)
    
            position_attn_output = attn_output + output
    
            position_attn_output = self.norm3(position_attn_output)
    
            gap, gmp = apply_masked_pooling(position_attn_output, padding_mask)
    
            # Concatenate GAP and GMP
            pooled_output = torch.cat([gap, gmp], dim=-1)  # (batch_size, 2 * feature_dim)
    
            # Pass through FFNN to predict DDG
            x = self.Linear_ddg(pooled_output)        
            
            return x.squeeze(-1)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:

lr = 1e-4
input_dim = 1280
transf_parameters={'input_dim':1280, 'num_heads':8,
                    'dropout_rate':0.,}

patience = 300
DDG_model = TransformerRegression
Final_model =torch.load('JanusDDG.pth', map_location=torch.device('cpu'))
path_save_fig = 'JanusDDG \n ----------------------------------'
print(path_save_fig)
pearson_r_train, pearson_r_validation, pearson_r_test, loss_ddg_train, loss_ddg_validation, loss_ddg_test, loss_ddg_train_TRANS, loss_ddg_validation_TRANS, loss_ddg_test_TRANS, loss_ddg_train_TOT, loss_ddg_validation_TOT, loss_ddg_test_TOT = training_and_validation_loop_ddg(Final_model, dataloader_train, dataloader_test,
                                                                                   dataloader_validation,
                                                                                   path_save_fig, epochs=28, lr =lr,patience = patience)


In [None]:
#torch.save(Final_model, 'JanusDDG_fine_tuned.pth')