In [58]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
import re
import sys
from scipy.spatial.distance import cdist
from torch_geometric.data import Data
import ast
from tqdm import tqdm
import json 
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import esm
top_folder_path = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..'))
sys.path.insert(0, top_folder_path)


from aggrepred.utils import *

# default values using in training
NEIGHBOUR_RADIUS = 10


# Utils

In [59]:
def onehot_encode(sequence: str, max_length: int = 1000) -> torch.Tensor:
    """
    One-hot encode an amino acid sequence

    :param sequence:   protein sequence
    :param max_length: specify the maximum length for protein sequence to use, it helps to have have size in batches

    :return: max_length x num_features tensor
    """
    
    seqlen = len(sequence)
    use_length = max_length
    # use_length = min(seqlen,max_length) #variable to len of seq, if want fix_size like 1024, set it to fix
    encoded = torch.zeros((use_length, NUM_AMINOS))
    for i in range(min(seqlen, max_length)):
        aa = sequence[i]
        encoded[i][aa2idx.get(aa, NUM_AMINOS-1)] = 1
    return encoded

def onehot_encode_batch(sequences: list, max_length: int = 1000) -> torch.Tensor:
    """
    One-hot encode an amino acid sequence

    :param sequence:   protein sequence
    :param max_length: specify the maximum length for protein sequence to use, it helps to have have size in batches

    :return: max_length x num_features tensor
    """
    batch_size = len(sequences)
    batch_encoded = torch.zeros((batch_size, max_length, NUM_AMINOS))
    for i, seq in enumerate(sequences):
        batch_encoded[i] = onehot_encode(seq, max_length)
    return batch_encoded

def onehot_meiler_encode(sequence: str, max_length: int = 1000) -> torch.Tensor:
    """
    One-hot encode an amino acid sequence, then concatenate with Meiler features.

    :param sequence:   protein sequence
    :param max_length: specify the maximum length for protein sequence to use, it helps to have have size in batches

    :return: max_length x num_features tensor
    """
    
    seqlen = len(sequence)
    use_length = max_length
    # use_length = min(seqlen,max_length) #variable to len of seq, if want fix_size like 1024, set it to fix
    encoded = torch.zeros((use_length, NUM_AMINOS+ NUM_MEILER))
    for i in range(min(seqlen, max_length)):
        aa = sequence[i]
        encoded[i][aa2idx.get(aa, NUM_AMINOS-1)] = 1
        encoded[i][-NUM_MEILER:] = MEILER[aa] if aa in MEILER else MEILER["X"]
    return encoded


def onehot_meiler_encode_batch(sequences: list, max_length: int = 1000) -> torch.Tensor:
    """
    One-hot encode an amino acid sequence, then concatenate with Meiler features.

    :param sequence:   protein sequence
    :param max_length: specify the maximum length for protein sequence to use, it helps to have have size in batches

    :return: max_length x num_features tensor
    """
    batch_size = len(sequences)
    batch_encoded = torch.zeros((batch_size, max_length, NUM_AMINOS+ NUM_MEILER))
    for i, seq in enumerate(sequences):
        batch_encoded[i] = onehot_meiler_encode(seq, max_length)
    return batch_encoded

def embed_esm_batch(batch_sequences, model, alphabet, repr_layer='last'):
    batch_converter = alphabet.get_batch_converter()
    data = [("protein" + str(i), seq) for i, seq in enumerate(batch_sequences)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)

    # Get the embeddings from the last layer
    last_layer = model.num_layers
    token_embeddings = results["representations"][last_layer]
    
    return token_embeddings[:,1:-1,:]

def embed_protbert_batch(sequences, model, tokenizer, device='cuda' ):
    model.eval()

    sequences_w_spaces = [' '.join(list(seq)) for seq in sequences]
    processed_sequences = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_w_spaces]

    ids = tokenizer.batch_encode_plus(processed_sequences, add_special_tokens=True, pad_to_max_length=True)
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    with torch.no_grad():
        embedding = model(input_ids=input_ids,attention_mask=attention_mask)[0]

    return embedding[:,1:-1,:]


In [60]:

# ----------------
# Helper functions
# ----------------

# Dictionary to convert 3-letter codes to 1-letter codes
AA_3to1 = {
    'ALA': 'A',
    'ARG': 'R',
    'ASN': 'N',
    'ASP': 'D',
    'CYS': 'C',
    'GLN': 'Q',
    'GLU': 'E',
    'GLY': 'G',
    'HIS': 'H',
    'ILE': 'I',
    'LEU': 'L',
    'LYS': 'K',
    'MET': 'M',
    'PHE': 'F',
    'PRO': 'P',
    'SER': 'S',
    'THR': 'T',
    'TRP': 'W',
    'TYR': 'Y',
    'VAL': 'V'
}

# Dictionary to convert 1-letter codes to 3-letter codes
AA_1to3 = {v: k for k, v in AA_3to1.items()}

def get_Calpha_df(df, chain_id=None):
    '''
    Filters a DataFrame containing PDB data to return only the C-alpha atoms for specified chain IDs.
    
    Parameters:
    df (pd.DataFrame): DataFrame containing PDB data.
    chain_id (list): List of chain IDs to filter by. Default is [None], meaning take all.
    
    Returns:
    pd.DataFrame: DataFrame containing only the C-alpha atoms for all(the whole pdb) or the specified chain IDs.
    '''
    if chain_id is None:
        return df[(df["Atom_Name"].str.strip() == "CA")].reset_index(drop=True)
    else:
        out = df[(df["Atom_Name"].str.strip() == "CA") & (df["Chain"].isin(chain_id))].reset_index(drop=True)
        if len(out) == 0:
            raise ValueError("No matching chain in the PDB file.")
        return out


def get_AA_onehot_features(df,  chain_id=None):
    '''
    Encodes CDR residues types as one-hot vectors for model input
    
    :param H_id: heavy chain ID ('None' if not available)
    :param L_id: light chain ID ('None' if not available)
    :param df: imgt numbered dataframe for specific pdb entry
    :returns: tensor (num_CDR_residues, 20) one-hot encoding for each 20 AA types
    '''
    
    # get CDR C-alpha atoms only
    df_CDRs = get_Calpha_df(df, chain_id)
    df_Calpha = get_Calpha_df(df, chain_id)
    df_CDRs = get_Calpha_df(df, chain_id)
    
    AA_unique_names = get_ordered_AA_3_letter_codes()
    AA_name_dict = {name: idx for idx, name in enumerate(AA_unique_names)}
    
    # nice names to make rest of code more understandable
    num_rows = df_Calpha.shape[0]
    num_AA = len(AA_unique_names)
    
    # convert AA name to one-hot encoding
    AA_onehot_matrix = np.zeros((num_rows, num_AA))
    
    # we will only non-zero elements where residues actually exist
    df_Calpha_not_null = df_Calpha[~df_Calpha["AA"].isna()]
    df_Calpha_not_null_indices = df_Calpha_not_null.index.values
    
    AA_onehot_matrix[df_Calpha_not_null_indices,
                     [AA_name_dict[residue] for residue in df_Calpha_not_null["AA"]]] = 1
    
    # convert from numpy to tensor
    AA_onehot_tensor = torch.tensor(AA_onehot_matrix)
    
    return AA_onehot_tensor


def get_seq_from_df(df, chainID=None):
    '''
    Get the full ordered amino acid seq for a protein chain
    
    :param df: imgt numbered dataframe for specific pdb entry
    :param chainID: chain ID of protein in pdb file
    :return: ordered list of str of all res nums in certain chain
    '''
    if chainID is None:
        df_Calpha_chain_of_interest = get_Calpha_df(df)
    else:
        df_Calpha_chain_of_interest = df[(df["Chain"]==chainID) & (df["Atom_Name"]=="CA")]
    
    amino_acids_3letter_list = df_Calpha_chain_of_interest["AA"].values.tolist()

        # Convert the 3-letter codes to 1-letter codes
    amino_acids_1letter_list = [AA_3to1.get(aa, 'X') for aa in amino_acids_3letter_list]
    
    # Join the list into a single string
    sequence = ''.join(amino_acids_1letter_list)

    return  sequence

def get_bfactor_from_df(df, chainID=None):
    '''
    Get the full ordered amino acid seq for a protein chain
    
    :param df: imgt numbered dataframe for specific pdb entry
    :param chainID: chain ID of protein in pdb file
    :return: ordered list of str of all res nums in certain chain
    '''
    df_Calpha_chain_of_interest = get_Calpha_df(df,chainID)

    # if chainID is None:
    #     df_Calpha_chain_of_interest = get_Calpha_df(df)
    # else:
    #     df_Calpha_chain_of_interest = df[(df["Chain"]==chainID) & (df["Atom_Name"]=="CA")]

    return  df_Calpha_chain_of_interest["bfactor"].values.astype(float).tolist()

def get_coors(df, chain_ids=None):
    '''
    Get CDR C-alpha atom coordinates
    
    :param H_id: heavy chain ID ('None' if not available)
    :param L_id: light chain ID ('None' if not available)
    :param df: imgt numbered dataframe for specific pdb entry
    :returns: tensor (num_CDR_residues, 3) with x, y, z coors of each atom
    '''
    
    # get CDR C-alpha atoms only
    df_CA = get_Calpha_df(df, chain_ids)
    
    # ensure coors are numbers
    df_CA["x"] = df_CA["x"].astype(float)
    df_CA["y"] = df_CA["y"].astype(float)
    df_CA["z"] = df_CA["z"].astype(float)

    # get coors as tensor
    coors = torch.tensor(df_CA[["x", "y", "z"]].values)

    return coors

def get_edge_features(df, chain_ids=None, neighbour_radius=NEIGHBOUR_RADIUS):
    '''
    Get tensor form of adjacency matrix for all CDR C-alpha atoms
    
    :param H_id: heavy chain ID ('None' if not available)
    :param L_id: light chain ID ('None' if not available)
    :param df: imgt numbered dataframe for specific pdb entry
    :param neighbour_radius: max distance in Angstroms neighbours can be
    :returns: tensor (num_CDR_residues, num_CDR_residues, 1) adj matrix 
    '''
    
    xyz_arr = get_coors(df, chain_ids).numpy()
    
    # get distances
    dist_matrix = cdist(xyz_arr, xyz_arr, 'euclidean')
    dist_tensor = torch.tensor(dist_matrix)
    
    # create adjacency matrix from distance info
    adj_matrix = torch.where(dist_tensor <= neighbour_radius, 1, 0)
    
    # remove self loops - do I want to do this???  
    adj_matrix = adj_matrix.fill_diagonal_(0, wrap=False)
    
    # adjust dimensions for model input
    adj_matrix.unsqueeze_(-1)
    
    return adj_matrix


def get_all_node_features(df, chain_ids=None):
    '''
    Get tensor features embedding Amino Acid type and corresponding chain
    for each C-alpha atom in the CDR
    
    :param H_id: heavy chain ID ('None' if not available)
    :param L_id: light chain ID ('None' if not available)
    :param df: imgt numbered dataframe for specific pdb entry
    :returns: tensor (num_CDR_residues, 76||26||22) with multi-hot encoding of selection from
              AA type (20), chain H/L (2), loop L1/.../H3 (6), and imgt num (54)
    '''

    return get_AA_onehot_features(df, chain_ids)
                        
###########################################

def adjacency_matrix_to_edge_index(adj_matrix):
    """
    Convert an adjacency matrix to an edge index representation using tensor operations.

    Args:
        adj_matrix (torch.Tensor): The adjacency matrix.

    Returns:
        torch.Tensor: The edge index representation.
    """
    # Find the indices of the non-zero elements in the adjacency matrix
    edge_index = torch.nonzero(adj_matrix, as_tuple=False).t().contiguous()
    return edge_index


def edge_index_to_adjacency_matrix(edge_index):
    """
    Convert an edge index representation to an adjacency matrix using tensor operations.

    Args:
        edge_index (torch.Tensor): The edge index representation.
        num_nodes (int): The number of nodes in the graph.

    Returns:
        torch.Tensor: The adjacency matrix.
    """
    num_nodes = torch.max(edge_index) + 1
    adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.long)
    adj_matrix[edge_index[0], edge_index[1]] = 1
    return adj_matrix




In [61]:
def process_pdb2graph(pdb_path,graph_save_path, chain=None,len_cutoff= 500, score_in_bfactor=True):
    
    # Check if pdb_path exists
    if not os.path.exists(pdb_path):
        raise FileNotFoundError(f"The PDB file path {pdb_path} does not exist.")
    
    # Check if graph_save_path exists, if not, create the directory if possible
    save_dir = os.path.dirname(graph_save_path)
    if not os.path.exists(save_dir):
        try:
            os.makedirs(save_dir)
        except Exception as e:
            raise OSError(f"Failed to create directory {save_dir}: {e}")

    pdb_df = format_pdb(pdb_path)

    coors = get_coors(pdb_df,chain).float()
    coors[coors != coors] = 0  # Replace NaNs with zeros
    feats = get_all_node_features(pdb_df,chain).float()
    edges = get_edge_features(pdb_df,chain).float()
    edge_index = adjacency_matrix_to_edge_index(edges.squeeze(-1))
    
    if score_in_bfactor:
        scores = get_bfactor_from_df(pdb_df,chain)
        y = torch.tensor(scores)
    else:
        y = None
    # print(len(y))
    # print(edge_index[0][-20:])
    # print(edge_index[1][-20:])

    # if len(y) > len_cutoff:

    #     # print("trucate ", len(y) , " to ", len_cutoff)
    #     ## trucate graph to just a specific size (not too big that cause GPU problem such as 2000nodes)
    #     feats = feats[:len_cutoff]

    #     # 2. Filter edge indices to include only edges between the first 500 nodes
    #     mask = (edge_index[0] < len_cutoff) & (edge_index[1] < len_cutoff)
    #     edge_index = edge_index[:, mask]

    #     coors = coors[:len_cutoff]
    #     y = y[:len_cutoff]

    # print(len(y))
    # print(edge_index[0][-20:])
    # print(edge_index[1][-20:])

    graph = Data(x=feats, pos=coors, edge_index=edge_index, y=y)
    # Save the graph
    torch.save(graph, graph_save_path)

    return graph


In [62]:
# # data_folder_path  = "/Users/lyanchhay/Documents/stage/Stage_AIDRUG_2024/main/data/"
# # graph_save_path = "/Users/lyanchhay/Documents/stage/Stage_AIDRUG_2024/main/data/graph/"

# data_dir  = "/users/eleves-b/2023/ly-an.chhay/main/data/"
# graph_save_dir = "/users/eleves-b/2023/ly-an.chhay/main/data/graph/"


# ## old code with all.csv 
# df = pd.read_csv(data_folder_path+"csv/all.csv")
# import ast

# for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Processing rows"):
#     code = row["code"]
    
#     graph_file_path = os.path.join(graph_save_folder, f"{code}.pt")
#     pdb_path = os.path.join(data_folder_path, "pdb", f"{code}.pdb")
    
#     # Skip processing if the graph file already exists
#     if os.path.exists(graph_file_path):
#         continue
    
#     _ = process_pdb2graph(pdb_path,graph_file_path,True)
    


In [63]:
# data_folder_path  = "/Users/lyanchhay/Documents/stage/Stage_AIDRUG_2024/main/data/"
# graph_save_path = "/Users/lyanchhay/Documents/stage/Stage_AIDRUG_2024/main/data/graph/"

data_dir  = "/users/eleves-b/2023/ly-an.chhay/main/data/"
pdb_dir = "/users/eleves-b/2023/ly-an.chhay/main/data/pdb/"
graph_dir = "/users/eleves-b/2023/ly-an.chhay/main/data/graph/"

df = pd.read_csv(data_dir+"pisces/data60_fixed_split.csv")

for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Processing rows"):
    code = row["ID"]
    # chain = row.chain
    # code = row["code"]
    
    graph_file_path = os.path.join(graph_dir, f"{code}.pt")
    pdb_path = pdb_dir+ f"{code}.pdb"
    
    # Skip processing if the graph file already exists
    if os.path.exists(graph_file_path):
        continue
    
    _ = process_pdb2graph(pdb_path,graph_file_path)


from concurrent.futures import ThreadPoolExecutor, as_completed



Processing rows: 100%|██████████| 23523/23523 [00:06<00:00, 3557.24it/s]


# Dataset dataloader

In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import  DataLoader as graphDataLoader
import pandas as pd
from torch_geometric.data import Data

class GraphDataset(Dataset):
    def __init__(self, df, graph_dir):
        self.data_frame = df.copy()
        self.codes = self.data_frame['ID'].tolist()
        self.graph_dir = graph_dir

        # Check if all graph files exist and filter the codes list accordingly
        self.codes = [code for code in self.codes if os.path.exists(f"{self.graph_dir}/{code}.pt")]

    def __len__(self):
        return len(self.codes)

    def __getitem__(self, idx):
        code_id = self.codes[idx]
        graph_path = f"{self.graph_dir}/{code_id}.pt"
        graph_data = torch.load(graph_path)
        return graph_data



In [8]:
# csv_file = '/users/eleves-b/2023/ly-an.chhay/main/data/csv/all.csv'
graph_dir = "/users/eleves-b/2023/ly-an.chhay/main/data/graph/"

df = pd.read_csv("/users/eleves-b/2023/ly-an.chhay/main/data/pisces/data60_fixed_split.csv")

train_dataset = GraphDataset(df[df.split=='train'], graph_dir)
valid_dataset = GraphDataset(df[df.split=='valid'], graph_dir)
test_dataset = GraphDataset(df[df.split=='test'], graph_dir)

train_dataloader = graphDataLoader(train_dataset, batch_size=1, shuffle=True)
valid_dataloader = graphDataLoader(valid_dataset, batch_size=1, shuffle=True)
test_dataloader = graphDataLoader(test_dataset, batch_size=1, shuffle=True)

# # Create a DataLoader
# loader = graphDataLoader(dataset, batch_size=2, shuffle=False)

In [9]:
len(train_dataloader)

18818

# Model

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from aggrepred.EGNN import EGNN


class EGNN_Model(nn.Module):
    '''
    Paragraph uses equivariant graph layers with skip connections
    '''
    def __init__(
        self,
        num_feats,
        edge_dim = 1,
        output_dim = 1,
        graph_hidden_layer_output_dims = None,
        linear_hidden_layer_output_dims = None,
        update_coors = False,
        dropout = 0.0,
        m_dim = 16
    ):
        super(EGNN_Model, self).__init__()

        self.input_dim = num_feats
        self.output_dim = output_dim
        current_dim = num_feats
        
        # these will store the different layers of out model
        self.graph_layers = nn.ModuleList()
        self.linear_layers = nn.ModuleList()
        
        # model with 1 standard EGNN and single dense layer if no architecture provided
        if graph_hidden_layer_output_dims == None: graph_hidden_layer_output_dims = [num_feats]
        if linear_hidden_layer_output_dims == None: linear_hidden_layer_output_dims = []
        
        # graph layers
        for hdim in graph_hidden_layer_output_dims:
            self.graph_layers.append(EGNN(dim = current_dim,
                                          edge_dim = edge_dim,
                                          update_coors = update_coors,
                                          dropout = dropout,
                                          m_dim = m_dim))
            current_dim = hdim
            
        # dense layers
        for hdim in linear_hidden_layer_output_dims:
            self.linear_layers.append(nn.Linear(in_features = current_dim,
                                                out_features = hdim))
            current_dim = hdim
        
        # final layer to get to per-node output
        self.linear_layers.append(nn.Linear(in_features = current_dim, out_features = output_dim))
        self.leakyrelu = nn.LeakyReLU(0.1)
        
    def forward(self, feats, coors, edges, mask=None):

        # graph layers
        for layer in self.graph_layers:
            feats = F.hardtanh(layer(feats, coors, edges, mask))
            
        # dense layers
        for layer in self.linear_layers[:-1]:
            feats = self.leakyrelu(layer(feats))
            # feats = F.hardtanh(layer(feats))
            
        # output (i.e. prediction)
        feats = self.linear_layers[-1](feats)
        
        return feats
    
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params


    

In [11]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch.nn import LayerNorm

class GATModel(torch.nn.Module):
    def __init__(self, num_feats, graph_hidden_layer_output_dims, linear_hidden_layer_output_dims, heads=4, dropout=0.2):
        """
        Initialize the model with the given parameters, including Transformer-like skip connections and layer normalization.
        
        Args:
        - num_feats (int): Number of input features (dimensionality of the input nodes).
        - graph_hidden_layer_output_dims (list of int): List of hidden dimensions for GAT layers.
        - linear_hidden_layer_output_dims (list of int): List of hidden dimensions for fully connected layers.
        - heads (int): Number of attention heads for the GAT layers. Default is 8.
        - dropout (float): Dropout rate applied between layers. Default is 0.6.
        """
        super(GATModel, self).__init__()
        
        self.gat_layers = torch.nn.ModuleList()
        self.layer_norms = torch.nn.ModuleList()
        
        # First GAT layer (input layer)
        self.gat_layers.append(GATConv(num_feats, graph_hidden_layer_output_dims[0], heads=heads))
        self.layer_norms.append(LayerNorm(graph_hidden_layer_output_dims[0] * heads))
        
        # Hidden GAT layers with Transformer-like skip connections and layer normalization
        for i in range(1, len(graph_hidden_layer_output_dims)):
            self.gat_layers.append(GATConv(graph_hidden_layer_output_dims[i-1] * heads, graph_hidden_layer_output_dims[i], heads=heads))
            self.layer_norms.append(LayerNorm(graph_hidden_layer_output_dims[i] * heads))
        
        # Linear layers for matching dimensions in residual connections
        self.match_dim1 = nn.Linear(num_feats, graph_hidden_layer_output_dims[0] * heads)
        # self.match_dim2 = nn.Linear(graph_hidden_channels * heads, graph_hidden_channels * heads)

        # Fully Connected layers
        self.fc_layers = torch.nn.ModuleList()
        self.fc_layer_norms = torch.nn.ModuleList()
        
        for i in range(len(linear_hidden_layer_output_dims)):
            if i == 0:
                self.fc_layers.append(torch.nn.Linear(graph_hidden_layer_output_dims[-1] * heads, linear_hidden_layer_output_dims[i]))
            else:
                self.fc_layers.append(torch.nn.Linear(linear_hidden_layer_output_dims[i-1], linear_hidden_layer_output_dims[i]))
            self.fc_layer_norms.append(LayerNorm(linear_hidden_layer_output_dims[i]))
        
        # Final output layer
        self.fc_layers.append(torch.nn.Linear(linear_hidden_layer_output_dims[-1], 1))  # Output is a single value (aggregation score)
        self.leakyrelu = nn.LeakyReLU(0.1)

    def forward(self, x, edge_index):
        # x, edge_index = data.x, data.edge_index
        x_residual = self.match_dim1(x)

        # Pass through GAT layers with Transformer-like skip connections and layer normalization
        for i, gat_layer in enumerate(self.gat_layers):
            
            x = gat_layer(x, edge_index)
            x = x_residual + x  # Residual connection
            x = self.layer_norms[i](x)  # Layer normalization
            x = F.hardtanh(x)

            x_residual = x  # Save input for skip connection
        
        # Pass through Fully Connected layers 
        for i, fc_layer in enumerate(self.fc_layers[:-1]):
            x = fc_layer(x)
            x = self.leakyrelu(x)
            x = F.dropout(x, p=0.5)
     
        # Final output layer (no activation)
        x = self.fc_layers[-1](x)
        
        return x


In [12]:
# Input feature size (number of input features per node, e.g., 20 for one-hot encoding)
num_feats = 20

# GCN hidden layer dimensions
graph_hidden_layer_output_dims = [20, 20, 20]

# Fully connected (linear) hidden layer dimensions
linear_hidden_layer_output_dims = [10, 10]

# Create the model
model = GATModel(
    num_feats=num_feats,
    graph_hidden_layer_output_dims=graph_hidden_layer_output_dims,
    linear_hidden_layer_output_dims=linear_hidden_layer_output_dims
)

# Check the model architecture
print(model)

GATModel(
  (gat_layers): ModuleList(
    (0): GATConv(20, 20, heads=4)
    (1-2): 2 x GATConv(80, 20, heads=4)
  )
  (layer_norms): ModuleList(
    (0-2): 3 x LayerNorm((80,), eps=1e-05, elementwise_affine=True)
  )
  (match_dim1): Linear(in_features=20, out_features=80, bias=True)
  (fc_layers): ModuleList(
    (0): Linear(in_features=80, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=1, bias=True)
  )
  (fc_layer_norms): ModuleList(
    (0-1): 2 x LayerNorm((10,), eps=1e-05, elementwise_affine=True)
  )
  (leakyrelu): LeakyReLU(negative_slope=0.1)
)


In [13]:


# class GCNModel(torch.nn.Module):
#     def __init__(self, num_feats, graph_hidden_layer_output_dims, linear_hidden_layer_output_dims):
#         """
#         Initialize the model with the given parameters.
        
#         Args:
#         - num_feats (int): Number of input features (dimensionality of the input nodes).
#         - graph_hidden_layer_output_dims (list of int): List of hidden dimensions for GCN layers.
#         - linear_hidden_layer_output_dims (list of int): List of hidden dimensions for fully connected layers.
#         """
#         super(GCNModel, self).__init__()
        
#         self.gcn_layers = torch.nn.ModuleList()
        
#         # First GCN layer (input layer)
#         self.gcn_layers.append(GCNConv(num_feats, graph_hidden_layer_output_dims[0]))
        
#         # Hidden GCN layers
#         for i in range(1, len(graph_hidden_layer_output_dims)):
#             self.gcn_layers.append(GCNConv(graph_hidden_layer_output_dims[i-1], graph_hidden_layer_output_dims[i]))
        
#         # Fully Connected layers
#         self.fc_layers = torch.nn.ModuleList()
        
#         for i in range(len(linear_hidden_layer_output_dims)):
#             if i == 0:
#                 self.fc_layers.append(torch.nn.Linear(graph_hidden_layer_output_dims[-1], linear_hidden_layer_output_dims[i]))
#             else:
#                 self.fc_layers.append(torch.nn.Linear(linear_hidden_layer_output_dims[i-1], linear_hidden_layer_output_dims[i]))
        
#         # Final output layer
#         self.fc_layers.append(torch.nn.Linear(linear_hidden_layer_output_dims[-1], 1))  # Output is a single value (aggregation score)
#         self.leakyrelu = F.leaky_relu(0.1)

#     def forward(self, x, edge_index):
#         # x, edge_index = data.x, data.edge_index
         
#         # Pass through GCN layers
#         for gcn_layer in self.gcn_layers:
#             x = gcn_layer(x, edge_index)
#             x = self.leakyrelu(x)
        
#         # Pass through Fully Connected layers
#         for fc_layer in self.fc_layers[:-1]:
#             x = fc_layer(x)
#             x = self.leakyrelu(x)
#             x = F.dropout(x, p=0.5)
        
#         # Final output layer (no activation)
#         x = self.fc_layers[-1](x)
        
#         return x


In [14]:
# # Input feature size (number of input features per node, e.g., 20 for one-hot encoding)
# num_feats = 20

# # GCN hidden layer dimensions
# graph_hidden_layer_output_dims = [20, 20, 20]

# # Fully connected (linear) hidden layer dimensions
# linear_hidden_layer_output_dims = [20, 20]

# # Create the model
# model = GCNModel(
#     num_feats=num_feats,
#     graph_hidden_layer_output_dims=graph_hidden_layer_output_dims,
#     linear_hidden_layer_output_dims=linear_hidden_layer_output_dims
# )

# # Check the model architecture
# print(model)


In [15]:
# for batch in train_dataloader:
#     print(batch)
#     break

# x , coors, edge_index , batch , y = batch.x , batch.pos, batch.edge_index, batch.batch, batch.y
# # convert edge_index to adjacent matrix, as in EGNN take adj_mat
# x = x.unsqueeze(0)
# coors = coors.unsqueeze(0)
# edges = edge_index_to_adjacency_matrix(edge_index).unsqueeze(2).unsqueeze(0)

# out = GCNModel(x,)

In [16]:
num_feats = 20
graph_hidden_layer_output_dims = [20,20,20]
linear_hidden_layer_output_dims = [10,10]

dummy_model = EGNN_Model(num_feats = num_feats,
                       graph_hidden_layer_output_dims = graph_hidden_layer_output_dims,
                       linear_hidden_layer_output_dims = linear_hidden_layer_output_dims)


trainable, non_trainable = count_parameters(dummy_model)
print(f"Number of trainable parameters: {trainable}")
print(f"Number of non-trainable parameters: {non_trainable}")

Number of trainable parameters: 22147
Number of non-trainable parameters: 0


In [56]:
for batch in train_dataloader:
    print(len(batch.seq[0]))
    print(batch.seq[0])
    print(batch.y.size())
    break

# x , coors, edge_index , batch , y = batch.x , batch.pos, batch.edge_index, batch.batch, batch.y
# # convert edge_index to adjacent matrix, as in EGNN take adj_mat
# x = x.unsqueeze(0)
# coors = coors.unsqueeze(0)
# edges = edge_index_to_adjacency_matrix(edge_index).unsqueeze(2).unsqueeze(0)


1061
PQQAPYWTHPQRMEKKLHAVPAGNTVKFRCPAAGNPTPTIRWLKDGQAFHGENRIGGIRLRHQHWSLVMESVVPSDRGTYTCLVENAVGSIRYNYLLDVLEEVQLLESGGGLVQPGGSLRLSCAASGFTFSDYYMSWIRQAPGKGLEWVSTISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARLTAYGHVDSWGQGTLVTVSSASTKGPSVFPLAPSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKRVEPKSQSVLTQPPSASGTPGQRVTISCSGSSSNIGTNTVNWYQQLPGTAPKLLIYRNYQRPSGVPDRFSGSKSGTSASLAISGLRSEDEADYYCAAWDDSLSGPHVVFGGGTKLTVLGQPKAAPSVTLFPPSSEELQANKATLVCLISDFYPGAVTVAWKADSSPVKAGVETTTPSKQSNNKYAASSYLSLTPEQWKSHRSYSCQVTHEGSTVEKTVAPTYPQQAPYWTHPQRMEKKLHAVPAGNTVKFRCPAAGNPTPTIRWLKDGQAFHGENRIGGIRLRHQHWSLVMESVVPSDRGTYTCLVENAVGSIRYNYLLDVLEVQLLESGGGLVQPGGSLRLSCAASGFTFSDYYMSWIRQAPGKGLEWVSTISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARLTAYGHVDSWGQGTLVTVSSASTKGPSVFPLAPSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKRVEPKSCVLTQPPSASGTPGQRVTISCSGSSSNIGTNTVNWYQQLPGTAPKLLIYRNYQRPSGVPDRFSGSKSGTSASLAISGLRSEDEADYYCAAWDDSLSGPHVVFGGGTKLTVLGQPKAAPSVTLFPPSSEELQANKATLVCLISDFYPGA

In [57]:
len("PQQAPYWTHPQRMEKKLHAVPAGNTVKFRCPAAGNPTPTIRWLKDGQAFHGENRIGGIRLRHQHWSLVMESVVPSDRGTYTCLVENAVGSIRYNYLLDVLEEVQLLESGGGLVQPGGSLRLSCAASGFTFSDYYMSWIRQAPGKGLEWVSTISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARLTAYGHVDSWGQGTLVTVSSASTKGPSVFPLAPSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKRVEPKSQSVLTQPPSASGTPGQRVTISCSGSSSNIGTNTVNWYQQLPGTAPKLLIYRNYQRPSGVPDRFSGSKSGTSASLAISGLRSEDEADYYCAAWDDSLSGPHVVFGGGTKLTVLGQPKAAPSVTLFPPSSEELQANKATLVCLISDFYPGAVTVAWKADSSPVKAGVETTTPSKQSNNKYAASSYLSLTPEQWKSHRSYSCQVTHEGSTVEKTVAPTYPQQAPYWTHPQRMEKKLHAVPAGNTVKFRCPAAGNPTPTIRWLKDGQAFHGENRIGGIRLRHQHWSLVMESVVPSDRGTYTCLVENAVGSIRYNYLLDVLEVQLLESGGGLVQPGGSLRLSCAASGFTFSDYYMSWIRQAPGKGLEWVSTISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARLTAYGHVDSWGQGTLVTVSSASTKGPSVFPLAPSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKRVEPKSCVLTQPPSASGTPGQRVTISCSGSSSNIGTNTVNWYQQLPGTAPKLLIYRNYQRPSGVPDRFSGSKSGTSASLAISGLRSEDEADYYCAAWDDSLSGPHVVFGGGTKLTVLGQPKAAPSVTLFPPSSEELQANKATLVCLISDFYPGAVTVAWKADSSPVKAGVETTTPSKQSNNKYAASSYLSLTPEQWKSHRSYSCQVTHEGSTVEKTVAPT")

1061

In [18]:
print('size of x:', x.size())
print('size of coors:', coors.size())
print('size of edges:', edges.size())

size of x: torch.Size([1, 255, 20])
size of coors: torch.Size([1, 255, 3])
size of edges: torch.Size([1, 255, 255, 1])


# Trainer
## config

In [64]:
# ----------------
# PARAM
# ----------------


# Define the configuration dictionary with all the model parameters
path = "./weights_antibody/graph/(onehot)_(3EGNN)/"

config = {
    "model": 'EGNN',
    "num_feats" : 21,
    "graph_hidden_layer_output_dims" : [21,21,21],
    "linear_hidden_layer_output_dims" : [20,10],
    "learning_rate": 1e-5,
    "batch_size": 1,
    "nb_epochs": 20,
    "encode_mode" : 'onehot'
}


# ----------------
#  MODEL 
# ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

if config['model']=='GCN':
    model = GCNModel(num_feats = config["num_feats"],
                       graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                       linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])
elif config['model']=='GAT':
    model = GATModel(num_feats = config["num_feats"],
                       graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                       linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])
else:
    model = EGNN_Model(num_feats = config["num_feats"],
                       graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                       linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])



# ----------------
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params

print(path)
print(model)
trainable, non_trainable = count_parameters(model)
print(f"Number of trainable parameters: {trainable}")
print(f"Number of non-trainable parameters: {non_trainable}")

cuda
./weights_antibody/graph/(onehot)_(3EGNN)/
EGNN_Model(
  (graph_layers): ModuleList(
    (0-2): 3 x EGNN(
      (edge_mlp): Sequential(
        (0): Linear(in_features=44, out_features=88, bias=True)
        (1): Identity()
        (2): SiLU()
        (3): Linear(in_features=88, out_features=16, bias=True)
        (4): SiLU()
      )
      (node_norm): Identity()
      (coors_norm): Identity()
      (node_mlp): Sequential(
        (0): Linear(in_features=37, out_features=42, bias=True)
        (1): Identity()
        (2): SiLU()
        (3): Linear(in_features=42, out_features=21, bias=True)
      )
    )
  )
  (linear_layers): ModuleList(
    (0): Linear(in_features=21, out_features=20, bias=True)
    (1): Linear(in_features=20, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=1, bias=True)
  )
  (leakyrelu): LeakyReLU(negative_slope=0.1)
)
Number of trainable parameters: 24310
Number of non-trainable parameters: 0


In [65]:
# ----------------
#   OPTIMIZER 
# ----------------
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])


# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


class CombinedLoss(nn.Module):
    def __init__(self, lambda_reg=1.0, lambda_bin=1.0, pos_weight=None):
        super(CombinedLoss, self).__init__()
        self.lambda_reg = lambda_reg
        self.lambda_bin = lambda_bin
        self.mse_loss = nn.MSELoss()  # Regression Loss (MSE)
        
        if pos_weight is not None:
            # Binary Classification Loss (Weighted BCE with logits)
            self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        else:
            self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, outputs, regression_targets):
        # Calculate regression loss
        reg_loss = self.mse_loss(outputs, regression_targets)
        
        # Calculate binary classification loss
        # Convert regression output to binary labels (logits) for classification
        binary_targets = (regression_targets> 0).float()
        bin_loss = self.bce_loss(outputs, binary_targets)
        
        # Combined weighted loss
        total_loss = self.lambda_reg * reg_loss + self.lambda_bin * bin_loss
        return total_loss

combined_loss = CombinedLoss(lambda_reg=0.7, lambda_bin=0.3, pos_weight=4.0)



## dataloader

In [66]:
# csv_file = '/users/eleves-b/2023/ly-an.chhay/main/data/csv/all.csv'
graph_dir = "/users/eleves-b/2023/ly-an.chhay/main/data/graph_antibody/"

df = pd.read_csv("/users/eleves-b/2023/ly-an.chhay/main/data/pisces/antibody.csv").sample(frac=0.10, random_state=42)

train_dataset = GraphDataset(df[df.split=='train'], graph_dir)
valid_dataset = GraphDataset(df[df.split=='valid'], graph_dir)
test_dataset = GraphDataset(df[df.split=='test'], graph_dir)

train_dataloader = graphDataLoader(train_dataset, batch_size=1, shuffle=True)
valid_dataloader = graphDataLoader(valid_dataset, batch_size=1, shuffle=True)
test_dataloader = graphDataLoader(test_dataset, batch_size=1, shuffle=True)

# if config['encode_mode'] not in ['onehot', 'onehot_meiler']:
#     print("yes")
#     train_dataset = GraphDataset(df[df.split=='train'].sample(frac=0.10, random_state=42), graph_dir)
#     valid_dataset = GraphDataset(df[df.split=='valid'].sample(frac=0.10, random_state=42), graph_dir)
#     test_dataset = GraphDataset(df[df.split=='test'].sample(frac=0.10, random_state=42), graph_dir)

#     train_dataloader = graphDataLoader(train_dataset, batch_size=1, shuffle=True)
#     valid_dataloader = graphDataLoader(valid_dataset, batch_size=1, shuffle=True)
#     test_dataloader = graphDataLoader(test_dataset, batch_size=1, shuffle=True)

# else:
#     train_dataset = GraphDataset(df[df.split=='train'], graph_dir)
#     valid_dataset = GraphDataset(df[df.split=='valid'], graph_dir)
#     test_dataset = GraphDataset(df[df.split=='test'], graph_dir)

#     train_dataloader = graphDataLoader(train_dataset, batch_size=1, shuffle=True)
#     valid_dataloader = graphDataLoader(valid_dataset, batch_size=1, shuffle=True)
#     test_dataloader = graphDataLoader(test_dataset, batch_size=1, shuffle=True)


In [77]:
import os
import torch
import time
import logging
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score,accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, matthews_corrcoef
from scipy.stats import pearsonr



def format_time(seconds):
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{minutes}m {seconds}s"

def train_epoch(model, optimizer, dataloader,encode_mode='onehot', device= 'cuda', printEvery=50):
    model.train()
    total_loss = 0.0
    count_iter = 0
    start_time = time.time()
    epoch_start_time = start_time

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    # protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    # protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')


    with tqdm(total=len(dataloader), desc='Training', unit='batch') as pbar:
        for idx, batch in enumerate(dataloader):
            # x, coors, edge_index, batch, y = batch.x, batch.pos, batch.edge_index, batch.batch, batch.y
            batch_sequences, coors, edge_index, batch, y = batch.seq, batch.pos, batch.edge_index, batch.batch, batch.y
            
            if config['model']== 'EGNN':
                ## different encoding here
                if encode_mode == 'esm':
                    x = embed_esm_batch(batch_sequences, esm_model, alphabet).to(device)
                elif encode_mode == 'protbert':
                    x = embed_protbert_batch(batch_sequences, protbert_model, protbert_tokenizer).to(device)
                elif encode_mode == 'onehot':
                    x = onehot_encode_batch(batch_sequences,len(batch_sequences[0])).to(device)
                else:
                    x = onehot_meiler_encode_batch(batch_sequences,len(batch_sequences[0])).to(device)

                

                # x = x.unsqueeze(0).to(device)
                edge_index = edge_index.to(device)
                coors = coors.unsqueeze(0).to(device)
                edges = edge_index_to_adjacency_matrix(edge_index).unsqueeze(2).unsqueeze(0).to(device)

                print(x.size())
                print(edges.size())
                print(coors.size())
                print(y.size())


                out = model(x, coors, edges).squeeze()
            else:
                x = x.to(device)
                edge_index = edge_index.to(device)
                out = model(x, edge_index).squeeze()

            current_loss = combined_loss(out, y.to(device))
            
            optimizer.zero_grad()
            current_loss.backward()
            optimizer.step()

            total_loss += current_loss.item()
            
            count_iter += 1
            if count_iter % printEvery == 0:
                elapsed_time = time.time() - start_time
                remaining_time = (elapsed_time / count_iter) * (len(dataloader) - count_iter)
                print(f"Iteration: {count_iter}, Time: {format_time(elapsed_time)}, Remaining: {format_time(remaining_time)}, Training Loss: {total_loss / count_iter:.4f}")
                start_time = time.time()
            
            #remove cache to save GPU
            torch.cuda.empty_cache()
            pbar.update(1)
            

    epoch_time = time.time() - epoch_start_time
    print(f"==> Average Training loss: mse ={total_loss / len(dataloader)}")
    print(f"==> Epoch Training Time: {format_time(epoch_time)}")
    print(f"================================================================\n")

    return total_loss / len(dataloader)


def evaluate(model, dataloader,encode_mode='onehot', device='cuda', mode='valid'):
    model.eval()
    total_loss = 0.0
    
    predictions = []
    targets = []
    binary_predictions = []
    binary_targets = []

    # esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    esm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    # esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_model = esm_model.eval()  # Set the model to evaluation mode

    # protbert_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    # protbert_model = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda')


    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            
            batch_sequences, coors, edge_index, batch, y = batch.seq, batch.pos, batch.edge_index, batch.batch, batch.y
        
            if config['model']== 'EGNN':
                # x = x.unsqueeze(0).to(device)
                ## different encoding here
                if encode_mode == 'esm':
                    x = embed_esm_batch(batch_sequences, esm_model, alphabet).to(device)
                elif encode_mode == 'protbert':
                    x = embed_protbert_batch(batch_sequences, protbert_model, protbert_tokenizer).to(device)
                elif encode_mode == 'onehot':
                    x = onehot_encode_batch(batch_sequences,len(batch_sequences[0])).to(device)
                else:
                    x = onehot_meiler_encode_batch(batch_sequences,len(batch_sequences[0])).to(device)

                edge_index = edge_index.to(device)
                coors = coors.unsqueeze(0).to(device)
                edges = edge_index_to_adjacency_matrix(edge_index).unsqueeze(2).unsqueeze(0).to(device)
                out = model(x, coors, edges).squeeze()
            else:
                x = x.to(device)
                edge_index = edge_index.to(device)
                out = model(x, edge_index).squeeze()

            current_loss = combined_loss(out, y.to(device))
            # current_loss = weighted_bce_loss(out, (y>0).float().to(device)) + mse_loss(out, y.to(device))
            
            total_loss += current_loss.item()
       

            #append to list of all preds
            predictions.append(out.cpu().numpy())
            targets.append(y.cpu().numpy())
            
            ## Convert regression targets to binary labels
            y_bin = (y.cpu().numpy() > 0).astype(int)
            out_bin = (out.cpu().numpy() > 0).astype(int)
            
            binary_predictions.append(out_bin)
            binary_targets.append(y_bin)

    # if mode == 'test':
    all_predictions = np.concatenate(predictions, axis=0)
    all_targets = np.concatenate(targets, axis=0)
    all_binary_predictions = np.concatenate(binary_predictions, axis=0)
    all_binary_targets = np.concatenate(binary_targets, axis=0)

    # Calculate overall metrics
    overall_mse = mean_squared_error(all_targets, all_predictions)
    overall_rmse = np.sqrt(overall_mse)
    overall_mae = mean_absolute_error(all_targets, all_predictions)
    overall_r2 = r2_score(all_targets, all_predictions)
    overall_pcc, _ = pearsonr(all_targets.flatten(), all_predictions.flatten())

    # Calculate binary classification metrics
    overall_accuracy = accuracy_score(all_binary_targets, all_binary_predictions)
    overall_precision = precision_score(all_binary_targets, all_binary_predictions)
    overall_recall = recall_score(all_binary_targets, all_binary_predictions)
    overall_f1 = f1_score(all_binary_targets, all_binary_predictions)
    overall_auc_roc = roc_auc_score(all_binary_targets, all_predictions)
    overall_auc_pr = average_precision_score(all_binary_targets, all_predictions)
    overall_mcc = matthews_corrcoef(all_binary_targets, all_binary_predictions)

    print(f"Overall Reg Metrics - MSE: {overall_mse:.4f}, RMSE: {overall_rmse:.4f}, MAE: {overall_mae:.4f}, R2: {overall_r2:.4f}, PCC: {overall_pcc:.4f}")
    
    print(f"Overall Classification Metrics - Accuracy: {overall_accuracy:.4f}, Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1-Score: {overall_f1:.4f}, AUC-ROC: {overall_auc_roc:.4f}, AUC-PR: {overall_auc_pr:.4f}, MCC: {overall_mcc:.4f}")
    metrics = {
        "Regression Metrics": {
            "MSE": round(float(overall_mse), 4),
            "RMSE": round(float(overall_rmse), 4),
            "MAE": round(float(overall_mae), 4),
            "R2": round(float(overall_r2), 4),
            "PCC": round(float(overall_pcc), 4)
        },
        "Classification Metrics": {
            "Accuracy": round(float(overall_accuracy), 4),
            "Precision": round(float(overall_precision), 4),
            "Recall": round(float(overall_recall), 4),
            "F1-Score": round(float(overall_f1), 4),
            "AUC-ROC": round(float(overall_auc_roc), 4),
            "AUC-PR": round(float(overall_auc_pr), 4),
            "MCC": round(float(overall_mcc), 4)
        }
    }

    return total_loss / len(dataloader),metrics, predictions, targets

def train_loop(model, optimizer, train_dataloader, valid_dataloader, nb_epochs,encode_mode='onehot_meiler', device= 'cuda', save_directory='./weights/'):
    
    start_epoch = 1
    best_validation_loss = float('inf')
    early_stopping_counter = 0

    # Paths for saving losses and metrics
    loss_output_path = os.path.join(save_directory, 'losses.json')
    metric_output_path = os.path.join(save_directory, 'metrics.json')
    
    # Initialize lists for losses
    train_losses = []
    val_losses = []
    
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
        print(f'Created directory: {save_directory}')

    checkpoint_path = os.path.join(save_directory, 'model_last.pt')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        
        # # Load losses from the losses.json file if it exists
        # if os.path.exists(loss_output_path):
        #     with open(loss_output_path, 'r') as f:
        #         losses = json.load(f)
        #         train_losses = losses.get('train_losses', [])
        #         val_losses = losses.get('val_losses', [])
        #     print(f'Loaded losses from {loss_output_path}.')
        #     print(train_losses)
        #     print(val_losses)
        # else:
        #     print(f'No losses file found at {loss_output_path}.')

    else:
        print('No checkpoint found. Starting from beginning.')
    
    # print(model)
    # model.to(device)


    
    # Load existing losses if available
    if os.path.exists(loss_output_path):
        with open(loss_output_path, 'r') as json_file:
            existing_losses = json.load(json_file)
            train_losses = existing_losses.get('train_losses', [])
            val_losses = existing_losses.get('val_losses', [])
            print(train_losses)
            print(val_losses)

    for epoch in range(start_epoch, nb_epochs + 1):
        print("==================================================================================")
        print(f'                            -----EPOCH {epoch}-----')
        print("==================================================================================")
        
        train_loss = train_epoch(model, optimizer, train_dataloader,encode_mode,device, printEvery=1000)
        train_losses.append(train_loss)
        
        # # **Print Gradients**
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         print(f'Gradient - {name}: {param.grad.norm()}')  # Prints the norm of gradients

        print("==========================VALIDATION===============================================")
        val_loss ,metrics, _ , _ = evaluate(model, valid_dataloader,encode_mode,device)
        val_losses.append(val_loss)

        print(f'==> Epoch {epoch} - Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        if val_loss < best_validation_loss:
            early_stopping_counter = 0
            best_validation_loss = val_loss
            best_model_save_path = os.path.join(save_directory, 'model_best.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'validation_accuracy': val_loss,
            }, best_model_save_path)
            print('\n')
            print(f'Best model checkpoint saved to: {best_model_save_path}')

            # Save metrics of the best model
            with open(metric_output_path, 'w') as json_file:
                json.dump(metrics, json_file, indent=4)
        
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= 3:
                print("\n==> Early stopping triggered. No improvement in validation loss for 3 epochs.")
                break

        last_model_save_path = os.path.join(save_directory, 'model_last.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'validation_accuracy': val_loss,
        }, last_model_save_path)
        print(f'Last epoch model saved to: {last_model_save_path}')
        print("==================================================================================\n")
    
        # Save updated losses to the JSON file
        losses = {
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        with open(loss_output_path, 'w') as json_file:
            json.dump(losses, json_file, indent=4)
        print(f'Losses updated and saved to: {loss_output_path}')
        
    return

   

## Train here

In [78]:
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), 'w') as json_file:
    json.dump(config, json_file, indent=4)

model.to(device)
train_loop(model,optimizer,train_dataloader,valid_dataloader,50,config['encode_mode'] ,device,path)

Loaded checkpoint from ./weights_antibody/graph/(onehot)_(3EGNN)/model_last.pt. Resuming from epoch 13
[21.188981015563467, 21.187766279754054, 21.186478816686463, 21.1860033426705, 21.18563679077616, 21.18391943731527, 21.1831734214706, 21.182343698096002, 21.182362777291587, 21.18133555358397, 21.18088679363901, 21.180064271236287]
[26.907335754363768, 26.906338964739152, 26.906807603374606, 26.90265687627177, 26.90422487258911, 26.903301119804382, 26.9002626557504, 26.900180587845465, 26.89686845771728, 26.89861907497529, 26.898993732467776, 26.900196242717005]
                            -----EPOCH 13-----


Training:   0%|          | 1/261 [00:00<00:30,  8.52batch/s]

torch.Size([1, 426, 21])
torch.Size([1, 426, 426, 1])
torch.Size([1, 426, 3])
torch.Size([426])
torch.Size([1, 428, 21])
torch.Size([1, 428, 428, 1])
torch.Size([1, 428, 3])
torch.Size([428])
torch.Size([1, 430, 21])
torch.Size([1, 430, 430, 1])
torch.Size([1, 430, 3])
torch.Size([430])
torch.Size([1, 430, 21])
torch.Size([1, 430, 430, 1])
torch.Size([1, 430, 3])
torch.Size([430])
torch.Size([1, 427, 21])
torch.Size([1, 427, 427, 1])
torch.Size([1, 427, 3])
torch.Size([427])


Training:   3%|▎         | 9/261 [00:00<00:09, 27.94batch/s]

torch.Size([1, 442, 21])
torch.Size([1, 442, 442, 1])
torch.Size([1, 442, 3])
torch.Size([442])
torch.Size([1, 444, 21])
torch.Size([1, 444, 444, 1])
torch.Size([1, 444, 3])
torch.Size([444])
torch.Size([1, 427, 21])
torch.Size([1, 427, 427, 1])
torch.Size([1, 427, 3])
torch.Size([427])
torch.Size([1, 423, 21])
torch.Size([1, 423, 423, 1])
torch.Size([1, 423, 3])
torch.Size([423])
torch.Size([1, 429, 21])
torch.Size([1, 429, 429, 1])
torch.Size([1, 429, 3])
torch.Size([429])
torch.Size([1, 240, 21])
torch.Size([1, 240, 240, 1])
torch.Size([1, 240, 3])
torch.Size([240])
torch.Size([1, 428, 21])
torch.Size([1, 428, 428, 1])
torch.Size([1, 428, 3])
torch.Size([428])
torch.Size([1, 433, 21])
torch.Size([1, 433, 433, 1])
torch.Size([1, 433, 3])
torch.Size([433])


Training:   7%|▋         | 17/261 [00:00<00:07, 33.13batch/s]

torch.Size([1, 454, 21])
torch.Size([1, 454, 454, 1])
torch.Size([1, 454, 3])
torch.Size([454])
torch.Size([1, 433, 21])
torch.Size([1, 433, 433, 1])
torch.Size([1, 433, 3])
torch.Size([433])
torch.Size([1, 435, 21])
torch.Size([1, 435, 435, 1])
torch.Size([1, 435, 3])
torch.Size([435])
torch.Size([1, 227, 21])
torch.Size([1, 227, 227, 1])
torch.Size([1, 227, 3])
torch.Size([227])
torch.Size([1, 441, 21])
torch.Size([1, 441, 441, 1])
torch.Size([1, 441, 3])
torch.Size([441])
torch.Size([1, 433, 21])
torch.Size([1, 433, 433, 1])
torch.Size([1, 433, 3])
torch.Size([433])
torch.Size([1, 433, 21])
torch.Size([1, 433, 433, 1])
torch.Size([1, 433, 3])
torch.Size([433])
torch.Size([1, 434, 21])
torch.Size([1, 434, 434, 1])
torch.Size([1, 434, 3])
torch.Size([434])


Training:  10%|▉         | 25/261 [00:00<00:07, 32.24batch/s]

torch.Size([1, 433, 21])
torch.Size([1, 433, 433, 1])
torch.Size([1, 433, 3])
torch.Size([433])
torch.Size([1, 431, 21])
torch.Size([1, 431, 431, 1])
torch.Size([1, 431, 3])
torch.Size([431])
torch.Size([1, 437, 21])
torch.Size([1, 437, 437, 1])
torch.Size([1, 437, 3])
torch.Size([437])
torch.Size([1, 433, 21])
torch.Size([1, 433, 433, 1])
torch.Size([1, 433, 3])
torch.Size([433])
torch.Size([1, 430, 21])
torch.Size([1, 430, 430, 1])
torch.Size([1, 430, 3])
torch.Size([430])
torch.Size([1, 427, 21])
torch.Size([1, 427, 427, 1])
torch.Size([1, 427, 3])
torch.Size([427])
torch.Size([1, 426, 21])
torch.Size([1, 426, 426, 1])
torch.Size([1, 426, 3])
torch.Size([426])


Training:  13%|█▎        | 33/261 [00:01<00:06, 33.02batch/s]

torch.Size([1, 436, 21])
torch.Size([1, 436, 436, 1])
torch.Size([1, 436, 3])
torch.Size([436])
torch.Size([1, 230, 21])
torch.Size([1, 230, 230, 1])
torch.Size([1, 230, 3])
torch.Size([230])
torch.Size([1, 434, 21])
torch.Size([1, 434, 434, 1])
torch.Size([1, 434, 3])
torch.Size([434])
torch.Size([1, 444, 21])
torch.Size([1, 444, 444, 1])
torch.Size([1, 444, 3])
torch.Size([444])
torch.Size([1, 425, 21])
torch.Size([1, 425, 425, 1])
torch.Size([1, 425, 3])
torch.Size([425])
torch.Size([1, 424, 21])
torch.Size([1, 424, 424, 1])
torch.Size([1, 424, 3])
torch.Size([424])
torch.Size([1, 427, 21])
torch.Size([1, 427, 427, 1])
torch.Size([1, 427, 3])
torch.Size([427])


Training:  16%|█▌        | 41/261 [00:01<00:07, 30.78batch/s]

torch.Size([1, 439, 21])
torch.Size([1, 439, 439, 1])
torch.Size([1, 439, 3])
torch.Size([439])
torch.Size([1, 443, 21])
torch.Size([1, 443, 443, 1])
torch.Size([1, 443, 3])
torch.Size([443])
torch.Size([1, 454, 21])
torch.Size([1, 454, 454, 1])
torch.Size([1, 454, 3])
torch.Size([454])
torch.Size([1, 424, 21])
torch.Size([1, 424, 424, 1])
torch.Size([1, 424, 3])
torch.Size([424])
torch.Size([1, 428, 21])
torch.Size([1, 428, 428, 1])
torch.Size([1, 428, 3])
torch.Size([428])
torch.Size([1, 434, 21])
torch.Size([1, 434, 434, 1])
torch.Size([1, 434, 3])
torch.Size([434])
torch.Size([1, 435, 21])
torch.Size([1, 435, 435, 1])
torch.Size([1, 435, 3])
torch.Size([435])


Training:  17%|█▋        | 45/261 [00:01<00:06, 30.94batch/s]

torch.Size([1, 437, 21])
torch.Size([1, 437, 437, 1])
torch.Size([1, 437, 3])
torch.Size([437])
torch.Size([1, 437, 21])
torch.Size([1, 437, 437, 1])
torch.Size([1, 437, 3])
torch.Size([437])
torch.Size([1, 434, 21])
torch.Size([1, 434, 434, 1])
torch.Size([1, 434, 3])
torch.Size([434])
torch.Size([1, 430, 21])
torch.Size([1, 430, 430, 1])
torch.Size([1, 430, 3])
torch.Size([430])
torch.Size([1, 424, 21])
torch.Size([1, 424, 424, 1])
torch.Size([1, 424, 3])
torch.Size([424])
torch.Size([1, 432, 21])
torch.Size([1, 432, 432, 1])
torch.Size([1, 432, 3])
torch.Size([432])
torch.Size([1, 432, 21])
torch.Size([1, 432, 432, 1])
torch.Size([1, 432, 3])
torch.Size([432])


Training:  20%|██        | 53/261 [00:01<00:06, 30.98batch/s]

torch.Size([1, 429, 21])
torch.Size([1, 429, 429, 1])
torch.Size([1, 429, 3])
torch.Size([429])
torch.Size([1, 420, 21])
torch.Size([1, 420, 420, 1])
torch.Size([1, 420, 3])
torch.Size([420])
torch.Size([1, 439, 21])
torch.Size([1, 439, 439, 1])
torch.Size([1, 439, 3])
torch.Size([439])
torch.Size([1, 432, 21])
torch.Size([1, 432, 432, 1])
torch.Size([1, 432, 3])
torch.Size([432])
torch.Size([1, 438, 21])
torch.Size([1, 438, 438, 1])
torch.Size([1, 438, 3])
torch.Size([438])
torch.Size([1, 426, 21])
torch.Size([1, 426, 426, 1])
torch.Size([1, 426, 3])
torch.Size([426])
torch.Size([1, 429, 21])
torch.Size([1, 429, 429, 1])
torch.Size([1, 429, 3])
torch.Size([429])


Training:  23%|██▎       | 61/261 [00:01<00:06, 31.54batch/s]

torch.Size([1, 429, 21])
torch.Size([1, 429, 429, 1])
torch.Size([1, 429, 3])
torch.Size([429])
torch.Size([1, 439, 21])
torch.Size([1, 439, 439, 1])
torch.Size([1, 439, 3])
torch.Size([439])
torch.Size([1, 422, 21])
torch.Size([1, 422, 422, 1])
torch.Size([1, 422, 3])
torch.Size([422])
torch.Size([1, 440, 21])
torch.Size([1, 440, 440, 1])
torch.Size([1, 440, 3])
torch.Size([440])
torch.Size([1, 326, 21])
torch.Size([1, 326, 326, 1])
torch.Size([1, 326, 3])
torch.Size([326])
torch.Size([1, 425, 21])
torch.Size([1, 425, 425, 1])
torch.Size([1, 425, 3])
torch.Size([425])
torch.Size([1, 439, 21])
torch.Size([1, 439, 439, 1])
torch.Size([1, 439, 3])
torch.Size([439])


Training:  25%|██▍       | 65/261 [00:02<00:06, 31.36batch/s]

torch.Size([1, 429, 21])
torch.Size([1, 429, 429, 1])
torch.Size([1, 429, 3])
torch.Size([429])
torch.Size([1, 434, 21])
torch.Size([1, 434, 434, 1])
torch.Size([1, 434, 3])
torch.Size([434])
torch.Size([1, 441, 21])
torch.Size([1, 441, 441, 1])
torch.Size([1, 441, 3])
torch.Size([441])
torch.Size([1, 439, 21])
torch.Size([1, 439, 439, 1])
torch.Size([1, 439, 3])
torch.Size([439])
torch.Size([1, 452, 21])
torch.Size([1, 452, 452, 1])
torch.Size([1, 452, 3])
torch.Size([452])


Training:  28%|██▊       | 72/261 [00:02<00:06, 29.44batch/s]

torch.Size([1, 427, 21])
torch.Size([1, 427, 427, 1])
torch.Size([1, 427, 3])
torch.Size([427])
torch.Size([1, 433, 21])
torch.Size([1, 433, 433, 1])
torch.Size([1, 433, 3])
torch.Size([433])
torch.Size([1, 432, 21])
torch.Size([1, 432, 432, 1])
torch.Size([1, 432, 3])
torch.Size([432])
torch.Size([1, 436, 21])
torch.Size([1, 436, 436, 1])
torch.Size([1, 436, 3])
torch.Size([436])
torch.Size([1, 436, 21])
torch.Size([1, 436, 436, 1])
torch.Size([1, 436, 3])
torch.Size([436])





KeyboardInterrupt: 

# load best model and test on test-set

In [None]:

def load_model_from_checkpoint(model, optimizer, checkpoint_path, device):
    """
    Loads the model and optimizer state from a checkpoint if it exists.
    
    Args:
    - model (torch.nn.Module): The model to load the state into.
    - optimizer (torch.optim.Optimizer): The optimizer to load the state into.
    - checkpoint_path (str): Path to the checkpoint file.
    - device (torch.device): Device to which the model should be moved.
    
    Returns:
    - start_epoch (int): The epoch to start training from.
    - best_validation_loss (float): The best validation loss recorded in the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}. Resuming from epoch {start_epoch}')
        # print(f'Best validation loss: {best_validation_loss}')
    else:
        start_epoch = 0
        best_validation_loss = float('inf')  # Assuming lower is better for validation loss
        print('No checkpoint found.')
    
    model = model.to(device)
    return start_epoch, best_validation_loss




In [None]:

# List of model paths
model_paths = [
    # "./weights/graph/(onehot)_(regloss)_(3EGNN)/",
    "./weights/graph/(onehot)_(regloss)_(3GAT)/",
    # "./weights/graph/(onehot)_(regloss)_(3GCN)/"
  
]

for path in model_paths:
    # Load the config for the current model
    with open(path + 'config.json', 'r') as json_file:
        config = json.load(json_file)


    # Initialize the model
    if config['model']=='GCN':
        model = GCNModel(num_feats = config["num_feats"],
                       graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                       linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])
    elif config['model']=='GAT':
        model = GATModel(num_feats = config["num_feats"],
                        graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                        linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])
    else:
        model = EGNN_Model(num_feats = config["num_feats"],
                        graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                        linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])


    # Load the model weights from the checkpoint
    _, _ = load_model_from_checkpoint(model, optimizer, path + 'model_best.pt', device)

    # Evaluate the model
    loss, metrics, preds, tar = evaluate(model, test_dataloader ,device)

    # Save metrics of the best model
    with open(path + 'result.json', 'w') as json_file:
        json.dump(metrics, json_file, indent=4)

    print(f"Processed model in path: {path}")

Loaded checkpoint from ./weights/graph/(onehot)_(regloss)_(3GAT)/model_best.pt. Resuming from epoch 31
Overall Reg Metrics - MSE: 0.4989, RMSE: 0.7063, MAE: 0.5328, R2: 0.7734, PCC: 0.8799
Overall Classification Metrics - Accuracy: 0.8632, Precision: 0.6199, Recall: 0.7677, F1-Score: 0.6859, AUC-ROC: 0.9215, AUC-PR: 0.7733, MCC: 0.6053
Processed model in path: ./weights/graph/(onehot)_(regloss)_(3GAT)/


In [None]:
def define_load_graph_model(weight_path,device='cuda'):
    config_path = os.path.join(weight_path, "config.json")

    # Read the JSON file back into a dictionary
    with open(config_path, 'r') as json_file:
        config = json.load(json_file)
    
    model = EGNN_Model(num_feats = config["num_feats"],
                        graph_hidden_layer_output_dims = config["graph_hidden_layer_output_dims"],
                        linear_hidden_layer_output_dims = config["linear_hidden_layer_output_dims"])


    checkpoint_path = os.path.join(weight_path, "model_best.pt")

    # checkpoint_path = os.path.abspath('../aggrepred/weights/both_loss_dif_layer/model_best.pt')

    print(checkpoint_path)
  
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print('Loading model succesfully')
    else:
        print('No model found')

    return model

In [None]:
graph_weight_path = "/users/eleves-b/2023/ly-an.chhay/main/aggrepred/weights/graph/(onehot)_(regloss)_(3EGNN)"
##################################################################
graph_model = define_load_graph_model(graph_weight_path)
graph_model = graph_model.to(device)
# print(graph_model)

torch.cuda.empty_cache() 
loss, metrics, preds, tar = evaluate(graph_model, test_dataloader ,device)

# # Save metrics of the best model
# with open(path + 'result.json', 'w') as json_file:
#     json.dump(metrics, json_file, indent=4)

# print(f"Processed model in path: {path}")

/users/eleves-b/2023/ly-an.chhay/main/aggrepred/weights/graph/(onehot)_(regloss)_(3EGNN)/model_best.pt
Loading model succesfully


TypeError: EGNN_Model.forward() missing 1 required positional argument: 'edges'

In [None]:
g= torch.load("/users/eleves-b/2023/ly-an.chhay/main/application/tmp/g.pt")
dataloader = graphDataLoader([g], batch_size=1, shuffle=False)
            

In [None]:
g

Data(x=[1239, 20], edge_index=[2, 25010], y=[1239], pos=[1239, 3])

In [None]:
num_feats = 20
graph_hidden_layer_output_dims = [20,20,20,20,20,20]
linear_hidden_layer_output_dims = [10,10]


best_pre_trained_model = EGNN_Model(num_feats = num_feats,
                       graph_hidden_layer_output_dims = graph_hidden_layer_output_dims,
                       linear_hidden_layer_output_dims = linear_hidden_layer_output_dims)

weight_dir = "/users/eleves-b/2023/ly-an.chhay/main/aggrepred/weights/"

checkpoint_path = weight_dir+'graph/6egnn_3FC_bothloss/model_best.pt'


def load_model(model, checkpoint_path,device):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        model.to(device)
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}')
    else:
        print('No checkpoint found')
    
    return model

best_pre_trained_model = load_model(best_pre_trained_model,checkpoint_path,device)

# best_pre_trained_model.to(device)
test_out, preds, trues = evaluate(best_pre_trained_model,test_dataloader,device)

In [None]:
num_feats = 20
graph_hidden_layer_output_dims = [20,20,20]
linear_hidden_layer_output_dims = [10,10]



best_pre_trained_model = EGNN_Model(num_feats = num_feats,
                       graph_hidden_layer_output_dims = graph_hidden_layer_output_dims,
                       linear_hidden_layer_output_dims = linear_hidden_layer_output_dims)

weight_dir = "/users/eleves-b/2023/ly-an.chhay/main/aggrepred/weights/"

checkpoint_path = weight_dir+'graph/3egnn_3FC_bothloss/model_best.pt'


def load_model(model, checkpoint_path,device):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        model.to(device)
        best_validation_loss = checkpoint['validation_accuracy']
        print(f'Loaded checkpoint from {checkpoint_path}')
    else:
        print('No checkpoint found')
    
    return model

best_pre_trained_model = load_model(best_pre_trained_model,checkpoint_path,device)

# best_pre_trained_model.to(device)
test_out, metric, preds, trues = evaluate(best_pre_trained_model,test_dataloader,device)