# Load Libraries

In [1]:
# Standard libraries
import os
import json
import pickle
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') # Known issue with PyTorch and DGL
import IPython

# Data handling
import pandas as pd
import numpy as np
from scipy.spatial.distance import pdist, squareform

# Machine learning and model evaluation
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score, roc_curve, auc, confusion_matrix, classification_report
from sklearn.cluster import HDBSCAN

# Neural Networks and Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import sigmoid
import torch.optim as optim

# Optuna for hyperparameter optimization
import optuna
from optuna.pruners import MedianPruner, HyperbandPruner

# Graph Neural Networks
import dgl
import dgl.nn as dglnn
from dgl import batch
from dgl.data.utils import save_graphs, load_graphs
from dgl.nn import GATConv, GATv2Conv
from dgl import max_nodes

# Cheminformatics
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D
from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from chembl_structure_pipeline import standardizer

# Network analysis
import networkx as nx

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns



check if GPU is aviable and the device

In [2]:
import torch
print("PyTorch version:", torch.__version__)
print("Is CUDA Supported?", torch.cuda.is_available())

PyTorch version: 2.1.2
Is CUDA Supported? True


In [3]:
torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.get_device_name(0)

(True, 1, 'Tesla T4')

# Data loading

In [4]:
def load_and_process_dataset(file_path, source_name):
    # Load dataset
    dataset = pd.read_json(file_path)
    
    # Reset index
    dataset.reset_index(drop=True, inplace=True)
    
    # Drop unnecessary columns, keeping only 'SMILES' and 'source'
    dataset = dataset[['SMILES', 'source']] if 'SMILES' in dataset.columns else dataset[['smiles', 'source']]
    
    # Rename 'SMILES' column to 'smiles'
    dataset.rename(columns={'SMILES': 'smiles'}, inplace=True)
    
    # Assign source name if missing
    if 'source' not in dataset.columns:
        dataset['source'] = source_name
    
    # Add 'binds_to_rna' column based on the source
    dataset['binds_to_rna'] = 0 if source_name == 'enmine_protein' else 1
    
    return dataset

# Define file paths and source names
datasets_info = [
    ('data_mvi/chemdiv_df.json', 'chemdiv'),
    ('data_mvi/enamine_df.json', 'enamine'),
    ('data_mvi/picked_molecules.json', 'enmine_protein'),
    ('data_mvi/life_chemicals_df.json', 'life_chemicals'),
    ('data_mvi/robin_df.json', 'robin')
]

# Load, process, and combine datasets
combined_df = pd.concat([load_and_process_dataset(file_path, source) for file_path, source in datasets_info], ignore_index=True)

combined_df.head()


In [5]:
from rdkit import Chem
import pandas as pd
# Assuming 'standardizer' is an instance of a class with methods 'standardize_mol' and 'get_parent_mol'

def remove_explicit_salts(smiles_list):
    salt_patterns = [
        ".[O-][Cl+3]([O-])([O-])[O-]",
        "[O-][Cl+3]([O-])([O-])[O-].",
        ".[O-][Cl+3](O)(O)O"
    ]
    for salt_pattern in salt_patterns:
        smiles_list = [smiles.replace(salt_pattern, "") for smiles in smiles_list]
    return smiles_list

def process_molecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    num_components = len(Chem.GetMolFrags(mol, asMols=False, sanitizeFrags=False))
    if num_components > 1:
        std_mol = standardizer.standardize_mol(mol)
        parent_mol, _ = standardizer.get_parent_mol(std_mol)
        if parent_mol is None or parent_mol.GetNumAtoms() == 0:
            return None
        return Chem.MolToSmiles(parent_mol)
    return Chem.MolToSmiles(mol)

def is_disconnected(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return len(Chem.GetMolFrags(mol, asMols=False, sanitizeFrags=False)) > 1


# Then, process each molecule for standardization and further sanitization
for idx, row in combined_df.iterrows():
    sanitized_smiles = process_molecule(row['smiles'])
    if sanitized_smiles is not None:
        combined_df.at[idx, 'smiles'] = sanitized_smiles

combined_df['smiles'] = remove_explicit_salts(combined_df['smiles'].tolist())


# Finally, apply the check for disconnected molecules
combined_df['is_disconnected'] = combined_df['smiles'].apply(is_disconnected)

# Filtering or analyzing disconnected molecules
disconnected_mols_df = combined_df[combined_df['is_disconnected']]

print(f"Total molecules: {len(combined_df)}")
print(f"Disconnected molecules after sanitization: {len(disconnected_mols_df)}")


In [6]:
# Your combined_df
# combined_df = pd.read_csv("your_dataset.csv")

# Initialize featurizers
atom_featurizer = CanonicalAtomFeaturizer()
bond_featurizer = CanonicalBondFeaturizer(self_loop=True)

# Function to convert SMILES to a graph
def smiles_to_graph(smiles_string):
    graph_constructor = SMILESToBigraph(add_self_loop=True,  # Add self loops
                                        node_featurizer=atom_featurizer,
                                        edge_featurizer=bond_featurizer)
    return graph_constructor(smiles_string)

# Convert SMILES to graphs
combined_df['graph'] = combined_df['smiles'].apply(smiles_to_graph)

In [7]:
import dgl

# Assuming 'binds_to_rna' is your label
graphs = combined_df['graph'].tolist()
labels = {'binds_to_rna': torch.tensor(combined_df['binds_to_rna'].values)}

# Save graphs and labels
dgl.save_graphs("data_mvi/graphs.bin", graphs, labels)



In [8]:
# Drop the 'graph' column and save the DataFrame
combined_df.drop(columns=['graph']).to_csv("data_mvi/combined_df.csv", index=False)


In [9]:
# Load the DataFrame
reloaded_df = pd.read_csv("data_mvi/combined_df.csv")

# Load the graphs
graphs, labels = dgl.load_graphs("data_mvi/graphs.bin")

# Labels are returned as a dictionary, convert to the desired format if necessary
binds_to_rna = labels['binds_to_rna']


In [10]:
from sklearn.model_selection import train_test_split

# Assuming 'binds_to_rna' is your target variable for stratification
labels = reloaded_df['binds_to_rna'].values

# Generate a list of indices from the DataFrame
indices = range(len(reloaded_df))

# Perform stratified split
train_indices, test_indices, train_labels, test_labels = train_test_split(
    indices, labels, test_size=0.2, stratify=labels, random_state=42)

# Now you can use train_indices and test_indices to split your graphs and labels
train_graphs = [graphs[i] for i in train_indices]
test_graphs = [graphs[i] for i in test_indices]

# If you saved labels as a tensor, you can also split it according to indices
train_labels = labels[train_indices]
test_labels = labels[test_indices]


# Graph Attention Networks (GAT)

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [12]:
# train_dataset = list(zip(train_graphs, train_labels))
# test_dataset = list(zip(test_graphs, test_labels))

In [13]:
import torch
import torch.nn as nn
import dgl
from dgl.nn import GATv2Conv, GlobalAttentionPooling
import optuna
import torch.optim as optim
from dgl.dataloading import GraphDataLoader
from sklearn.model_selection import train_test_split

class GraphClsGATv2(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_heads, num_classes):
        super(GraphClsGATv2, self).__init__()
        self.layer1 = GATv2Conv(in_feats=in_dim, out_feats=hidden_dim, num_heads=num_heads)
        self.layer2 = GATv2Conv(in_feats=hidden_dim * num_heads, out_feats=hidden_dim, num_heads=num_heads)
        self.pooling = GlobalAttentionPooling(gate_nn=nn.Linear(hidden_dim * num_heads, 1))
        self.classifier = nn.Linear(hidden_dim * num_heads, num_classes)

    def forward(self, g, h, get_attention=False):
        h = self.layer1(g, h).flatten(1)
        h = self.layer2(g, h).flatten(1)
        hg = self.pooling(g, h)
        logits = self.classifier(hg)
        if get_attention:
            return logits, self.pooling.attn
        else:
            return logits
        

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs).to(device)
    labels = torch.tensor(labels, dtype=torch.long).to(device)
    return batched_graph, labels

In [16]:
def objective(trial):
    hidden_dim = trial.suggest_int('hidden_dim', 32, 128)
    num_heads = trial.suggest_int('num_heads', 2, 8)
    lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)
    batch_size = trial.suggest_categorical('batch_size', [32, 64, 128])

    model = GraphClsGATv2(in_dim=74, hidden_dim=hidden_dim, num_heads=num_heads, num_classes=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    train_loader = GraphDataLoader(list(zip(train_graphs, train_labels)), batch_size=batch_size, shuffle=True, collate_fn=collate)
    val_loader = GraphDataLoader(list(zip(test_graphs, test_labels)), batch_size=batch_size, shuffle=False, collate_fn=collate)

    for epoch in range(10):  # Adjust epochs as needed
        model.train()
        for batched_graph, labels in train_loader:
            optimizer.zero_grad()
            h = batched_graph.ndata['h'].float().to(device)  # Ensure node features are on the correct device
            logits = model(batched_graph, h)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batched_graph, labels in val_loader:
            h = batched_graph.ndata['h'].float().to(device)  # Ensure node features are on the correct device
            logits = model(batched_graph, h)
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = correct / total
    return val_accuracy



In [17]:
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

[I 2024-03-05 01:25:03,690] A new study created in memory with name: no-name-59b53552-04d3-46b7-8054-619fa057cb91


  lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)


[I 2024-03-05 01:27:25,605] Trial 0 finished with value: 0.8267903072966022 and parameters: {'hidden_dim': 113, 'num_heads': 6, 'lr': 0.00033117543647138867, 'batch_size': 64}. Best is trial 0 with value: 0.8267903072966022.


  lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)


[I 2024-03-05 01:29:30,084] Trial 1 finished with value: 0.8326113442534182 and parameters: {'hidden_dim': 88, 'num_heads': 5, 'lr': 0.00034234617923300517, 'batch_size': 64}. Best is trial 1 with value: 0.8326113442534182.


[I 2024-03-05 01:33:18,857] Trial 2 finished with value: 0.7808311899282523 and parameters: {'hidden_dim': 119, 'num_heads': 5, 'lr': 0.004847755227413553, 'batch_size': 32}. Best is trial 1 with value: 0.8326113442534182.


[I 2024-03-05 01:35:17,801] Trial 3 finished with value: 0.8098686882360905 and parameters: {'hidden_dim': 72, 'num_heads': 8, 'lr': 0.0002565654077034613, 'batch_size': 128}. Best is trial 1 with value: 0.8326113442534182.


[I 2024-03-05 01:38:47,365] Trial 4 finished with value: 0.8360633545417625 and parameters: {'hidden_dim': 36, 'num_heads': 6, 'lr': 0.0008804206473212662, 'batch_size': 32}. Best is trial 4 with value: 0.8360633545417625.


[I 2024-03-05 01:40:39,447] Trial 5 finished with value: 0.8149451739542439 and parameters: {'hidden_dim': 59, 'num_heads': 3, 'lr': 0.00384418839308737, 'batch_size': 64}. Best is trial 4 with value: 0.8360633545417625.


[I 2024-03-05 01:44:03,326] Trial 6 finished with value: 0.7885474482198457 and parameters: {'hidden_dim': 42, 'num_heads': 2, 'lr': 0.0003072741104865836, 'batch_size': 32}. Best is trial 4 with value: 0.8360633545417625.


[I 2024-03-05 01:47:34,042] Trial 7 finished with value: 0.7650602409638554 and parameters: {'hidden_dim': 108, 'num_heads': 2, 'lr': 0.008742357975815511, 'batch_size': 32}. Best is trial 4 with value: 0.8360633545417625.


[I 2024-03-05 01:49:08,274] Trial 8 finished with value: 0.806348991471504 and parameters: {'hidden_dim': 122, 'num_heads': 4, 'lr': 0.008076634335747422, 'batch_size': 128}. Best is trial 4 with value: 0.8360633545417625.


[I 2024-03-05 01:50:32,102] Trial 9 finished with value: 0.8007987004196562 and parameters: {'hidden_dim': 80, 'num_heads': 5, 'lr': 0.008276939444228366, 'batch_size': 128}. Best is trial 4 with value: 0.8360633545417625.


[I 2024-03-05 01:53:56,400] Trial 10 finished with value: 0.8349126844456478 and parameters: {'hidden_dim': 36, 'num_heads': 7, 'lr': 0.0011542974868178782, 'batch_size': 32}. Best is trial 4 with value: 0.8360633545417625.


[I 2024-03-05 01:57:15,953] Trial 11 finished with value: 0.8349803709218898 and parameters: {'hidden_dim': 33, 'num_heads': 7, 'lr': 0.0011848211159779997, 'batch_size': 32}. Best is trial 4 with value: 0.8360633545417625.


[I 2024-03-05 02:00:39,315] Trial 12 finished with value: 0.8449302829294707 and parameters: {'hidden_dim': 51, 'num_heads': 7, 'lr': 0.001217529588475766, 'batch_size': 32}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:04:10,022] Trial 13 finished with value: 0.8426966292134831 and parameters: {'hidden_dim': 52, 'num_heads': 8, 'lr': 0.0006694829812299309, 'batch_size': 32}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:07:39,899] Trial 14 finished with value: 0.799106538513605 and parameters: {'hidden_dim': 56, 'num_heads': 8, 'lr': 0.00012702781891760911, 'batch_size': 32}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:11:03,354] Trial 15 finished with value: 0.8297685122512523 and parameters: {'hidden_dim': 53, 'num_heads': 7, 'lr': 0.001834340979179343, 'batch_size': 32}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:15:15,084] Trial 16 finished with value: 0.8421551374035467 and parameters: {'hidden_dim': 67, 'num_heads': 8, 'lr': 0.0005540109780779452, 'batch_size': 32}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:18:46,793] Trial 17 finished with value: 0.819480167862461 and parameters: {'hidden_dim': 48, 'num_heads': 6, 'lr': 0.002305628818535247, 'batch_size': 32}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:21:14,531] Trial 18 finished with value: 0.8301746311087045 and parameters: {'hidden_dim': 94, 'num_heads': 7, 'lr': 0.0006346900037879755, 'batch_size': 128}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:23:30,904] Trial 19 finished with value: 0.8288209015838636 and parameters: {'hidden_dim': 67, 'num_heads': 8, 'lr': 0.002008230587218132, 'batch_size': 64}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:26:53,359] Trial 20 finished with value: 0.7815757411669149 and parameters: {'hidden_dim': 47, 'num_heads': 7, 'lr': 0.00010159932067440519, 'batch_size': 32}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:30:33,824] Trial 21 finished with value: 0.8392446189251388 and parameters: {'hidden_dim': 68, 'num_heads': 8, 'lr': 0.0005999714201200556, 'batch_size': 32}. Best is trial 12 with value: 0.8449302829294707.


[I 2024-03-05 02:34:15,366] Trial 22 finished with value: 0.8453364017869229 and parameters: {'hidden_dim': 66, 'num_heads': 8, 'lr': 0.0005473538685539596, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 02:37:49,740] Trial 23 finished with value: 0.8204954650060918 and parameters: {'hidden_dim': 61, 'num_heads': 8, 'lr': 0.00017759049658412992, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 02:41:24,748] Trial 24 finished with value: 0.8362664139704887 and parameters: {'hidden_dim': 83, 'num_heads': 6, 'lr': 0.00045535527007741877, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 02:44:50,955] Trial 25 finished with value: 0.8408690943549478 and parameters: {'hidden_dim': 48, 'num_heads': 7, 'lr': 0.0008598343839747525, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 02:49:13,494] Trial 26 finished with value: 0.8142683091918235 and parameters: {'hidden_dim': 100, 'num_heads': 8, 'lr': 0.0014295151334628733, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 02:52:42,862] Trial 27 finished with value: 0.828076350345201 and parameters: {'hidden_dim': 75, 'num_heads': 4, 'lr': 0.0007666604810184992, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 02:54:50,682] Trial 28 finished with value: 0.8291593339650738 and parameters: {'hidden_dim': 63, 'num_heads': 7, 'lr': 0.0004397017175607994, 'batch_size': 64}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 02:55:58,803] Trial 29 finished with value: 0.8332205225395966 and parameters: {'hidden_dim': 43, 'num_heads': 6, 'lr': 0.003412981405373537, 'batch_size': 128}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 02:57:53,042] Trial 30 finished with value: 0.8032354135643699 and parameters: {'hidden_dim': 54, 'num_heads': 6, 'lr': 0.0002102085107057966, 'batch_size': 64}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:01:43,435] Trial 31 finished with value: 0.8205631514823338 and parameters: {'hidden_dim': 76, 'num_heads': 8, 'lr': 0.0005003644068596925, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:05:17,005] Trial 32 finished with value: 0.8408014078787058 and parameters: {'hidden_dim': 67, 'num_heads': 8, 'lr': 0.0006220289604353612, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:09:00,238] Trial 33 finished with value: 0.8351157438743739 and parameters: {'hidden_dim': 87, 'num_heads': 8, 'lr': 0.00040061733968073203, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:12:55,224] Trial 34 finished with value: 0.8270610532015703 and parameters: {'hidden_dim': 63, 'num_heads': 7, 'lr': 0.0014924730044494393, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:16:47,737] Trial 35 finished with value: 0.8218491945309327 and parameters: {'hidden_dim': 71, 'num_heads': 8, 'lr': 0.00035247612484413476, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:18:49,909] Trial 36 finished with value: 0.8192094219574929 and parameters: {'hidden_dim': 54, 'num_heads': 8, 'lr': 0.0002695061910871944, 'batch_size': 64}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:22:10,870] Trial 37 finished with value: 0.8325436577771761 and parameters: {'hidden_dim': 41, 'num_heads': 7, 'lr': 0.0008206789820218511, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:25:35,682] Trial 38 finished with value: 0.8238797888181941 and parameters: {'hidden_dim': 57, 'num_heads': 5, 'lr': 0.0010435585973227375, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:27:04,667] Trial 39 finished with value: 0.8339650737782591 and parameters: {'hidden_dim': 50, 'num_heads': 8, 'lr': 0.0005723317030997537, 'batch_size': 128}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:30:45,606] Trial 40 finished with value: 0.8141329362393394 and parameters: {'hidden_dim': 128, 'num_heads': 4, 'lr': 0.0029310704951522036, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:34:17,425] Trial 41 finished with value: 0.838093948829024 and parameters: {'hidden_dim': 45, 'num_heads': 7, 'lr': 0.0007377622705346461, 'batch_size': 32}. Best is trial 22 with value: 0.8453364017869229.


[I 2024-03-05 03:37:49,583] Trial 42 finished with value: 0.8462163259780696 and parameters: {'hidden_dim': 39, 'num_heads': 7, 'lr': 0.0009646485849361715, 'batch_size': 32}. Best is trial 42 with value: 0.8462163259780696.


[I 2024-03-05 03:41:14,363] Trial 43 finished with value: 0.8361310410180046 and parameters: {'hidden_dim': 32, 'num_heads': 6, 'lr': 0.001509125175923574, 'batch_size': 32}. Best is trial 42 with value: 0.8462163259780696.


[I 2024-03-05 03:44:43,757] Trial 44 finished with value: 0.8299038852037363 and parameters: {'hidden_dim': 39, 'num_heads': 7, 'lr': 0.0010056517813625413, 'batch_size': 32}. Best is trial 42 with value: 0.8462163259780696.


[I 2024-03-05 03:48:23,702] Trial 45 finished with value: 0.8145390550967917 and parameters: {'hidden_dim': 60, 'num_heads': 8, 'lr': 0.0003022591435156923, 'batch_size': 32}. Best is trial 42 with value: 0.8462163259780696.


[I 2024-03-05 03:51:51,166] Trial 46 finished with value: 0.79761743603628 and parameters: {'hidden_dim': 39, 'num_heads': 7, 'lr': 0.005766239594208876, 'batch_size': 32}. Best is trial 42 with value: 0.8462163259780696.


[I 2024-03-05 03:53:17,674] Trial 47 finished with value: 0.839583051306349 and parameters: {'hidden_dim': 52, 'num_heads': 8, 'lr': 0.001265128125866692, 'batch_size': 128}. Best is trial 42 with value: 0.8462163259780696.


[I 2024-03-05 03:56:55,694] Trial 48 finished with value: 0.8356572356843103 and parameters: {'hidden_dim': 75, 'num_heads': 6, 'lr': 0.0003844844786021545, 'batch_size': 32}. Best is trial 42 with value: 0.8462163259780696.


[I 2024-03-05 03:58:45,146] Trial 49 finished with value: 0.8267226208203601 and parameters: {'hidden_dim': 36, 'num_heads': 7, 'lr': 0.0005363287198046636, 'batch_size': 64}. Best is trial 42 with value: 0.8462163259780696.


In [18]:

print("Best trial:")
trial = study.best_trial
print(f"Value: {trial.value}")
print("Params:")
for key, value in trial.params.items():
    print(f"{key}: {value}")


Best trial:
Value: 0.8462163259780696
Params:
hidden_dim: 39
num_heads: 7
lr: 0.0009646485849361715
batch_size: 32


In [None]:
# Best trial:
# Value: 0.8462163259780696
# Params:
# hidden_dim: 39
# num_heads: 7
# lr: 0.0009646485849361715
# batch_size: 32

test

In [17]:
# Assuming you've got the best hyperparameters as follows (replace these with your actual best hyperparameters)
best_hyperparams = {
    'hidden_dim': 39,  # Example hyperparameter
    'num_heads': 7,    # Example hyperparameter
    'lr': 0.0009646485849361715,  # Example hyperparameter
    'batch_size': 32   # Example hyperparameter
}

# 1. Rebuild the Model with the best hyperparameters
model = GraphClsGATv2(in_dim=74,  # Assuming in_dim is fixed to 74 as per your dataset
                      hidden_dim=best_hyperparams['hidden_dim'],
                      num_heads=best_hyperparams['num_heads'],
                      num_classes=2).to(device)  # Assuming binary classification

# 2. Retrain the Model
optimizer = optim.Adam(model.parameters(), lr=best_hyperparams['lr'])
criterion = torch.nn.CrossEntropyLoss()

# Convert train labels to a tensor and move it to the appropriate device
train_labels_tensor = torch.LongTensor(train_labels).to(device)

# Create DataLoaders for training
train_loader = GraphDataLoader(list(zip(train_graphs, train_labels_tensor)), 
                               batch_size=best_hyperparams['batch_size'], 
                               shuffle=True, 
                               collate_fn=collate)

# Training loop (simplified for demonstration)
num_epochs = 30  # You can adjust the number of epochs
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    num_batches = 0
    for batched_graph, labels in train_loader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(batched_graph, batched_graph.ndata['h'].float())
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
    avg_loss = total_loss / num_batches
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}')

In [21]:
# save the model
torch.save(model.state_dict(), "data_mvi/gat_model.pth")

In [19]:
import torch
from dgl.dataloading import GraphDataLoader

batch_size = 32  # Use the optimal batch size found by Optuna

# Convert test labels to a tensor and move it to the appropriate device
test_labels_tensor = torch.LongTensor(test_labels).to(device)

# Create a DataLoader for the test set
test_loader = GraphDataLoader(list(zip(test_graphs, test_labels_tensor)), 
                              batch_size=batch_size,  # Use the optimal batch size found by Optuna
                              shuffle=False, 
                              collate_fn=collate)

def evaluate_model(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    total_correct = 0
    total = 0
    with torch.no_grad():  # No need to track gradients for evaluation
        for batched_graph, labels in data_loader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            if 'h' in batched_graph.ndata:  # Use 'h' for node features
                features = batched_graph.ndata['h'].to(device)
            logits = model(batched_graph, features)
            predicted = torch.argmax(logits, dim=1)
            total_correct += (predicted == labels).sum().item()
            total += labels.size(0)
    accuracy = total_correct / total
    return accuracy



In [20]:
# Evaluate the model on the test set
test_accuracy = evaluate_model(model, test_loader)
print(f"Test Accuracy: {test_accuracy:.4f}")

Test Accuracy: 0.8447


In [None]:
# Test Accuracy: 0.8447