In [None]:
import os
import sys
import re
import json
import ast
import time
import logging
import pandas as pd
import numpy as np
from tqdm import tqdm
from scipy.spatial.distance import cdist
from scipy.stats import pearsonr

# PyTorch and PyTorch Geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as graphDataLoader
from torch_geometric.nn import GATConv, GCNConv

# Transformers and ESM
from transformers import BertModel, BertTokenizer
import esm

# Custom imports from aggrepred
top_folder_path = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..'))
sys.path.insert(0, top_folder_path)

from aggrepred.graph_utils import *
from aggrepred.graph_model import  EGNN_Model

# Default values used in training
NEIGHBOUR_RADIUS = 10

# Metrics and evaluation
from sklearn.metrics import (
    mean_squared_error, 
    mean_absolute_error, 
    r2_score,
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score, 
    roc_auc_score, 
    average_precision_score, 
    matthews_corrcoef
)


# Utils

### process all pdb to graph file format .pt
### make sure to store all PDB file in "../data/pdb/" folder

In [None]:
data_dir  = "../data/"
pdb_dir = "../data/pdb/"
graph_dir = "../data/graph/"

df = pd.read_csv(data_dir+"csv/data60_fixed_split.csv")

for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Processing rows"):
    code = row["ID"]
    # chain = row.chain
    # code = row["code"]
    
    graph_file_path = os.path.join(graph_dir, f"{code}.pt")
    pdb_path = pdb_dir+ f"{code}.pdb"
    
    # Skip processing if the graph file already exists
    if os.path.exists(graph_file_path):
        continue
    
    _ = process_pdb2graph(pdb_path,graph_file_path)


from concurrent.futures import ThreadPoolExecutor, as_completed



# Dataset dataloader

In [5]:

class GraphDataset(Dataset):
    def __init__(self, df, graph_dir):
        self.data_frame = df.copy()
        self.codes = self.data_frame['ID'].tolist()
        self.graph_dir = graph_dir

        # Check if all graph files exist and filter the codes list accordingly
        self.codes = [code for code in self.codes if os.path.exists(f"{self.graph_dir}/{code}.pt")]

    def __len__(self):
        return len(self.codes)

    def __getitem__(self, idx):
        code_id = self.codes[idx]
        graph_path = f"{self.graph_dir}/{code_id}.pt"
        graph_data = torch.load(graph_path)
        return graph_data



In [None]:
# csv_file = '/users/eleves-b/2023/ly-an.chhay/main/data/csv/all.csv'
graph_dir = "../data/graph/"

df = pd.read_csv("../data/csv/data60_fixed_split.csv")

train_dataset = GraphDataset(df[df.split=='train'], graph_dir)
valid_dataset = GraphDataset(df[df.split=='valid'], graph_dir)
test_dataset = GraphDataset(df[df.split=='test'], graph_dir)

train_dataset = GraphDataset(df[df.split=='train'].sample(frac=0.10, random_state=42), graph_dir)
valid_dataset = GraphDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42), graph_dir)
test_dataset = GraphDataset(df[df.split=='test'], graph_dir)

train_dataloader = graphDataLoader(train_dataset, batch_size=1, shuffle=True)
valid_dataloader = graphDataLoader(valid_dataset, batch_size=1, shuffle=True)
test_dataloader = graphDataLoader(test_dataset, batch_size=1, shuffle=True)



In [None]:
len(train_dataloader)

# Model

In [None]:
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params

num_feats = 20
graph_hidden_layer_output_dims = [20,20,20]
linear_hidden_layer_output_dims = [10,10]

model = EGNN_Model(num_feats = num_feats,
                       graph_hidden_layer_output_dims = graph_hidden_layer_output_dims,
                       linear_hidden_layer_output_dims = linear_hidden_layer_output_dims)


trainable, non_trainable = count_parameters(model)
print(f"Number of trainable parameters: {trainable}")
print(f"Number of non-trainable parameters: {non_trainable}")

In [None]:
print(model)

In [None]:
for batch in train_dataloader:
    print(batch)
    break

x , coors, edge_index , batch , y = batch.x , batch.pos, batch.edge_index, batch.batch, batch.y
# convert edge_index to adjacent matrix, as in EGNN take adj_mat
x = x.unsqueeze(0)
coors = coors.unsqueeze(0)
edges = edge_index_to_adjacency_matrix(edge_index).unsqueeze(2).unsqueeze(0)


In [None]:
print('size of x:', x.size())
print('size of coors:', coors.size())
print('size of edges:', edges.size())

# Trainer
## config

In [None]:
# ----------------
# PARAM
# ----------------


# Define the configuration dictionary with all the model parameters
# path = "./weights/graph/(esm8M)_(1EGNN)/"
path = "./weights/graph/(onehot)_(3EGNN)/"

config = {
    "model": 'EGNN',
    "num_feats" : 20,
    "graph_hidden_layer_output_dims" : [20,20,20],
    "linear_hidden_layer_output_dims" : [20,10],
    "learning_rate": 1e-5,
    "batch_size": 1,
    "nb_epochs": 20,
    "encode_mode" : 'onehot'
}


# ----------------
#  MODEL 
# ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


model = EGNN_Model(num_feats = config["num_feats"],
                       graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                       linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])



# ----------------
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params

print(path)
print(model)
trainable, non_trainable = count_parameters(model)
print(f"Number of trainable parameters: {trainable}")
print(f"Number of non-trainable parameters: {non_trainable}")

In [12]:
# ----------------
#   OPTIMIZER 
# ----------------
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


class CombinedLoss(nn.Module):
    def __init__(self, lambda_reg=1.0, lambda_bin=1.0, pos_weight=None):
        super(CombinedLoss, self).__init__()
        self.lambda_reg = lambda_reg
        self.lambda_bin = lambda_bin
        self.mse_loss = nn.MSELoss()  # Regression Loss (MSE)
        
        if pos_weight is not None:
            # Binary Classification Loss (Weighted BCE with logits)
            self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        else:
            self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, outputs, regression_targets):
        # Calculate regression loss
        reg_loss = self.mse_loss(outputs, regression_targets)
        
        # Calculate binary classification loss
        # Convert regression output to binary labels (logits) for classification
        binary_targets = (regression_targets> 0).float()
        bin_loss = self.bce_loss(outputs, binary_targets)
        
        # Combined weighted loss
        total_loss = self.lambda_reg * reg_loss + self.lambda_bin * bin_loss
        return total_loss

combined_loss = CombinedLoss(lambda_reg=0.7, lambda_bin=0.3, pos_weight=4.0)



## dataloader

In [None]:
graph_dir = "../data/graph/"
df = pd.read_csv("../data/csv/data60_fixed_split.csv")

train_dataset = GraphDataset(df[df.split=='train'].sample(frac=0.10, random_state=42), graph_dir)
valid_dataset = GraphDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42), graph_dir)
test_dataset = GraphDataset(df[df.split=='test'], graph_dir)

train_dataloader = graphDataLoader(train_dataset, batch_size=1, shuffle=True)
valid_dataloader = graphDataLoader(valid_dataset, batch_size=1, shuffle=True)
test_dataloader = graphDataLoader(test_dataset, batch_size=1, shuffle=True)


In [23]:



def format_time(seconds):
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{minutes}m {seconds}s"

def train_epoch(model, optimizer, dataloader, device, printEvery=50):
    model.train()
    total_loss = 0.0
    count_iter = 0
    start_time = time.time()
    epoch_start_time = start_time
    accumulation_steps = 8

    esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')


    with tqdm(total=len(dataloader), desc='Training', unit='batch') as pbar:
        for idx, batch in enumerate(dataloader):
            x, coors, edge_index, batch, y = batch.x, batch.pos, batch.edge_index, batch.batch, batch.y
            
            batch_sequences = [onehot_to_sequence(x)]
            
            if config["encode_mode"] == 'esm':
                x = embed_esm_batch(batch_sequences, esm_model, alphabet).to(device)
            elif config["encode_mode"] == 'protbert':
                x = embed_protbert_batch(batch_sequences, protbert_model, protbert_tokenizer).to(device)
            else:
                x = x.unsqueeze(0).to(device)

                
            if config['model']== 'EGNN':
                
                edge_index = edge_index.to(device)
                coors = coors.unsqueeze(0).to(device)
                edges = edge_index_to_adjacency_matrix(edge_index).unsqueeze(2).unsqueeze(0).to(device)
                out = model(x, coors, edges).squeeze()
            else:
                x = x.squeeze(0).to(device)
                edge_index = edge_index.to(device)
                out = model(x, edge_index).squeeze()

            current_loss = combined_loss(out, y.to(device))
            
            optimizer.zero_grad()
            current_loss.backward()
            optimizer.step()

            total_loss += current_loss.item()
            
            count_iter += 1
            if count_iter % printEvery == 0:
                elapsed_time = time.time() - start_time
                remaining_time = (elapsed_time / count_iter) * (len(dataloader) - count_iter)
                print(f"Iteration: {count_iter}, Time: {format_time(elapsed_time)}, Remaining: {format_time(remaining_time)}, Training Loss: {total_loss / count_iter:.4f}")
                start_time = time.time()
            
            #remove cache to save GPU
            torch.cuda.empty_cache()
            pbar.update(1)
            

    epoch_time = time.time() - epoch_start_time
    print(f"==> Average Training loss: mse ={total_loss / len(dataloader)}")
    print(f"==> Epoch Training Time: {format_time(epoch_time)}")
    print(f"================================================================\n")

    return total_loss / len(dataloader)


def evaluate(model, dataloader, device, mode='valid'):
    model.eval()
    total_loss = 0.0
    
    predictions = []
    targets = []
    binary_predictions = []
    binary_targets = []

    esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')


    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            x, coors, edge_index, batch, y = batch.x, batch.pos, batch.edge_index, batch.batch, batch.y
            
            batch_sequences = [onehot_to_sequence(x)]
            
            if config["encode_mode"] == 'esm':
                x = embed_esm_batch(batch_sequences, esm_model, alphabet).to(device)
            elif config["encode_mode"] == 'protbert':
                x = embed_protbert_batch(batch_sequences, protbert_model, protbert_tokenizer).to(device)
            else:
                x = x.unsqueeze(0).to(device)
            
            if config['model']== 'EGNN':
                
                edge_index = edge_index.to(device)
                coors = coors.unsqueeze(0).to(device)
                edges = edge_index_to_adjacency_matrix(edge_index).unsqueeze(2).unsqueeze(0).to(device)
                out = model(x, coors, edges).squeeze()
            else:
                x = x.squeeze(0).to(device)
                edge_index = edge_index.to(device)
                out = model(x, edge_index).squeeze()

            current_loss = combined_loss(out, y.to(device))
            # current_loss = weighted_bce_loss(out, (y>0).float().to(device)) + mse_loss(out, y.to(device))
            
            total_loss += current_loss.item()
       

            #append to list of all preds
            predictions.append(out.cpu().numpy())
            targets.append(y.cpu().numpy())
            
            ## Convert regression targets to binary labels
            y_bin = (y.cpu().numpy() > 0).astype(int)
            out_bin = (out.cpu().numpy() > 0).astype(int)
            
            binary_predictions.append(out_bin)
            binary_targets.append(y_bin)

    # if mode == 'test':
    all_predictions = np.concatenate(predictions, axis=0)
    all_targets = np.concatenate(targets, axis=0)
    all_binary_predictions = np.concatenate(binary_predictions, axis=0)
    all_binary_targets = np.concatenate(binary_targets, axis=0)

    # Calculate overall metrics
    overall_mse = mean_squared_error(all_targets, all_predictions)
    overall_rmse = np.sqrt(overall_mse)
    overall_mae = mean_absolute_error(all_targets, all_predictions)
    overall_r2 = r2_score(all_targets, all_predictions)
    overall_pcc, _ = pearsonr(all_targets.flatten(), all_predictions.flatten())

    # Calculate binary classification metrics
    overall_accuracy = accuracy_score(all_binary_targets, all_binary_predictions)
    overall_precision = precision_score(all_binary_targets, all_binary_predictions)
    overall_recall = recall_score(all_binary_targets, all_binary_predictions)
    overall_f1 = f1_score(all_binary_targets, all_binary_predictions)
    overall_auc_roc = roc_auc_score(all_binary_targets, all_predictions)
    overall_auc_pr = average_precision_score(all_binary_targets, all_predictions)
    overall_mcc = matthews_corrcoef(all_binary_targets, all_binary_predictions)

    print(f"Overall Reg Metrics - MSE: {overall_mse:.4f}, RMSE: {overall_rmse:.4f}, MAE: {overall_mae:.4f}, R2: {overall_r2:.4f}, PCC: {overall_pcc:.4f}")
    
    print(f"Overall Classification Metrics - Accuracy: {overall_accuracy:.4f}, Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1-Score: {overall_f1:.4f}, AUC-ROC: {overall_auc_roc:.4f}, AUC-PR: {overall_auc_pr:.4f}, MCC: {overall_mcc:.4f}")
    metrics = {
        "Regression Metrics": {
            "MSE": round(float(overall_mse), 4),
            "RMSE": round(float(overall_rmse), 4),
            "MAE": round(float(overall_mae), 4),
            "R2": round(float(overall_r2), 4),
            "PCC": round(float(overall_pcc), 4)
        },
        "Classification Metrics": {
            "Accuracy": round(float(overall_accuracy), 4),
            "Precision": round(float(overall_precision), 4),
            "Recall": round(float(overall_recall), 4),
            "F1-Score": round(float(overall_f1), 4),
            "AUC-ROC": round(float(overall_auc_roc), 4),
            "AUC-PR": round(float(overall_auc_pr), 4),
            "MCC": round(float(overall_mcc), 4)
        }
    }

    return total_loss / len(dataloader),metrics, predictions, targets

def train_loop(model, optimizer, train_dataloader, valid_dataloader, nb_epochs, device, save_directory='./weights/'):
    
    start_epoch = 1
    best_validation_loss = float('inf')
    early_stopping_counter = 0

    # Paths for saving losses and metrics
    loss_output_path = os.path.join(save_directory, 'losses.json')
    metric_output_path = os.path.join(save_directory, 'metrics.json')
    
    # Initialize lists for losses
    train_losses = []
    val_losses = []
    
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
        print(f'Created directory: {save_directory}')

    checkpoint_path = os.path.join(save_directory, 'model_last.pt')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        

    else:
        print('No checkpoint found. Starting from beginning.')
    
    model.to(device)


    
    # Load existing losses if available
    if os.path.exists(loss_output_path):
        with open(loss_output_path, 'r') as json_file:
            existing_losses = json.load(json_file)
            train_losses = existing_losses.get('train_losses', [])
            val_losses = existing_losses.get('val_losses', [])
            print(train_losses)
            print(val_losses)

    for epoch in range(start_epoch, nb_epochs + 1):
        print("==================================================================================")
        print(f'                            -----EPOCH {epoch}-----')
        print("==================================================================================")
        
        train_loss = train_epoch(model, optimizer, train_dataloader,device, printEvery=1000)
        train_losses.append(train_loss)
        
        # # **Print Gradients**
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         print(f'Gradient - {name}: {param.grad.norm()}')  # Prints the norm of gradients

        print("==========================VALIDATION===============================================")
        val_loss ,metrics, _ , _ = evaluate(model, valid_dataloader,device)
        val_losses.append(val_loss)

        print(f'==> Epoch {epoch} - Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        if val_loss < best_validation_loss:
            early_stopping_counter = 0
            best_validation_loss = val_loss
            best_model_save_path = os.path.join(save_directory, 'model_best.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'validation_accuracy': val_loss,
            }, best_model_save_path)
            print('\n')
            print(f'Best model checkpoint saved to: {best_model_save_path}')

            # Save metrics of the best model
            with open(metric_output_path, 'w') as json_file:
                json.dump(metrics, json_file, indent=4)
        
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= 3:
                print("\n==> Early stopping triggered. No improvement in validation loss for 3 epochs.")
                break

        last_model_save_path = os.path.join(save_directory, 'model_last.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'validation_accuracy': val_loss,
        }, last_model_save_path)
        print(f'Last epoch model saved to: {last_model_save_path}')
        print("==================================================================================\n")
    
        # Save updated losses to the JSON file
        losses = {
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        with open(loss_output_path, 'w') as json_file:
            json.dump(losses, json_file, indent=4)
        print(f'Losses updated and saved to: {loss_output_path}')
        
    return

   

## Train here

In [None]:
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), 'w') as json_file:
    json.dump(config, json_file, indent=4)

model.to(device)
train_loop(model,optimizer,train_dataloader,valid_dataloader,50, device,path)

# load best model and test on test-set

In [24]:

def load_model_from_checkpoint(model, optimizer, checkpoint_path, device):
    """
    Loads the model and optimizer state from a checkpoint if it exists.
    
    Args:
    - model (torch.nn.Module): The model to load the state into.
    - optimizer (torch.optim.Optimizer): The optimizer to load the state into.
    - checkpoint_path (str): Path to the checkpoint file.
    - device (torch.device): Device to which the model should be moved.
    
    Returns:
    - start_epoch (int): The epoch to start training from.
    - best_validation_loss (float): The best validation loss recorded in the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        # print(f'Best validation loss: {best_validation_loss}')
    else:
        start_epoch = 0
        best_validation_loss = float('inf')  # Assuming lower is better for validation loss
        print('No checkpoint found.')
    
    model = model.to(device)
    return start_epoch, best_validation_loss




In [None]:

# List of model paths
model_paths = [
    "./weights/graph/(onehot)_(3EGNN)/",
    "./weights/graph/(esm8M)_(1EGNN)/"
]

for path in model_paths:
    # Load the config for the current model
    with open(path + 'config.json', 'r') as json_file:
        config = json.load(json_file)


    # Initialize the model
    model = EGNN_Model(num_feats = config["num_feats"],
                        graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                        linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])


    # Load the model weights from the checkpoint
    _, _ = load_model_from_checkpoint(model, optimizer, path + 'model_best.pt', device)

    # Evaluate the model
    loss, metrics, preds, tar = evaluate(model, test_dataloader ,device)

    # Save metrics of the best model
    with open(path + 'result.json', 'w') as json_file:
        json.dump(metrics, json_file, indent=4)

    print(f"Processed model in path: {path}")