# Imports

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils import data as torch_data
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
from typing import Final
import pandas as pd
import dataclasses
from etils import epath
import torch_scatter
import gc
from tqdm.notebook import tqdm
from graphmodels import featurizer
from graphmodels import constants
from graphmodels import datasets
from graphmodels import data_utils
from graphmodels import models
from graphmodels.layers import graph_attention_layers
from graphmodels.models import gat
from torch.utils.data import DataLoader, Dataset
from sklearn import model_selection as sk_modelselection
from sklearn import metrics as sk_metrics

from torch_geometric import nn as geom_nn
from torch_geometric.loader import DataLoader as GeomDataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import PPI
from torch_geometric.nn import models




from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.drawOptions.addAtomIndices = True
print(torch.cuda.is_available())

True


In [2]:
DATAPATH: epath.Path = epath.Path("../datasets/chemistry/")
df = pd.read_csv(DATAPATH/"delaney-processed_clean.csv")
df["mol"] = df["RDKIT_SMILES"].apply(Chem.MolFromSmiles)
df["num_bonds"] = df["mol"].apply(lambda x : x.GetNumBonds())
df = df.loc[df["num_bonds"]>=1]

In [3]:
# Split data
train_df, test_df = sk_modelselection.train_test_split(
    df, test_size=0.3, random_state=42, shuffle=True,
    #stratify=df["measured log solubility in mols per litre"]
)
train_df, valid_df = sk_modelselection.train_test_split(
    train_df,
    test_size=0.15,
    random_state=42,
    shuffle=True,
    #stratify=train_df["measured log solubility in mols per litre"],
)

In [4]:
LABEL: str = "measured log solubility in mols per litre"
SMILES: str = "smiles"
train_dset = datasets.mpnn_dataset.MPNNDataset(smiles=tuple(train_df[SMILES]),
                                               targets=tuple(train_df[LABEL]),
                                 add_master_node=False)

valid_dset = datasets.mpnn_dataset.MPNNDataset(smiles=tuple(valid_df[SMILES]),
                                               targets=tuple(valid_df[LABEL]),
                                 add_master_node=False)

test_dset = datasets.mpnn_dataset.MPNNDataset(smiles=tuple(test_df[SMILES]),
                                               targets=tuple(test_df[LABEL]),
                                add_master_node=False)


train_dataloader = DataLoader(dataset=train_dset,
                              batch_size=32,
                              shuffle=True,
                              collate_fn=data_utils.mpnn_collate_diag,
                             )

valid_dataloader = DataLoader(dataset=valid_dset,
                              batch_size=32,
                              shuffle=False,
                              collate_fn=data_utils.mpnn_collate_diag,
                             )

test_dataloader = DataLoader(dataset=test_dset,
                              batch_size=32,
                              shuffle=False,
                              collate_fn=data_utils.mpnn_collate_diag,
                            )


#dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')
# train_dset = PPI(root='/tmp/PPI', split='train')
# val_dset = PPI(root='/tmp/PPI', split='val')
# test_dset = PPI(root='/tmp/PPI', split='test')


# # Create data loaders
# train_loader = GeomDataLoader(train_dset, batch_size=16, shuffle=True)
# val_loader = GeomDataLoader(val_dset, batch_size=32, shuffle=False)
# test_loader = GeomDataLoader(test_dset, batch_size=32, shuffle=False)

# # Extract labels for stratified splitting
# labels = [data.y.item() for data in dataset]

# # First split: train vs temp (val + test)
# train_idx, temp_idx = sk_modelselection.train_test_split(
#     list(range(len(dataset))),
#     test_size=0.3,
#     stratify=labels,
#     random_state=42,
# )

# # Second split: val vs test from temp
# temp_labels = [labels[i] for i in temp_idx]
# val_idx, test_idx = sk_modelselection.train_test_split(
#     temp_idx,
#     test_size=0.5,
#     stratify=temp_labels,
#     random_state=42,
# )

# # Create the datasets
# train_dataset = dataset[train_idx]
# val_dataset = dataset[val_idx]
# test_dataset = dataset[test_idx]

# # Create data loaders
# train_loader = GeomDataLoader(train_dataset, batch_size=16, shuffle=True)
# val_loader = GeomDataLoader(val_dataset, batch_size=32, shuffle=False)
# test_loader = GeomDataLoader(test_dataset, batch_size=32, shuffle=False)

In [5]:
first_batch = next(iter(train_dataloader))

In [8]:
layer = graph_attention_layers.MultiHeadGATLayer(n_node_features=136, apply_act=False, dropout=0.25, n_hidden_features=64, num_heads=8, scaling=0.2)
layer(first_batch.node_features, first_batch.edge_index).shape

torch.Size([504, 64])

In [35]:
import torch
from jaxtyping import Float, Int
from jaxtyping import jaxtyped as jt
from torch import nn
from typeguard import typechecked as typechecker

from graphmodels.layers import graph_attention_layers
from graphmodels.models import constants as model_constants


@jt(typechecker=typechecker)
class GATModel(nn.Module):
    """Implements a Graph Attention Network (GAT).

    This model consists of a stack of GAT layers with skip connections.
    Only node features are supported in this implementation.
    """

    def __init__(
        self,
        n_node_features: int,
        n_hidden_features: int,
        n_out_channels: int,
        num_layers: int,
        num_heads: int,
        scaling: float,
        dropout: float,
        output_level: model_constants.OutputLevel | str,
    ):
        super().__init__()

        if output_level not in model_constants.ALLOWED_OUTPUT_LEVEL:
            raise ValueError(f"{output_level} isnt a valid output.")

        self.output_level = output_level

        gat_layers = [graph_attention_layers.MultiHeadGATLayer(
                    n_node_features=n_node_features,
                    n_hidden_features=n_hidden_features,
                    num_heads=num_heads,
                    dropout=dropout,
                    scaling=scaling,
                    apply_act=True)]

        for i in range(num_layers-1):
            gat_layers.append(graph_attention_layers.MultiHeadGATLayer(
                    n_node_features=n_hidden_features,
                    n_hidden_features=n_hidden_features,
                    num_heads=num_heads,
                    dropout=dropout,
                    scaling=scaling,
                    apply_act=True,
                ))

        self.conv_layers = nn.ModuleList(gat_layers)

        self.output_layer = graph_attention_layers.MultiHeadGATLayer(
            n_node_features=n_hidden_features,
            n_hidden_features=n_out_channels,
            num_heads=1,
            dropout=dropout,
            scaling=scaling,
            apply_act=False,
        )

    def readout(
        self,
        x: Float[torch.Tensor, "nodes features"],
        edge_index: Int[torch.Tensor, "2 edges"],
        batch_vector: Int[torch.Tensor, " batch"],
    ) -> Float[torch.Tensor, "out n_out_channels"]:
        
        emb_dim = x.size(-1)
        num_batches = int(batch_vector.max()) + 1
        
        if self.output_level == model_constants.OutputLevel.GRAPH:
            mol_embeddings = torch.zeros(
                num_batches,
                emb_dim,
                device=x.device,
            )

            mol_embeddings.index_add_(0, batch_vector, x)

            return mol_embeddings

        return x

    def forward(
        self,
        node_features: Float[torch.Tensor, "nodes node_features"],
        edge_index: Int[torch.Tensor, "2 edges"],
        batch_vector: Int[torch.Tensor, " batch"],
    ) -> Float[torch.Tensor, "out n_out_channels"]:

        for layer in self.conv_layers:
            
            node_features = layer(
                node_features=node_features,
                edge_index=edge_index,
            )

        out = self.output_layer(node_features, edge_index)
        return self.readout(out, edge_index, batch_vector)


In [40]:
model = GATModel(n_node_features=136, n_hidden_features=255, n_out_channels=1, dropout=0.25, num_heads=3, num_layers=5, scaling=0.2, output_level="graph")
out = model(node_features=first_batch.node_features, edge_index=first_batch.edge_index, batch_vector=first_batch.batch_vector)
out.shape, first_batch.node_features.shape

(torch.Size([32, 1]), torch.Size([504, 136]))

In [41]:
out

tensor([[-0.7260],
        [ 1.4188],
        [-5.2592],
        [ 1.6417],
        [ 0.0660],
        [-1.3101],
        [ 4.7412],
        [-1.0956],
        [ 2.4441],
        [ 0.2318],
        [ 4.0344],
        [ 1.3035],
        [-1.4586],
        [ 0.8494],
        [ 3.4144],
        [-6.1151],
        [ 1.2773],
        [-0.7131],
        [ 1.6294],
        [ 1.4579],
        [ 2.7853],
        [ 1.3618],
        [ 2.4237],
        [ 3.4000],
        [ 3.1446],
        [-5.9464],
        [ 5.5217],
        [10.0874],
        [-4.6524],
        [-1.6082],
        [ 9.5441],
        [-6.9917]], grad_fn=<IndexAddBackward0>)

In [27]:
out.shape

torch.Size([32, 255])

In [23]:
model.output_layer(out, first_batch.edge_index)

IndexError: index 32 is out of bounds for dimension 0 with size 32

In [20]:
out.shape

torch.Size([32, 255])

In [8]:
def train_model(model: nn.Module,
                train_loader: DataLoader,
                valid_loader: DataLoader,
                loss_fn: torch.nn.modules.loss,
                epochs: int=10,
                lr: float=1e-3,
                max_learning_rate: float=1e-2,
                weight_decay: float=0.1,
                device: str='cuda'):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_learning_rate, steps_per_epoch=len(train_loader), epochs=epochs)
    
    history = []

    model.to(device)
    print("🚀 Starting training...\n")
    pbar = tqdm(total=epochs, desc="Training")

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for batch_idx, train_batch in enumerate(train_loader):
            atom_feats = train_batch.x.to(device)
            #bond_features =  batch.edge_features.to(device)
            edge_index = train_batch.edge_index.to(device)
            labels = train_batch.y.to(device)
            batch_vector = train_batch.batch.to(device)

            optimizer.zero_grad()
            outputs = model(node_features=atom_feats, edge_index=edge_index, batch_vector=batch_vector).squeeze()

            loss = loss_fn(outputs, labels.squeeze().to(torch.float32))
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        # Validation phase
        model.eval()
        valid_loss = 0.0
        valid_batch_loss = []
        with torch.no_grad():
            for valid_batch_idx, val_batch in enumerate(valid_loader):
                atom_feats = val_batch.x.to(device)
                #bond_features =  batch.edge_features.to(device)
                edge_index = val_batch.edge_index.to(device)
                labels = val_batch.y.to(device)
                batch_vector = val_batch.batch.to(device)

                outputs = model(node_features=atom_feats, edge_index=edge_index, batch_vector=batch_vector).squeeze()
                loss = loss_fn(outputs, labels.squeeze().to(torch.float32))
                valid_loss += loss.item()
        valid_loss /= len(valid_loader)

        # Update the tqdm progress bar with train and validation loss
        pbar.set_postfix({"train_loss": f"{train_loss:.4f}", "val_loss": f"{valid_loss:.4f}"})

        history.append({'epoch': epoch+1, 'train_loss': train_loss, 'valid_loss': valid_loss})
        pbar.update(1)

    pbar.close()
    print("\n🎉 Training completed!\n")
    return pd.DataFrame(history)


In [10]:
#multihead_gat_model = MultiheadGATModel(n_node_features=136,
                                  # n_hidden_features=512,
                                  # n_out_features=1,
                                  # dropout=0.1,
                                  # agg_method="mean",
                                  # num_heads=4)
gc.collect()
torch.cuda.empty_cache()
gat_model = SimpleGATModel(n_node_features=50,
                           n_hidden_features=50,
                           n_out_features=121,
                          )
                           #dropout=0.1)

#geom_model = models.GAT(in_channels=3, hidden_channels=12, num_layers=1, out_channels=1)
#geom_model = GATGraphClassifier(in_channels=3, hidden_channels=12, num_layers=3, out_channels=1)
history = train_model(model=gat_model,
                      loss_fn=nn.BCEWithLogitsLoss(),
                      train_loader=train_loader,
                      valid_loader=val_loader,
                      epochs=20,
                      lr=5e-4,
                      max_learning_rate=5e-3,
                     weight_decay=1e-3,
                      device="cuda")

gc.collect()
torch.cuda.empty_cache()

🚀 Starting training...



Training:   0%|          | 0/20 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 364.00 MiB. GPU 0 has a total capacity of 3.63 GiB of which 179.00 MiB is free. Including non-PyTorch memory, this process has 2.71 GiB memory in use. Of the allocated memory 2.48 GiB is allocated by PyTorch, and 164.85 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
ax = sns.lineplot(history, x="epoch", y="train_loss",color="green")
ax = sns.lineplot(history, x="epoch", y="valid_loss", color="orange")

In [None]:
gat_model.eval()
device="cpu"
gat_model.to(device)

preds = []
target_labels = []
with torch.no_grad():
    for test_batch in test_loader:
        # Sort edges by target node to make scatter_softmax deterministic
        sorted_idx = test_batch.edge_index[0].argsort()
        edge_index_sorted = test_batch.edge_index[:, sorted_idx]

        outputs = gat_model(node_features= test_batch.x, edge_index=edge_index_sorted, batch_vector=test_batch.batch).squeeze()
        #preds.append(outputs)
        probs = torch.sigmoid(outputs)
        preds.append((probs > 0.5).int())
        target_labels.append(test_batch.y)

preds = torch.cat(preds)
target_labels = torch.cat(target_labels)

In [None]:
target_labels

In [None]:
preds

In [None]:
print(f"MCC = {sk_metrics.matthews_corrcoef(target_labels, preds)}")
print(f"ACC = {sk_metrics.accuracy_score(target_labels, preds)}")

In [None]:
# print(f"MAE = {sk_metrics.mean_absolute_error(target_labels, preds)}")
# print(f"RMSE = {sk_metrics.root_mean_squared_error(target_labels, preds)}")