<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [19]</a>'.</span>

# Load Libraries

In [123]:
# Standard libraries
import os
import json
import pickle
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') # Known issue with PyTorch and DGL
import IPython

# Data handling
import pandas as pd
import numpy as np
from scipy.spatial.distance import pdist, squareform

# Machine learning and model evaluation
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score, roc_curve, auc, confusion_matrix, classification_report
from sklearn.cluster import HDBSCAN

# Neural Networks and Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import sigmoid
import torch.optim as optim

# Optuna for hyperparameter optimization
import optuna
from optuna.pruners import MedianPruner, HyperbandPruner

# Graph Neural Networks
import dgl
import dgl.nn as dglnn
from dgl import batch
from dgl.data.utils import save_graphs, load_graphs
from dgl.nn import GATConv, GATv2Conv
from dgl import max_nodes

# Cheminformatics
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D
from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from chembl_structure_pipeline import standardizer

# Network analysis
import networkx as nx

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns



check if GPU is aviable and the device

In [118]:
import torch
print("PyTorch version:", torch.__version__)
print("Is CUDA Supported?", torch.cuda.is_available())

PyTorch version: 2.1.2
Is CUDA Supported? True


In [119]:
torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.get_device_name(0)

(True, 1, 'Tesla T4')

# Data loading

In [120]:
def load_and_process_dataset(file_path, source_name):
    # Load dataset
    dataset = pd.read_json(file_path)
    
    # Reset index
    dataset.reset_index(drop=True, inplace=True)
    
    # Drop unnecessary columns, keeping only 'SMILES' and 'source'
    dataset = dataset[['SMILES', 'source']] if 'SMILES' in dataset.columns else dataset[['smiles', 'source']]
    
    # Rename 'SMILES' column to 'smiles'
    dataset.rename(columns={'SMILES': 'smiles'}, inplace=True)
    
    # Assign source name if missing
    if 'source' not in dataset.columns:
        dataset['source'] = source_name
    
    # Add 'binds_to_rna' column based on the source
    dataset['binds_to_rna'] = 0 if source_name == 'enmine_protein' else 1
    
    return dataset

# Define file paths and source names
datasets_info = [
    ('data_mvi/chemdiv_df.json', 'chemdiv'),
    ('data_mvi/enamine_df.json', 'enamine'),
    ('data_mvi/picked_molecules.json', 'enmine_protein'),
    ('data_mvi/life_chemicals_df.json', 'life_chemicals'),
    ('data_mvi/robin_df.json', 'robin')
]

# Load, process, and combine datasets
combined_df = pd.concat([load_and_process_dataset(file_path, source) for file_path, source in datasets_info], ignore_index=True)

combined_df.head()


Unnamed: 0,smiles,source,binds_to_rna
0,O=C(Nc1ccc2ccccc2c1)c1ccc2c(c1)C(=O)N(c1cccc(N...,chemdiv,1
1,O=C(CSc1nnc(-c2ccccc2Cl)n1-c1ccccc1)c1ccc2c(c1...,chemdiv,1
2,Cc1ccc(-n2c(=O)c3c4c(sc3n3c(SCC(=O)c5ccccc5)nn...,chemdiv,1
3,O=C(Nc1ccc(C(=O)c2ccccc2)cc1)c1ccc(Oc2ccc(C(=O...,chemdiv,1
4,O=C(Nc1ccc(Oc2cccc(Oc3ccc(NC(=O)c4ccccc4Cl)cc3...,chemdiv,1


In [121]:
combined_df.shape

(73868, 3)

In [None]:
from rdkit import Chem
import pandas as pd
# Assuming 'standardizer' is an instance of a class with methods 'standardize_mol' and 'get_parent_mol'

def remove_explicit_salts(smiles_list):
    salt_patterns = [
        ".[O-][Cl+3]([O-])([O-])[O-]",
        "[O-][Cl+3]([O-])([O-])[O-].",
        ".[O-][Cl+3](O)(O)O"
    ]
    for salt_pattern in salt_patterns:
        smiles_list = [smiles.replace(salt_pattern, "") for smiles in smiles_list]
    return smiles_list

def process_molecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    num_components = len(Chem.GetMolFrags(mol, asMols=False, sanitizeFrags=False))
    if num_components > 1:
        std_mol = standardizer.standardize_mol(mol)
        parent_mol, _ = standardizer.get_parent_mol(std_mol)
        if parent_mol is None or parent_mol.GetNumAtoms() == 0:
            return None
        return Chem.MolToSmiles(parent_mol)
    return Chem.MolToSmiles(mol)

def is_disconnected(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return len(Chem.GetMolFrags(mol, asMols=False, sanitizeFrags=False)) > 1


# Then, process each molecule for standardization and further sanitization
for idx, row in combined_df.iterrows():
    sanitized_smiles = process_molecule(row['smiles'])
    if sanitized_smiles is not None:
        combined_df.at[idx, 'smiles'] = sanitized_smiles

combined_df['smiles'] = remove_explicit_salts(combined_df['smiles'].tolist())


# Finally, apply the check for disconnected molecules
combined_df['is_disconnected'] = combined_df['smiles'].apply(is_disconnected)

# Filtering or analyzing disconnected molecules
disconnected_mols_df = combined_df[combined_df['is_disconnected']]

print(f"Total molecules: {len(combined_df)}")
print(f"Disconnected molecules after sanitization: {len(disconnected_mols_df)}")


In [None]:
print(disconnected_mols_df['smiles'].values)

[]


In [None]:
# dimensions of the dataset
combined_df.shape

(73868, 4)

In [None]:
combined_df.head(1)


Unnamed: 0,smiles,source,binds_to_rna,is_disconnected
0,O=C(Nc1ccc2ccccc2c1)c1ccc2c(c1)C(=O)N(c1cccc(N...,chemdiv,1,False


In [None]:
# Your combined_df
# combined_df = pd.read_csv("your_dataset.csv")

# Initialize featurizers
atom_featurizer = CanonicalAtomFeaturizer()
bond_featurizer = CanonicalBondFeaturizer(self_loop=True)

# Function to convert SMILES to a graph
def smiles_to_graph(smiles_string):
    graph_constructor = SMILESToBigraph(add_self_loop=True,  # Add self loops
                                        node_featurizer=atom_featurizer,
                                        edge_featurizer=bond_featurizer)
    return graph_constructor(smiles_string)

# Convert SMILES to graphs
combined_df['graph'] = combined_df['smiles'].apply(smiles_to_graph)

Unnamed: 0,smiles,source,binds_to_rna,is_disconnected,graph
0,O=C(Nc1ccc2ccccc2c1)c1ccc2c(c1)C(=O)N(c1cccc(N...,chemdiv,1,False,"Graph(num_nodes=54, num_edges=178,\n ndat..."
1,O=C(CSc1nnc(-c2ccccc2Cl)n1-c1ccccc1)c1ccc2c(c1...,chemdiv,1,False,"Graph(num_nodes=35, num_edges=115,\n ndat..."


In [None]:
combined_df['graph']


0        Graph(num_nodes=54, num_edges=178,\n      ndat...
1        Graph(num_nodes=35, num_edges=115,\n      ndat...
2        Graph(num_nodes=35, num_edges=115,\n      ndat...
3        Graph(num_nodes=47, num_edges=151,\n      ndat...
4        Graph(num_nodes=40, num_edges=128,\n      ndat...
                               ...                        
73863    Graph(num_nodes=34, num_edges=108,\n      ndat...
73864    Graph(num_nodes=25, num_edges=79,\n      ndata...
73865    Graph(num_nodes=22, num_edges=72,\n      ndata...
73866    Graph(num_nodes=48, num_edges=146,\n      ndat...
73867    Graph(num_nodes=24, num_edges=76,\n      ndata...
Name: graph, Length: 73868, dtype: object

In [None]:
import dgl

# Assuming 'binds_to_rna' is your label
graphs = combined_df['graph'].tolist()
labels = {'binds_to_rna': torch.tensor(combined_df['binds_to_rna'].values)}

# Save graphs and labels
dgl.save_graphs("data_mvi/graphs.bin", graphs, labels)



In [None]:
# Drop the 'graph' column and save the DataFrame
combined_df.drop(columns=['graph']).to_csv("data_mvi/combined_df.csv", index=False)


In [None]:
combined_df

Unnamed: 0,smiles,source,binds_to_rna,is_disconnected,graph
0,O=C(Nc1ccc2ccccc2c1)c1ccc2c(c1)C(=O)N(c1cccc(N...,chemdiv,1,False,"Graph(num_nodes=54, num_edges=178,\n ndat..."
1,O=C(CSc1nnc(-c2ccccc2Cl)n1-c1ccccc1)c1ccc2c(c1...,chemdiv,1,False,"Graph(num_nodes=35, num_edges=115,\n ndat..."
2,Cc1ccc(-n2c(=O)c3c4c(sc3n3c(SCC(=O)c5ccccc5)nn...,chemdiv,1,False,"Graph(num_nodes=35, num_edges=115,\n ndat..."
3,O=C(Nc1ccc(C(=O)c2ccccc2)cc1)c1ccc(Oc2ccc(C(=O...,chemdiv,1,False,"Graph(num_nodes=47, num_edges=151,\n ndat..."
4,O=C(Nc1ccc(Oc2cccc(Oc3ccc(NC(=O)c4ccccc4Cl)cc3...,chemdiv,1,False,"Graph(num_nodes=40, num_edges=128,\n ndat..."
...,...,...,...,...,...
73863,C=CC(=O)Nc1cccc(Nc2nc(N[C@H]3CC[C@H](N(C)C)CC3...,robin,1,False,"Graph(num_nodes=34, num_edges=108,\n ndat..."
73864,N#C/C(C(=O)c1ccc(Cl)cc1Cl)=C1\NC(=O)c2ccc(Cl)c...,robin,1,False,"Graph(num_nodes=25, num_edges=79,\n ndata..."
73865,C[C@H](N[C@H]1C[C@H]1c1ccccc1)c1ccc2c(c1)OCCO2,robin,1,False,"Graph(num_nodes=22, num_edges=72,\n ndata..."
73866,NCCC[C@@H](N)CC(=O)N[C@H]1CNC(=O)[C@@H]([C@@H]...,robin,1,False,"Graph(num_nodes=48, num_edges=146,\n ndat..."


In [124]:
# Load the DataFrame
reloaded_df = pd.read_csv("data_mvi/combined_df.csv")

# Load the graphs
graphs, labels = dgl.load_graphs("data_mvi/graphs.bin")

# Labels are returned as a dictionary, convert to the desired format if necessary
binds_to_rna = labels['binds_to_rna']


In [125]:
reloaded_df

Unnamed: 0,smiles,source,binds_to_rna,is_disconnected
0,O=C(Nc1ccc2ccccc2c1)c1ccc2c(c1)C(=O)N(c1cccc(N...,chemdiv,1,False
1,O=C(CSc1nnc(-c2ccccc2Cl)n1-c1ccccc1)c1ccc2c(c1...,chemdiv,1,False
2,Cc1ccc(-n2c(=O)c3c4c(sc3n3c(SCC(=O)c5ccccc5)nn...,chemdiv,1,False
3,O=C(Nc1ccc(C(=O)c2ccccc2)cc1)c1ccc(Oc2ccc(C(=O...,chemdiv,1,False
4,O=C(Nc1ccc(Oc2cccc(Oc3ccc(NC(=O)c4ccccc4Cl)cc3...,chemdiv,1,False
...,...,...,...,...
73863,C=CC(=O)Nc1cccc(Nc2nc(N[C@H]3CC[C@H](N(C)C)CC3...,robin,1,False
73864,N#C/C(C(=O)c1ccc(Cl)cc1Cl)=C1\NC(=O)c2ccc(Cl)c...,robin,1,False
73865,C[C@H](N[C@H]1C[C@H]1c1ccccc1)c1ccc2c(c1)OCCO2,robin,1,False
73866,NCCC[C@@H](N)CC(=O)N[C@H]1CNC(=O)[C@@H]([C@@H]...,robin,1,False


In [126]:
from sklearn.model_selection import train_test_split

# Assuming 'binds_to_rna' is your target variable for stratification
labels = reloaded_df['binds_to_rna'].values

# Generate a list of indices from the DataFrame
indices = range(len(reloaded_df))

# Perform stratified split
train_indices, test_indices, train_labels, test_labels = train_test_split(
    indices, labels, test_size=0.2, stratify=labels, random_state=42)

# Now you can use train_indices and test_indices to split your graphs and labels
train_graphs = [graphs[i] for i in train_indices]
test_graphs = [graphs[i] for i in test_indices]

# If you saved labels as a tensor, you can also split it according to indices
train_labels = labels[train_indices]
test_labels = labels[test_indices]


In [127]:
train_graphs

[Graph(num_nodes=18, num_edges=56,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=20, num_edges=62,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=25, num_edges=77,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=35, num_edges=109,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=18, num_edges=56,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=22, num_edges=70,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes

# Graph Attention Networks (GAT)

In [217]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [218]:
# train_dataset = list(zip(train_graphs, train_labels))
# test_dataset = list(zip(test_graphs, test_labels))

In [219]:
class GATv2Model(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout):
        super(GATv2Model, self).__init__()
        self.conv1 = dglnn.GATv2Conv(in_dim, hidden_dim, num_heads=num_heads)
        self.conv2 = dglnn.GATv2Conv(hidden_dim * num_heads, hidden_dim, num_heads=num_heads)
        self.conv3 = dglnn.GATv2Conv(hidden_dim * num_heads, hidden_dim, num_heads=num_heads)
        
        # Define the gate_nn for GlobalAttentionPooling
        self.gate_nn = nn.Linear(hidden_dim * num_heads, 1)
        self.pool = GlobalAttentionPooling(self.gate_nn)
        
        self.classifier = nn.Linear(hidden_dim * num_heads, out_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, g, h):
        h = self.dropout(h)
        h = self.conv1(g, h)
        h = F.elu(h)
        h = self.dropout(h)
        
        h = self.conv2(g, h)
        h = F.elu(h)
        h = self.dropout(h)
        
        h = self.conv3(g, h)
        h = F.elu(h)

        # Apply global attention pooling to get graph-level representation
        hg = self.pool(g, h)
        
        h = self.classifier(hg)
        
        return h


In [220]:
# from dgl.nn.pytorch.glob import GlobalAttentionPooling

# class GATv2Model(nn.Module):
#     def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout):
#         super(GATv2Model, self).__init__()
#         self.conv1 = dglnn.GATv2Conv(in_dim, hidden_dim, num_heads=num_heads, allow_zero_in_degree=True)
#         self.conv2 = dglnn.GATv2Conv(hidden_dim * num_heads, hidden_dim, num_heads=num_heads, allow_zero_in_degree=True)
#         self.conv3 = dglnn.GATv2Conv(hidden_dim * num_heads, hidden_dim, num_heads=num_heads, allow_zero_in_degree=True)
        
#         # Define global attention pooling layer
#         self.attention_pool = GlobalAttentionPooling(gate_nn=nn.Linear(hidden_dim * num_heads, 1))
        
#         self.dropout = nn.Dropout(dropout)
#         self.classifier = nn.Linear(hidden_dim * num_heads, out_dim)

#     def forward(self, g, h):
#         h = self.dropout(h)
#         h, attn_weights1 = self.conv1(g, h)
#         h = self.dropout(h)
#         h, attn_weights2 = self.conv2(g, h)
#         h = self.dropout(h)
#         h, attn_weights3 = self.conv3(g, h)

#         # Apply global attention pooling to aggregate node features
#         hg = self.attention_pool(g, h)

#         # Classifier
#         hg = self.dropout(hg)
#         out = self.classifier(hg)
#         return out, attn_weights1, attn_weights2, attn_weights3



In [221]:
from torch.utils.data import DataLoader

class GraphDataset(Dataset):
    def __init__(self, graphs, labels):
        self.graphs = graphs
        self.labels = labels

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        graph = self.graphs[idx]
        label = self.labels[idx]
        return graph, label




In [222]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    
    # Check if labels are already tensors (which might be the case if __getitem__ converts them)
    # If labels are not tensors, convert them to a tensor
    if not all(isinstance(label, torch.Tensor) for label in labels):
        labels = torch.tensor(labels, dtype=torch.long)
    else:
        # If labels are already tensors, stack them into a single tensor
        labels = torch.stack(labels, dim=0)
    
    return batched_graph, labels


In [223]:
batch_size = 64  # Adjust batch size as necessary

train_dataset = GraphDataset(train_graphs, train_labels)
test_dataset = GraphDataset(test_graphs, test_labels)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate, num_workers=4)



In [224]:
for batched_graphs, labels in train_loader:
    print(type(batched_graphs))
    # To get the number of graphs in the batch, you can check the length of the batch_num_nodes list
    print(batched_graphs.batch_num_nodes().shape[0])
    break  # Just to check the first batch


<class 'dgl.heterograph.DGLGraph'>
64


In [225]:
def check_graph_features(graphs):
    for i, g in enumerate(graphs):
        if 'h' not in g.ndata:
            return False, f"Graph {i} does not have 'h' node feature."
    return True, "All graphs have 'h' node feature."

# Assuming 'graphs' is a list of your DGL Graphs
check_result, message = check_graph_features(train_graphs)
print(message)


All graphs have 'h' node feature.


In [226]:
# def objective(trial):
#     # Hyperparameters
#     in_dim = 74  # Confirm this matches your dataset
#     hidden_dim = trial.suggest_int('hidden_dim', 4, 128)
#     out_dim = 2  # For binary classification
#     num_heads = trial.suggest_int('num_heads', 1, 15)
#     dropout = trial.suggest_float('dropout', 0.1, 0.5)
#     lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
#     weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-1, log=True)

#     model = GATv2Model(in_dim, hidden_dim, out_dim, num_heads, dropout).to(device)
#     criterion = torch.nn.BCEWithLogitsLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

#     # Debug: Print model summary (optional, remove in production)
#     print(model)

#     for epoch in range(3):  # Reduced epoch for quick debugging
#         model.train()
#         for batched_graphs, labels in train_loader:
#             batched_graphs = batched_graphs.to(device)
#             labels = labels.to(device)

#             # Ensure 'h' node feature exists and has correct dimension
#             if 'h' in batched_graphs.ndata:
#                 h = batched_graphs.ndata['h'].to(device)
#                 if h.shape[1] != in_dim:  # Debugging check
#                     raise ValueError(f"Feature size mismatch. Expected {in_dim}, got {h.shape[1]}")
#             else:
#                 raise ValueError("Graphs do not have 'h' node feature. Please check data preprocessing.")

#             optimizer.zero_grad()
#             out, attn_weights1, attn_weights2, attn_weights3 = model(batched_graphs, h)  # change back to model(batched_graphs, h)          
#             loss = criterion(out, labels)
#             loss.backward()
#             optimizer.step()

#             # Debug: Print shapes for troubleshooting (optional, remove in production)
#             print(f"Output shape: {out.shape}, Labels shape: {labels.shape}")

#     # Simplified return for debugging
#     return 0  # Placeholder return for debugging purposes

# # Assuming train_loader is defined somewhere else
# # Make sure to adjust your DataLoader to use the correct collate function


In [232]:
def objective(trial):
    # Hyperparameters
    in_dim = 74  # Adjust based on your dataset
    hidden_dim = trial.suggest_int('hidden_dim', 4, 128)
    out_dim = 2  # For binary classification
    num_heads = trial.suggest_int('num_heads', 1, 15)
    dropout = trial.suggest_float('dropout', 0.1, 0.5)
    lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-1, log=True)
    patience = 20  # Early stopping patience

    # Model, loss, and optimizer
    model = GATv2Model(in_dim, hidden_dim, out_dim, num_heads, dropout).to(device)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Metrics
    best_accuracy = 0
    epochs_no_improve = 0

    # Training loop
    for epoch in range(500):  # Adjust the number of epochs
        model.train()
        train_loss = 0
        for batched_graphs, labels in train_loader:
            batched_graphs = batched_graphs.to(device)
            if 'h' in batched_graphs.ndata:
                h = batched_graphs.ndata['h'].to(device)
            else:
                raise ValueError("Graphs do not have 'h' node feature. Please check data preprocessing.")
            labels = labels.to(device)

            optimizer.zero_grad()
            out = model(batched_graphs, h)  # Adjusted to receive a single tensor
            loss = criterion(out.squeeze(), labels.float())  # Ensure labels are float for BCEWithLogitsLoss
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0
        correct = 0
        total = 0  # To calculate accuracy correctly
        with torch.no_grad():
            for batched_graphs, labels in test_loader:
                batched_graphs = batched_graphs.to(device)
                h = batched_graphs.ndata['h'].to(device)
                labels = labels.to(device)
                out = model(batched_graphs, h)  # Adjusted to receive a single tensor
                val_loss += criterion(out, labels.float()).item()  # Ensure labels are float
                predicted = torch.round(torch.sigmoid(out))  # Use sigmoid and round for binary classification
                correct += (predicted == labels.unsqueeze(1)).sum().item()
                total += labels.size(0)

        accuracy = correct / total

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print(f"Early stopping triggered after {epoch + 1} epochs.")
            break

        # Optional: Reduce logging frequency
        if epoch % 10 == 0:
            print(f'Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Accuracy: {accuracy:.4f}')

    return best_accuracy


In [233]:
# def train(model, optimizer, criterion, train_loader, device):
#     model.train()
#     total_loss = 0
#     for batched_graphs, labels in train_loader:
#         batched_graphs.to(device)
#         h = batched_graphs.ndata['h'].to(device)
#         labels = labels.to(device)
        
#         optimizer.zero_grad()
#         out, _, _, _ = model(batched_graphs, h)
#         loss = criterion(out, labels)
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss.item()
#     return total_loss / len(train_loader)


In [234]:
# def evaluate(model, criterion, test_loader, device):
#     model.eval()
#     total_loss = 0
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for batched_graphs, labels in test_loader:
#             batched_graphs.to(device)
#             h = batched_graphs.ndata['h'].to(device)
#             labels = labels.to(device)
            
#             out, _, _, _ = model(batched_graphs, h)
#             loss = criterion(out, labels)
#             total_loss += loss.item()
#             _, predicted = torch.max(out, 1)
#             correct += (predicted == labels).sum().item()
#             total += labels.size(0)
#     accuracy = correct / total
#     return accuracy, total_loss / len(test_loader)


In [235]:
# def objective(trial):
#     # Hyperparameters
#     in_dim = 74  # Adjust based on your dataset
#     hidden_dim = trial.suggest_int('hidden_dim', 4, 128)
#     out_dim = 2  # For binary classification
#     num_heads = trial.suggest_int('num_heads', 1, 15)
#     dropout = trial.suggest_float('dropout', 0.1, 0.5)
#     lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
#     weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-1, log=True)
#     patience = 20
#     best_accuracy = 0
#     epochs_no_improve = 0
    
#     for epoch in range(500):
#         train_loss = train(model, optimizer, criterion, train_loader, device)
#         accuracy, val_loss = evaluate(model, criterion, test_loader, device)
        
#         if accuracy > best_accuracy:
#             best_accuracy = accuracy
#             epochs_no_improve = 0
#         else:
#             epochs_no_improve += 1

#         if epochs_no_improve == patience:
#             print(f"Early stopping triggered after {epoch + 1} epochs.")
#             break

#         if epoch % 10 == 0:  # Logging
#             print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {accuracy:.4f}')
    
#     return best_accuracy


<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [236]:

# Optimize hyperparameters
criterion = torch.nn.CrossEntropyLoss()
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50, n_jobs=1)


[I 2024-03-04 17:12:40,698] A new study created in memory with name: no-name-861c0c8c-5fe9-4f7b-ad03-b256e7e9bc2a


[W 2024-03-04 17:12:41,274] Trial 0 failed with parameters: {'hidden_dim': 122, 'num_heads': 5, 'dropout': 0.3715454387122761, 'lr': 0.0004424168105456859, 'weight_decay': 0.0023873636666766557} because of the following error: RuntimeError('mat1 and mat2 shapes cannot be multiplied (8335x122 and 610x610)').
Traceback (most recent call last):
  File "/home/xfulop/miniconda3/envs/gnn/lib/python3.8/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_28386/406744428.py", line 34, in objective
    out = model(batched_graphs, h)  # Adjusted to receive a single tensor
  File "/home/xfulop/miniconda3/envs/gnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xfulop/miniconda3/envs/gnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8335x122 and 610x610)

In [None]:
# Ensure consistent style
sns.set_style("whitegrid")

plt.figure(figsize=(14, 5))

# Plot training and validation loss
plt.subplot(1, 2, 1)
sns.lineplot(range(len(train_losses)), train_losses, label='Training Loss')
sns.lineplot(range(len(val_losses)), val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

# Plot accuracy
plt.subplot(1, 2, 2)
sns.lineplot(range(len(accuracies)), accuracies, label='Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

plt.show()

In [None]:
# Assuming the best parameters are stored in study.best_params
best_params = study.best_params
best_params


In [None]:
import numpy as np

train_losses = []
test_losses = []
test_accuracies = []

# Adjusted training loop to track training loss
for epoch in range(500):  # Adjust the number of epochs as needed
    model.train()
    epoch_train_loss = 0
    for batched_graph, labels in train_loader:  # Corrected variable names for clarity
        batched_graph = batched_graph.to(device)
        labels = labels.to(device).float()  # Ensure labels are floats
        optimizer.zero_grad()
        out, _, _, _ = model(batched_graph, batched_graph.ndata['h'])
        loss = criterion(out, labels.unsqueeze(1))  # Adjust label shape if necessary
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item() * labels.size(0)  # Accumulate loss correctly
    epoch_train_loss /= len(train_loader.dataset)  # Calculate mean loss for the epoch
    train_losses.append(epoch_train_loss)

    # Evaluation loop
    model.eval()
    epoch_test_loss = 0
    correct = 0
    with torch.no_grad():
        for batched_graph, labels in test_loader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device).float()
            out, _, _, _ = model(batched_graph, batched_graph.ndata['h'])
            loss = criterion(out, labels.unsqueeze(1))
            epoch_test_loss += loss.item() * labels.size(0)  # Accumulate loss correctly
            predictions = torch.sigmoid(out) >= 0.5  # Convert to binary predictions
            correct += predictions.eq(labels.unsqueeze(1)).sum().item()  # Calculate correct predictions
    epoch_test_loss /= len(test_loader.dataset)
    test_losses.append(epoch_test_loss)
    test_accuracy = correct / len(test_loader.dataset)  # Calculate accuracy
    test_accuracies.append(test_accuracy)

    print(f"Epoch {epoch+1}: Train Loss = {epoch_train_loss:.4f}, Test Loss = {epoch_test_loss:.4f}, Test Accuracy = {test_accuracy:.4f}")


In [None]:
# save the model
model_path = 'model.pth'
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

In [None]:
sns.set(style="whitegrid")

plt.figure(figsize=(16, 6))

plt.subplot(1, 2, 1)
sns.lineplot(x=np.arange(1, 501), y=train_losses, label='Train Loss')
sns.lineplot(x=np.arange(1, 501), y=test_losses, label='Test Loss')
plt.title('Train vs. Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
sns.lineplot(x=np.arange(1, 501), y=test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

plt.show()


In [None]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_trainable_parameters(model)
print(f"Total Trainable Parameters: {total_params}")


In [None]:
sns.set(style="whitegrid")

plt.figure(figsize=(16, 6))

plt.subplot(1, 2, 1)
sns.lineplot(x=np.arange(1, len(train_losses) + 1), y=train_losses, label='Train Loss')
sns.lineplot(x=np.arange(1, len(test_losses) + 1), y=test_losses, label='Test Loss')
plt.title(f'Train vs. Test Loss (Total Parameters: {total_params})')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
sns.lineplot(x=np.arange(1, len(test_accuracies) + 1), y=test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

plt.tight_layout()
plt.show()


# Other

In [None]:
# Assuming `model` is your trained model, `test_graphs` is your list of test graphs, and `test_labels` is your list of test labels
accuracy, attn_weights1, attn_weights2, attn_weights3 = evaluate_model(model, test_graphs, test_labels)

print(f"Test Accuracy: {accuracy}")

# Example of analyzing or visualizing the attention weights
# This is a simple example that prints the shape of the attention weights
print(f"Attention Weights Shape (Layer 1): {attn_weights1.shape}")
print(f"Attention Weights Shape (Layer 2): {attn_weights2.shape}")
print(f"Attention Weights Shape (Layer 3): {attn_weights3.shape}")

In [None]:

# Create a DataFrame from the metrics dictionary
metrics_df = pd.DataFrame([metrics])

# Print the metrics in table format
print(metrics_df)

# Plot the metrics using Seaborn
plt.figure(figsize=(10, 5))
sns.barplot(x=metrics_df.columns, y=metrics_df.iloc[0], data=metrics_df)
plt.title('Model Metrics')
plt.ylabel('Value')
plt.show()


