# Load Libraries

In [1]:
# Standard libraries
import os
import json
import pickle
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') # Known issue with PyTorch and DGL
import IPython

# Data handling
import pandas as pd
import numpy as np
from scipy.spatial.distance import pdist, squareform

# Machine learning and model evaluation
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score, roc_curve, auc, confusion_matrix, classification_report
from sklearn.cluster import HDBSCAN

# Neural Networks and Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import sigmoid
import torch.optim as optim

# Optuna for hyperparameter optimization
import optuna
from optuna.pruners import MedianPruner, HyperbandPruner

# Graph Neural Networks
import dgl
import dgl.nn as dglnn
from dgl import batch
from dgl.data.utils import save_graphs, load_graphs
from dgl.nn import GATConv, GATv2Conv
from dgl import max_nodes

# Cheminformatics
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D
from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from chembl_structure_pipeline import standardizer

# Network analysis
import networkx as nx

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns



check if GPU is aviable and the device

In [2]:
import torch
print("PyTorch version:", torch.__version__)
print("Is CUDA Supported?", torch.cuda.is_available())

PyTorch version: 2.1.2
Is CUDA Supported? True


In [3]:
torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.get_device_name(0)

(True, 1, 'Tesla T4')

# Data loading

In [4]:
# def load_and_process_dataset(file_path, source_name):
#     # Load dataset
#     dataset = pd.read_json(file_path)
    
#     # Reset index
#     dataset.reset_index(drop=True, inplace=True)
    
#     # Drop unnecessary columns, keeping only 'SMILES' and 'source'
#     dataset = dataset[['SMILES', 'source']] if 'SMILES' in dataset.columns else dataset[['smiles', 'source']]
    
#     # Rename 'SMILES' column to 'smiles'
#     dataset.rename(columns={'SMILES': 'smiles'}, inplace=True)
    
#     # Assign source name if missing
#     if 'source' not in dataset.columns:
#         dataset['source'] = source_name
    
#     # Add 'binds_to_rna' column based on the source
#     dataset['binds_to_rna'] = 0 if source_name == 'enmine_protein' else 1
    
#     return dataset

# # Define file paths and source names
# datasets_info = [
#     ('data_mvi/chemdiv_df.json', 'chemdiv'),
#     ('data_mvi/enamine_df.json', 'enamine'),
#     ('data_mvi/picked_molecules.json', 'enmine_protein'),
#     ('data_mvi/life_chemicals_df.json', 'life_chemicals'),
#     ('data_mvi/robin_df.json', 'robin')
# ]

# # Load, process, and combine datasets
# combined_df = pd.concat([load_and_process_dataset(file_path, source) for file_path, source in datasets_info], ignore_index=True)

# combined_df.head()


In [5]:
# from rdkit import Chem
# import pandas as pd
# # Assuming 'standardizer' is an instance of a class with methods 'standardize_mol' and 'get_parent_mol'

# def remove_explicit_salts(smiles_list):
#     salt_patterns = [
#         ".[O-][Cl+3]([O-])([O-])[O-]",
#         "[O-][Cl+3]([O-])([O-])[O-].",
#         ".[O-][Cl+3](O)(O)O"
#     ]
#     for salt_pattern in salt_patterns:
#         smiles_list = [smiles.replace(salt_pattern, "") for smiles in smiles_list]
#     return smiles_list

# def process_molecule(smiles):
#     mol = Chem.MolFromSmiles(smiles)
#     if mol is None:
#         return None
#     num_components = len(Chem.GetMolFrags(mol, asMols=False, sanitizeFrags=False))
#     if num_components > 1:
#         std_mol = standardizer.standardize_mol(mol)
#         parent_mol, _ = standardizer.get_parent_mol(std_mol)
#         if parent_mol is None or parent_mol.GetNumAtoms() == 0:
#             return None
#         return Chem.MolToSmiles(parent_mol)
#     return Chem.MolToSmiles(mol)

# def is_disconnected(smiles):
#     mol = Chem.MolFromSmiles(smiles)
#     return len(Chem.GetMolFrags(mol, asMols=False, sanitizeFrags=False)) > 1


# # Then, process each molecule for standardization and further sanitization
# for idx, row in combined_df.iterrows():
#     sanitized_smiles = process_molecule(row['smiles'])
#     if sanitized_smiles is not None:
#         combined_df.at[idx, 'smiles'] = sanitized_smiles

# combined_df['smiles'] = remove_explicit_salts(combined_df['smiles'].tolist())


# # Finally, apply the check for disconnected molecules
# combined_df['is_disconnected'] = combined_df['smiles'].apply(is_disconnected)

# # Filtering or analyzing disconnected molecules
# disconnected_mols_df = combined_df[combined_df['is_disconnected']]

# print(f"Total molecules: {len(combined_df)}")
# print(f"Disconnected molecules after sanitization: {len(disconnected_mols_df)}")


In [6]:
# # Your combined_df
# # combined_df = pd.read_csv("your_dataset.csv")

# # Initialize featurizers
# atom_featurizer = CanonicalAtomFeaturizer()
# bond_featurizer = CanonicalBondFeaturizer(self_loop=True)

# # Function to convert SMILES to a graph
# def smiles_to_graph(smiles_string):
#     graph_constructor = SMILESToBigraph(add_self_loop=True,  # Add self loops
#                                         node_featurizer=atom_featurizer,
#                                         edge_featurizer=bond_featurizer)
#     return graph_constructor(smiles_string)

# # Convert SMILES to graphs
# combined_df['graph'] = combined_df['smiles'].apply(smiles_to_graph)

In [7]:
# import dgl

# # Assuming 'binds_to_rna' is your label
# graphs = combined_df['graph'].tolist()
# labels = {'binds_to_rna': torch.tensor(combined_df['binds_to_rna'].values)}

# # Save graphs and labels
# dgl.save_graphs("data_mvi/graphs.bin", graphs, labels)



In [8]:
# # Drop the 'graph' column and save the DataFrame
# combined_df.drop(columns=['graph']).to_csv("data_mvi/combined_df.csv", index=False)


In [9]:
# Load the DataFrame
reloaded_df = pd.read_csv("data_mvi/combined_df.csv")

# Load the graphs
graphs, labels = dgl.load_graphs("data_mvi/graphs.bin")

# Labels are returned as a dictionary, convert to the desired format if necessary
binds_to_rna = labels['binds_to_rna']


In [10]:
from sklearn.model_selection import train_test_split

# Assuming 'binds_to_rna' is your target variable for stratification
labels = reloaded_df['binds_to_rna'].values

# Generate a list of indices from the DataFrame
indices = range(len(reloaded_df))

# Perform stratified split
train_indices, test_indices, train_labels, test_labels = train_test_split(
    indices, labels, test_size=0.2, stratify=labels, random_state=42)

# Now you can use train_indices and test_indices to split your graphs and labels
train_graphs = [graphs[i] for i in train_indices]
test_graphs = [graphs[i] for i in test_indices]

# If you saved labels as a tensor, you can also split it according to indices
train_labels = labels[train_indices]
test_labels = labels[test_indices]


# Graph Attention Networks (GAT)

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [12]:
# train_dataset = list(zip(train_graphs, train_labels))
# test_dataset = list(zip(test_graphs, test_labels))

In [13]:
import torch
import torch.nn as nn
import dgl
from dgl.nn import GatedGraphConv, GlobalAttentionPooling

class GraphClsGGNN(nn.Module):
    def __init__(self, annotation_size, out_feats, n_steps, n_etypes, num_cls):
        super(GraphClsGGNN, self).__init__()
        self.ggnn = GatedGraphConv(in_feats=annotation_size, out_feats=out_feats, n_steps=n_steps, n_etypes=n_etypes)
        self.pooling_gate_nn = nn.Linear(out_feats, 1)
        self.pooling = GlobalAttentionPooling(self.pooling_gate_nn)
        self.output_layer = nn.Linear(out_feats, num_cls)  # For binary classification, num_cls=2


    def forward(self, graph, feat, get_attention=False):
        h = self.ggnn(graph, feat)
        if get_attention:
            hg, attention = self.pooling(graph, h, get_attention=True)
            logits = self.output_layer(hg)
            return logits, attention
        else:
            hg = self.pooling(graph, h)
            logits = self.output_layer(hg)
            return logits


In [14]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    
    # Check if labels are already tensors (which might be the case if __getitem__ converts them)
    # If labels are not tensors, convert them to a tensor
    if not all(isinstance(label, torch.Tensor) for label in labels):
        labels = torch.tensor(labels, dtype=torch.long)
    else:
        # If labels are already tensors, stack them into a single tensor
        labels = torch.stack(labels, dim=0)
    
    return batched_graph, labels


In [15]:
import optuna
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.dataloading import GraphDataLoader
from sklearn.model_selection import train_test_split

# Assuming `graphs` and `labels` are your dataset
train_graphs, test_graphs, train_labels, test_labels = train_test_split(graphs, labels, test_size=0.2, random_state=42)

# Convert labels to tensors here to avoid reassignment inside the objective function
train_labels_tensor = torch.LongTensor(train_labels)
test_labels_tensor = torch.LongTensor(test_labels)

def objective(trial):
    n_steps = trial.suggest_int('n_steps', 1, 15)
    out_feats = trial.suggest_int('out_feats', 74, 200)  # Ensure out_feats is at least equal to annotation_size
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64, 128, 256, 512])
    
    model = GraphClsGGNN(annotation_size=74, out_feats=out_feats, n_steps=n_steps, n_etypes=1, num_cls=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    
    train_loader = GraphDataLoader(list(zip(train_graphs, train_labels_tensor)), batch_size=batch_size, shuffle=True, collate_fn=collate, num_workers=5)
    val_loader = GraphDataLoader(list(zip(test_graphs, test_labels_tensor)), batch_size=batch_size, shuffle=False, collate_fn=collate, num_workers=5)
    
    for epoch in range(100):  # Adjust the number of epochs as needed
        model.train()
        for batched_graph, labels in train_loader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            if 'h' in batched_graph.ndata:
                feat = batched_graph.ndata['h'].to(device)
            else:
                raise KeyError("'h' node features not found. Please ensure the correct key is used.")
            optimizer.zero_grad()
            logits = model(batched_graph, feat)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
    
    # Evaluation
    model.eval()
    total_correct = 0
    total = 0
    with torch.no_grad():
        for batched_graph, labels in val_loader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            if 'h' in batched_graph.ndata:
                feat = batched_graph.ndata['h'].to(device)
            logits = model(batched_graph, feat)
            predicted = torch.argmax(logits, dim=1)
            total_correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
    val_accuracy = total_correct / total
    return val_accuracy





In [16]:
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100, n_jobs=-1)


[I 2024-03-04 22:39:27,456] A new study created in memory with name: no-name-c9b759a3-b982-4048-8506-098d4449d762


[I 2024-03-04 22:41:17,763] Trial 0 finished with value: 0.8198186002436714 and parameters: {'n_steps': 3, 'out_feats': 98, 'lr': 0.0030592682431992914, 'batch_size': 128}. Best is trial 0 with value: 0.8198186002436714.


[I 2024-03-04 22:45:24,268] Trial 1 finished with value: 0.7737241099228375 and parameters: {'n_steps': 1, 'out_feats': 83, 'lr': 8.979195497191822e-05, 'batch_size': 32}. Best is trial 0 with value: 0.8198186002436714.


[I 2024-03-04 22:46:32,971] Trial 2 finished with value: 0.7955868417490185 and parameters: {'n_steps': 1, 'out_feats': 76, 'lr': 0.0006376791827759781, 'batch_size': 128}. Best is trial 0 with value: 0.8198186002436714.


[I 2024-03-04 22:58:17,834] Trial 3 finished with value: 0.8428996886422093 and parameters: {'n_steps': 4, 'out_feats': 74, 'lr': 0.0008118235527526994, 'batch_size': 16}. Best is trial 3 with value: 0.8428996886422093.


[I 2024-03-04 23:00:50,960] Trial 4 finished with value: 0.8011371328008664 and parameters: {'n_steps': 2, 'out_feats': 92, 'lr': 0.0001224521104895872, 'batch_size': 64}. Best is trial 3 with value: 0.8428996886422093.


[I 2024-03-04 23:05:39,621] Trial 5 finished with value: 0.8514281846487072 and parameters: {'n_steps': 2, 'out_feats': 87, 'lr': 0.0004709152159475434, 'batch_size': 32}. Best is trial 5 with value: 0.8514281846487072.


[I 2024-03-04 23:10:31,231] Trial 6 finished with value: 0.8382970082577501 and parameters: {'n_steps': 2, 'out_feats': 74, 'lr': 0.0017451771298170667, 'batch_size': 32}. Best is trial 5 with value: 0.8514281846487072.


[I 2024-03-04 23:13:34,577] Trial 7 finished with value: 0.8509543793150128 and parameters: {'n_steps': 4, 'out_feats': 99, 'lr': 0.00022233459611329585, 'batch_size': 64}. Best is trial 5 with value: 0.8514281846487072.


[I 2024-03-04 23:15:00,480] Trial 8 finished with value: 0.7119263571138487 and parameters: {'n_steps': 3, 'out_feats': 93, 'lr': 0.014116201373285672, 'batch_size': 128}. Best is trial 5 with value: 0.8514281846487072.


[I 2024-03-04 23:18:17,670] Trial 9 finished with value: 0.6091105997021795 and parameters: {'n_steps': 5, 'out_feats': 89, 'lr': 0.07788856396846346, 'batch_size': 64}. Best is trial 5 with value: 0.8514281846487072.


[I 2024-03-04 23:23:14,116] Trial 10 finished with value: 0.7376472180858265 and parameters: {'n_steps': 2, 'out_feats': 83, 'lr': 1.070029054270681e-05, 'batch_size': 32}. Best is trial 5 with value: 0.8514281846487072.


[I 2024-03-04 23:26:14,175] Trial 11 finished with value: 0.8183294977663463 and parameters: {'n_steps': 4, 'out_feats': 100, 'lr': 8.179749240740736e-05, 'batch_size': 64}. Best is trial 5 with value: 0.8514281846487072.


[I 2024-03-04 23:38:29,344] Trial 12 finished with value: 0.8552186273182618 and parameters: {'n_steps': 5, 'out_feats': 85, 'lr': 0.0003279728846994854, 'batch_size': 16}. Best is trial 12 with value: 0.8552186273182618.


[I 2024-03-04 23:50:37,730] Trial 13 finished with value: 0.8159604710978746 and parameters: {'n_steps': 5, 'out_feats': 83, 'lr': 3.0435279468327182e-05, 'batch_size': 16}. Best is trial 12 with value: 0.8552186273182618.


[I 2024-03-05 00:00:06,682] Trial 14 finished with value: 0.7560579396236632 and parameters: {'n_steps': 2, 'out_feats': 87, 'lr': 0.005443490014707386, 'batch_size': 16}. Best is trial 12 with value: 0.8552186273182618.


[I 2024-03-05 00:12:20,941] Trial 15 finished with value: 0.8563016109381345 and parameters: {'n_steps': 5, 'out_feats': 79, 'lr': 0.0003690110013130995, 'batch_size': 16}. Best is trial 15 with value: 0.8563016109381345.


[I 2024-03-05 00:24:52,779] Trial 16 finished with value: 0.8579260863679437 and parameters: {'n_steps': 5, 'out_feats': 79, 'lr': 0.0002787455900253741, 'batch_size': 16}. Best is trial 16 with value: 0.8579260863679437.


[I 2024-03-05 00:37:35,278] Trial 17 finished with value: 0.8224583728171111 and parameters: {'n_steps': 5, 'out_feats': 79, 'lr': 4.5951890189508e-05, 'batch_size': 16}. Best is trial 16 with value: 0.8579260863679437.


[I 2024-03-05 00:49:45,496] Trial 18 finished with value: 0.68485176661703 and parameters: {'n_steps': 4, 'out_feats': 79, 'lr': 0.01296646264197027, 'batch_size': 16}. Best is trial 16 with value: 0.8579260863679437.


[I 2024-03-05 01:02:06,205] Trial 19 finished with value: 0.7869229727900365 and parameters: {'n_steps': 5, 'out_feats': 80, 'lr': 1.4454237486058128e-05, 'batch_size': 16}. Best is trial 16 with value: 0.8579260863679437.


[I 2024-03-05 01:12:53,766] Trial 20 finished with value: 0.8449302829294707 and parameters: {'n_steps': 3, 'out_feats': 77, 'lr': 0.00020221453818348802, 'batch_size': 16}. Best is trial 16 with value: 0.8579260863679437.


[I 2024-03-05 01:25:18,137] Trial 21 finished with value: 0.8592121294165426 and parameters: {'n_steps': 5, 'out_feats': 81, 'lr': 0.00038308871884124647, 'batch_size': 16}. Best is trial 21 with value: 0.8592121294165426.


[I 2024-03-05 01:34:04,570] Trial 22 finished with value: 0.7235007445512387 and parameters: {'n_steps': 5, 'out_feats': 81, 'lr': 0.001908876464300018, 'batch_size': 16}. Best is trial 21 with value: 0.8592121294165426.


[I 2024-03-05 01:41:37,529] Trial 23 finished with value: 0.8045214566129687 and parameters: {'n_steps': 4, 'out_feats': 77, 'lr': 0.001128196479585872, 'batch_size': 16}. Best is trial 21 with value: 0.8592121294165426.


[I 2024-03-05 01:49:58,510] Trial 24 finished with value: 0.8654392852308109 and parameters: {'n_steps': 5, 'out_feats': 82, 'lr': 0.0003388786595148537, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 01:57:23,726] Trial 25 finished with value: 0.8558278056044403 and parameters: {'n_steps': 4, 'out_feats': 82, 'lr': 0.00018280297840227955, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:05:31,921] Trial 26 finished with value: 0.8213077027209963 and parameters: {'n_steps': 5, 'out_feats': 85, 'lr': 4.942546207294039e-05, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:13:06,081] Trial 27 finished with value: 0.8047922025179369 and parameters: {'n_steps': 4, 'out_feats': 85, 'lr': 0.0010354329842995858, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:14:35,445] Trial 28 finished with value: 0.7960606470827128 and parameters: {'n_steps': 5, 'out_feats': 89, 'lr': 0.004344206060837336, 'batch_size': 128}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:15:55,001] Trial 29 finished with value: 0.8432381210234196 and parameters: {'n_steps': 3, 'out_feats': 81, 'lr': 0.002264102082426197, 'batch_size': 128}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:24:34,330] Trial 30 finished with value: 0.7900365506971707 and parameters: {'n_steps': 4, 'out_feats': 76, 'lr': 2.4464317542154085e-05, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:32:54,066] Trial 31 finished with value: 0.8608366048463517 and parameters: {'n_steps': 5, 'out_feats': 78, 'lr': 0.0003772174540714661, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:41:13,577] Trial 32 finished with value: 0.8492622174089617 and parameters: {'n_steps': 5, 'out_feats': 78, 'lr': 0.00013301732061682037, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:49:43,826] Trial 33 finished with value: 0.8490591579802356 and parameters: {'n_steps': 5, 'out_feats': 81, 'lr': 0.0005198401117146547, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 02:57:54,421] Trial 34 finished with value: 0.8218491945309327 and parameters: {'n_steps': 5, 'out_feats': 75, 'lr': 7.407281727185796e-05, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:02:19,547] Trial 35 finished with value: 0.8551509408420198 and parameters: {'n_steps': 4, 'out_feats': 96, 'lr': 0.0002952905783554524, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:08:48,697] Trial 36 finished with value: 0.8223906863408691 and parameters: {'n_steps': 1, 'out_feats': 77, 'lr': 0.0006493250747098867, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:10:18,161] Trial 37 finished with value: 0.8535941518884527 and parameters: {'n_steps': 5, 'out_feats': 84, 'lr': 0.0010702345852843839, 'batch_size': 128}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:12:34,961] Trial 38 finished with value: 0.8191417354812508 and parameters: {'n_steps': 4, 'out_feats': 80, 'lr': 0.0001215802403583266, 'batch_size': 64}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:16:59,573] Trial 39 finished with value: 0.856843102748071 and parameters: {'n_steps': 5, 'out_feats': 75, 'lr': 0.0006379092709413135, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:25:13,139] Trial 40 finished with value: 0.8472993095979423 and parameters: {'n_steps': 5, 'out_feats': 83, 'lr': 0.00024879277367438526, 'batch_size': 16}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:29:52,649] Trial 41 finished with value: 0.8587383240828482 and parameters: {'n_steps': 5, 'out_feats': 74, 'lr': 0.0006414518613755485, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:34:11,233] Trial 42 finished with value: 0.8579260863679437 and parameters: {'n_steps': 5, 'out_feats': 74, 'lr': 0.000427852782115388, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:38:54,142] Trial 43 finished with value: 0.8472993095979423 and parameters: {'n_steps': 5, 'out_feats': 76, 'lr': 0.00016501047203769902, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:42:48,581] Trial 44 finished with value: 0.8301069446324625 and parameters: {'n_steps': 4, 'out_feats': 78, 'lr': 0.0013544181501140397, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:47:08,076] Trial 45 finished with value: 0.8618519019899824 and parameters: {'n_steps': 5, 'out_feats': 74, 'lr': 0.0005848856325856442, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:51:10,317] Trial 46 finished with value: 0.8560985515094084 and parameters: {'n_steps': 4, 'out_feats': 74, 'lr': 0.0006939812787367989, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:55:46,787] Trial 47 finished with value: 0.7845539461215649 and parameters: {'n_steps': 5, 'out_feats': 75, 'lr': 0.002728372051576313, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 03:58:55,427] Trial 48 finished with value: 0.807838093948829 and parameters: {'n_steps': 1, 'out_feats': 76, 'lr': 0.00042474540323356175, 'batch_size': 32}. Best is trial 24 with value: 0.8654392852308109.


[I 2024-03-05 04:02:04,244] Trial 49 finished with value: 0.796196020035197 and parameters: {'n_steps': 5, 'out_feats': 89, 'lr': 0.004088153177600271, 'batch_size': 64}. Best is trial 24 with value: 0.8654392852308109.


In [17]:

# Output best trial information
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")
    

Best trial:
 Value: 0.8654392852308109
 Params: 
    n_steps: 5
    out_feats: 82
    lr: 0.0003388786595148537
    batch_size: 16


In [17]:
# Best trial:
#  Value: 0.8654392852308109
#  Params: 
#     n_steps: 5
#     out_feats: 82
#     lr: 0.0003388786595148537
#     batch_size: 16

In [19]:
# Assuming these are your best hyperparameters from Optuna
best_hyperparams = {
    'n_steps': 5,
    'out_feats': 82,
    'lr': 0.0003388786595148537,
    'batch_size': 16
}

# Reinitialize the model with the best hyperparameters
model = GraphClsGGNN(
    annotation_size=74, # Assuming 'h' feature size is 74
    out_feats=best_hyperparams['out_feats'],
    n_steps=best_hyperparams['n_steps'],
    n_etypes=1, # Adjust this based on your dataset
    num_cls=2
).to(device)

optimizer = optim.Adam(model.parameters(), lr=best_hyperparams['lr'])
criterion = nn.CrossEntropyLoss()

# Create your DataLoaders
train_loader = GraphDataLoader(list(zip(train_graphs, torch.tensor(train_labels))), batch_size=best_hyperparams['batch_size'], shuffle=True, collate_fn=collate)
test_loader = GraphDataLoader(list(zip(test_graphs, torch.tensor(test_labels))), batch_size=best_hyperparams['batch_size'], shuffle=False, collate_fn=collate)

# Retrain the model
num_epochs = 300 # Adjust based on your needs
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    num_batches = 0
    for batched_graph, labels in train_loader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(batched_graph, batched_graph.ndata['h'].float().to(device))
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}')




Epoch 1/30, Loss: 0.5194
Epoch 2/30, Loss: 0.4417
Epoch 3/30, Loss: 0.4040
Epoch 4/30, Loss: 0.3792
Epoch 5/30, Loss: 0.3610
Epoch 6/30, Loss: 0.3485
Epoch 7/30, Loss: 0.3387
Epoch 8/30, Loss: 0.3301
Epoch 9/30, Loss: 0.3231
Epoch 10/30, Loss: 0.3182
Epoch 11/30, Loss: 0.3115
Epoch 12/30, Loss: 0.3082
Epoch 13/30, Loss: 0.3047
Epoch 14/30, Loss: 0.2998
Epoch 15/30, Loss: 0.2960
Epoch 16/30, Loss: 0.2955
Epoch 17/30, Loss: 0.2937
Epoch 18/30, Loss: 0.2883
Epoch 19/30, Loss: 0.2870
Epoch 20/30, Loss: 0.2842
Epoch 21/30, Loss: 0.2824
Epoch 22/30, Loss: 0.2790
Epoch 23/30, Loss: 0.2770
Epoch 24/30, Loss: 0.2767
Epoch 25/30, Loss: 0.2751
Epoch 26/30, Loss: 0.2723
Epoch 27/30, Loss: 0.2733
Epoch 28/30, Loss: 0.2711
Epoch 29/30, Loss: 0.2706
Epoch 30/30, Loss: 0.2678


In [22]:
# save the model
torch.save(model.state_dict(), "data_mvi/ggnn_model.pth")

In [20]:

# Evaluate the model on the test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batched_graph, labels in test_loader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        logits = model(batched_graph, batched_graph.ndata['h'].float().to(device))
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()



In [21]:
test_accuracy = correct / total
print(f'Test accuracy: {test_accuracy:.4f}')

Test accuracy: 0.8685


In [None]:
# as % 