<a href="https://colab.research.google.com/github/kmheckel/MPNN_QC/blob/main/GGNNs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Things to do: Data loading onto GPU, Tuning learning rate, make networks and settings as general as possible, use validation data set, showable loss improvement, standardise the input data
### Preliminaries and loading data

In [1]:


#@title [RUN] Sanity check torch version and GPU runtime

import torch
assert torch.cuda.is_available(), "WARNING! You are running on a non-GPU instance. For this practical a GPU is highly recommended."
REQUIRED_VERSION = "2.1.0+cu121"
TORCH_VERSION = torch.__version__
CUDA_VERSION = TORCH_VERSION.split("+")

if TORCH_VERSION != REQUIRED_VERSION:
  print(f"Detected torch version {TORCH_VERSION}, but notebook was created for {REQUIRED_VERSION}")
  print(f"Attempting installation of {REQUIRED_VERSION}")
  !pip install torch==2.1.0+cu121
print("Correct version of torch detected. You are running on a machine with GPU.")


Correct version of torch detected. You are running on a machine with GPU.


In [2]:
#@title [RUN] Install required python libraries
import os

# Install PyTorch Geometric and other libraries
if 'IS_GRADESCOPE_ENV' not in os.environ:
    print("Installing PyTorch Geometric")
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-geometric
    print("Installing other libraries")
    !pip install -q rdkit-pypi==2021.9.4
    !pip install -q py3Dmol==1.8.0
    !pip install networkx
    !pip install mycolorpy
    !pip install colorama

Installing PyTorch Geometric
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m58.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling other libraries
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.8/20.8 MB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Collecting mycolorpy
  Downloading mycolorpy-1.5.1.tar.gz (2.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mycolorpy
  Building wheel for mycolorpy (setup.py) ... [?25l[?25hdone
  Created wheel for mycolorpy: filename=mycolorpy-1.5.1-py3-none-any.whl size=3851 sha256=979f9f2cce48bc78f4ee544631b14dbcb9a23ef9ab23f942eeb02aa9dc86a60a
  Stored in directory: /root/.cache/pip/wheels/d6/1d/04/f8bfb8b04e3daba3

In [3]:
#@title [RUN] Import python modules

import os
import sys
import time
import math
import random
import itertools
from datetime import datetime
from typing import Mapping, Tuple, Sequence, List

import pandas as pd
import networkx as nx
import numpy as np
import scipy as sp
from scipy.stats import ortho_group
from scipy.linalg import block_diag

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn import Embedding, Linear, ReLU, BatchNorm1d, Module, ModuleList, Sequential

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse
from torch_geometric.datasets import Planetoid, QM9
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter, scatter_mean, scatter_max, scatter_sum

import rdkit.Chem as Chem
from rdkit.Geometry.rdGeometry import Point3D
from rdkit.Chem import QED, Crippen, rdMolDescriptors, rdmolops
from rdkit.Chem.Draw import IPythonConsole

import py3Dmol
from rdkit.Chem import AllChem

import matplotlib.pyplot as plt
import seaborn as sns
from mycolorpy import colorlist as mcp
import matplotlib.cm as cm
import colorama

from google.colab import files
from IPython.display import HTML

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

print("All imports succeeded.")
print("Python version {}".format(sys.version))
print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

All imports succeeded.
Python version 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
PyTorch version 2.1.0+cu121
PyG version 2.5.0


In [4]:
#@title [RUN] Set random seed for deterministic results

def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(0)
print("All seeds set.")

All seeds set.


In [5]:
#@title [RUN] Helper functions for data preparation

class SetTarget:
    """
    This transform modifies the labels vector per data sample to only keep
    the label for a specific target (there are 19 targets in QM9).

    Note: for this practical, we have hardcoded the target to be target #0,
    i.e. the electric dipole moment of a drug-like molecule.
    (https://en.wikipedia.org/wiki/Electric_dipole_moment)
    """
    def __init__(self, target=0):
      self.target = target
    def __call__(self, data):
        data.y = data.y[:, self.target]
        return data


class CompleteGraph:
    """
    This transform adds all pairwise edges into the edge index per data sample,
    then removes self loops, i.e. it builds a fully connected or complete graph
    """
    def __call__(self, input):
        data = input.clone()
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

print("Helper functions loaded.")

Helper functions loaded.


In [6]:
# @title [RUN] Helper functions for plots and visualisations


def draw_one_graph(ax, edges, label=None, node_emb=None, layout=None, special_color=False):
    """draw a graph with networkx based on adjacency matrix (edges)
    graph labels could be displayed as a title for each graph
    node_emb could be displayed in colors
    """
    graph = nx.Graph()
    edges = zip(edges[0], edges[1])
    graph.add_edges_from(edges)
    node_pos = layout(graph)
    #add colors according to node embeding
    if (node_emb is not None) or special_color:
        color_map = []
        node_list = [node[0] for node in graph.nodes(data = True)]
        for i,node in enumerate(node_list):
            #just ignore this branch
            if special_color:
                if len(node_list) == 3:
                    crt_color = (1,0,0)
                elif len(node_list) == 5:
                    crt_color = (0,1,0)
                elif len(node_list) == 4:
                    crt_color = (1,1,0)
                else:
                  special_list = [(1,0,0)] * 3 + [(0,1,0)] * 5 + [(1,1,0)] * 4
                  crt_color = special_list[i]
            else:
                crt_node_emb = node_emb[node]
                #map float number (node embeding) to a color
                crt_color = cm.gist_rainbow(crt_node_emb, bytes=True)
                crt_color = (crt_color[0]/255.0, crt_color[1]/255.0, crt_color[2]/255.0, crt_color[3]/255.0)
            color_map.append(crt_color)

        nx.draw_networkx_nodes(graph,node_pos, node_color=color_map,
                        nodelist = node_list, ax=ax)
        nx.draw_networkx_edges(graph, node_pos, ax=ax)
        nx.draw_networkx_labels(graph,node_pos, ax=ax)
    else:
        nx.draw_networkx(graph, node_pos, ax=ax)


def gallery(graphs, labels=None, node_emb=None, special_color=False, max_graphs=4, max_fig_size=(40, 10), layout=nx.layout.kamada_kawai_layout):
    ''' Draw multiple graphs as a gallery
    Args:
      graphs: torch_geometrics.dataset object/ List of Graph objects
      labels: num_graphs
      node_emb: num_graphs* [num_nodes x num_ch]
      max_graphs: maximum graphs display
    '''
    num_graphs = min(len(graphs), max_graphs)
    ff, axes = plt.subplots(1, num_graphs,
                            figsize=max_fig_size,
                            subplot_kw={'xticks': [], 'yticks': []})
    if num_graphs == 1:
        axes = [axes]
    if node_emb is None:
        node_emb = num_graphs*[None]
    if labels is None:
        labels = num_graphs * [" "]


    for i in range(num_graphs):
        draw_one_graph(axes[i], graphs[i].edge_index.numpy(), labels[i], node_emb[i], layout, special_color)
        if labels[i] != " ":
            axes[i].set_title(f"Target: {labels[i]}", fontsize=28)
        axes[i].set_axis_off()
    plt.show()


def hash_node_embedings(node_emb):
  """ Hash the tensor representing nodes' features
  to a number in [0,1] used to represent a color

  Args:
    node_emb: list of num_graphs arrays, each of dim (num_nodes x num_feats)
  Returns:
    list of num_graphs arrays in [0,1], each of dim (num_nodes)
  """
  chunk_size_graph = [x.shape[0] for x in node_emb]
  start_idx_graph = [0] + list(itertools.accumulate(chunk_size_graph))[:-1]

  node_emb_flatten = np.concatenate(node_emb).mean(-1)

  min_emb = node_emb_flatten.min()
  max_emb = node_emb_flatten.max()
  node_emb_flatten = (node_emb_flatten-min_emb)/(max_emb-min_emb)

  #split in graphs again according to (start_idx_graph, chunk_size_graph)
  node_emb_hashed = [node_emb_flatten[i:i+l] for (i,l) in zip(start_idx_graph, chunk_size_graph)]
  return node_emb_hashed


def update_stats(training_stats, epoch_stats):
    """ Store metrics along the training
    Args:
      epoch_stats: dict containg metrics about one epoch
      training_stats: dict containing lists of metrics along training
    Returns:
      updated training_stats
    """
    if training_stats is None:
        training_stats = {}
        for key in epoch_stats.keys():
            training_stats[key] = []
    for key,val in epoch_stats.items():
        training_stats[key].append(val)
    return training_stats


def plot_stats(training_stats, figsize=(5, 5), name=""):
    """ Create one plot for each metric stored in training_stats
    """
    stats_names = [key[6:] for key in training_stats.keys() if key.startswith('train_')]
    f, ax = plt.subplots(len(stats_names), 1, figsize=figsize)
    if len(stats_names)==1:
        ax = np.array([ax])
    for key, axx in zip(stats_names, ax.reshape(-1,)):
        axx.plot(
            training_stats['epoch'],
            training_stats[f'train_{key}'],
            label=f"Training {key}")
        axx.plot(
            training_stats['epoch'],
            training_stats[f'val_{key}'],
            label=f"Validation {key}")
        axx.set_xlabel("Training epoch")
        axx.set_ylabel(key)
        axx.legend()
    plt.title(name)


def get_color_coded_str(i, color):
    return "\033[3{}m{}\033[0m".format(int(color), int(i))


def print_color_numpy(map, list_graphs):
    """ print matrix map in color according to list_graphs
    """
    list_blocks = []
    for i,graph in enumerate(list_graphs):
        block_i = (i+1)*np.ones((graph.num_nodes,graph.num_nodes))
        list_blocks += [block_i]
    block_color = block_diag(*list_blocks)

    map_modified = np.vectorize(get_color_coded_str)(map, block_color)
    print("\n".join([" ".join(["{}"]*map.shape[0])]*map.shape[1]).format(*[x for y in map_modified.tolist() for x in y]))


#############################################
# Visualization helpers for molecular graphs
#############################################

allowable_atoms = [
    "H",
    "C",
    "N",
    "O",
    "F",
    "C",
    "Cl",
    "Br",
    "I",
    "H",
    "Unknown",
]

def to_atom(t):
    try:
        return allowable_atoms[int(t.argmax())]
    except:
        return "C"


def to_bond_index(t):
    t_s = t.squeeze()
    return [1, 2, 3, 4][
        int(
            torch.dot(
                t_s,
                torch.tensor(
                    range(t_s.size()[0]), dtype=torch.float, device=t.device
                ),
            ).item()
        )
    ]

def to_rdkit(data, device=None):
    has_pos = False
    node_list = []
    for i in range(data.x.size()[0]):
        node_list.append(to_atom(data.x[i][:5]))

    # create empty editable mol object
    mol = Chem.RWMol()
    # add atoms to mol and keep track of index
    node_to_idx = {}
    invalid_idx = set([])
    for i in range(len(node_list)):
        if node_list[i] == "Stop" or node_list[i] == "H":
            invalid_idx.add(i)
            continue
        a = Chem.Atom(node_list[i])
        molIdx = mol.AddAtom(a)
        node_to_idx[i] = molIdx

    added_bonds = set([])
    for i in range(0, data.edge_index.size()[1]):
        ix = data.edge_index[0][i].item()
        iy = data.edge_index[1][i].item()
        bond = to_bond_index(data.edge_attr[i])  # <font color='red'>TODO</font> fix this
        # bond = 1
        # add bonds between adjacent atoms

        if data.edge_attr[i].sum() == 0:
          continue

        if (
            (str((ix, iy)) in added_bonds)
            or (str((iy, ix)) in added_bonds)
            or (iy in invalid_idx or ix in invalid_idx)
        ):
            continue
        # add relevant bond type (there are many more of these)

        if bond == 0:
            continue
        elif bond == 1:
            bond_type = Chem.rdchem.BondType.SINGLE
            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
        elif bond == 2:
            bond_type = Chem.rdchem.BondType.DOUBLE
            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
        elif bond == 3:
            bond_type = Chem.rdchem.BondType.TRIPLE
            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
        elif bond == 4:
            bond_type = Chem.rdchem.BondType.SINGLE
            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)

        added_bonds.add(str((ix, iy)))

    if has_pos:
        conf = Chem.Conformer(mol.GetNumAtoms())
        for i in range(data.pos.size(0)):
            if i in invalid_idx:
                continue
            p = Point3D(
                data.pos[i][0].item(),
                data.pos[i][1].item(),
                data.pos[i][2].item(),
            )
            conf.SetAtomPosition(node_to_idx[i], p)
        conf.SetId(0)
        mol.AddConformer(conf)

    # Convert RWMol to Mol object
    mol = mol.GetMol()
    mol_frags = rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
    largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
    return largest_mol


def MolTo3DView(mol, size=(300, 300), style="stick", surface=False, opacity=0.5):
    """Draw molecule in 3D

    Args:
    ----
        mol: rdMol, molecule to show
        size: tuple(int, int), canvas size
        style: str, type of drawing molecule
               style can be 'line', 'stick', 'sphere', 'carton'
        surface, bool, display SAS
        opacity, float, opacity of surface, range 0.0-1.0
    Return:
    ----
        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
    """
    assert style in ('line', 'stick', 'sphere', 'carton')

    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
    mblock = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=size[0], height=size[1])
    viewer.addModel(mblock, 'mol')
    viewer.setStyle({style:{}})
    if surface:
        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
    viewer.zoomTo()
    return viewer

def smi2conf(smiles):
    '''Convert SMILES to rdkit.Mol with 3D coordinates'''
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol)
        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        return mol
    else:
        return None

print("Helper functions loaded.")

Helper functions loaded.


In [7]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    path = './qm9'
    target = 0

    # Transforms which are applied during data loading:
    # Select the target/label
    transform = T.Compose([SetTarget(target=target)])

    # Load the QM9 dataset with the transforms defined
    # dataset = QM9(path, transform=transform)
    dataset = QM9(path)

    # Normalize targets per data sample to mean = 0 and std = 1.
    mean = dataset.data.y.mean(dim=0, keepdim=True)
    std = dataset.data.y.std(dim=0, keepdim=True)
    # dataset.data.y = (dataset.data.y - mean) / std
    mean, std = mean[:, target].item(), std[:, target].item()

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
Extracting qm9/raw/qm9.zip
Downloading https://ndownloader.figshare.com/files/3195404
Processing...
100%|██████████| 133885/133885 [02:37<00:00, 850.39it/s] 
Done!


In [8]:
print(f"Total number of samples: {len(dataset)}.")

# Split datasets (in case of using the full dataset)
# test_dataset = dataset[:10000]
# val_dataset = dataset[10000:20000]
# train_dataset = dataset[20000:]

# Split datasets (our 3K subset)
train_pyg_dataset = dataset[:1000]
val_pyg_dataset = dataset[1000:2000]
test_pyg_dataset = dataset[2000:3000]
print(f"Created dataset splits with {len(train_pyg_dataset)} training, {len(val_pyg_dataset)} validation, {len(test_pyg_dataset)} test samples.")

Total number of samples: 130831.
Created dataset splits with 1000 training, 1000 validation, 1000 test samples.


In [9]:
train_pyg_dataset[0]

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], z=[5], smiles='[H]C([H])([H])[H]', name='gdb_1', idx=[1])

In [10]:
train_pyg_dataset[50]

Data(x=[9, 11], edge_index=[2, 18], edge_attr=[18, 4], y=[1, 19], pos=[9, 3], z=[9], smiles='[H]C1=NC([H])=C([H])N1[H]', name='gdb_51', idx=[1])

In [11]:
data = train_pyg_dataset[0]

print(f"\nThis molecule has {data.x.shape[0]} atoms, and {data.edge_attr.shape[0]} edges.")

print(f"\nFor each atom, we are given a feature vector with {data.x.shape[1]} entries (described above).")

print(f"\nFor each edge, we are given a feature vector with {data.edge_attr.shape[1]} entries (also described above).")

print(f"\nIn the next section, we will learn how to build a GNN in the Message Passing flavor to \n\
process the node and edge features of molecular graphs and predict their properties.")

print(f"\nEach atom also has a {data.pos.shape[1]}-dimensional coordinate associated with it. \n\
We will talk about their importance later in the practical.")

print(f"\nFinally, we have {data.y.shape[0]} regression target for the entire molecule.")


This molecule has 5 atoms, and 8 edges.

For each atom, we are given a feature vector with 11 entries (described above).

For each edge, we are given a feature vector with 4 entries (also described above).

In the next section, we will learn how to build a GNN in the Message Passing flavor to 
process the node and edge features of molecular graphs and predict their properties.

Each atom also has a 3-dimensional coordinate associated with it. 
We will talk about their importance later in the practical.

Finally, we have 1 regression target for the entire molecule.


In [20]:
# Create dataloaders with batch size = 32
train_loader = DataLoader(train_pyg_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_pyg_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_pyg_dataset, batch_size=32, shuffle=False)

In [21]:
data = train_pyg_dataset[1]
data.y, data.x, data.pos

(tensor([[ 1.6256e+00,  9.4600e+00, -6.9933e+00,  2.2558e+00,  9.2491e+00,
           2.6156e+01,  9.3493e-01, -1.5381e+03, -1.5381e+03, -1.5380e+03,
          -1.5387e+03,  6.3160e+00, -1.2006e+01, -1.2082e+01, -1.2159e+01,
          -1.1246e+01,  2.9361e+02,  2.9354e+02,  1.9139e+02]]),
 tensor([[0., 0., 1., 0., 0., 7., 0., 0., 0., 0., 3.],
         [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]),
 tensor([[-0.0404,  1.0241,  0.0626],
         [ 0.0173,  0.0125, -0.0274],
         [ 0.9158,  1.3587, -0.0288],
         [-0.5203,  1.3435, -0.7755]]))

### Implementing GGNNs

In [22]:
#Load necessary packages
import torch
from torch_geometric.nn import GatedGraphConv, global_mean_pool, global_add_pool, GRUAggregation
#from torch_geometric.datasets import QM9 (already loaded QM9)
from torch_geometric.loader import DataLoader
from torch.nn import Linear
import torch.nn.functional as F

In [23]:
def train(model, train_loader, optimizer, device):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        y_pred = model(data)
        loss = F.mse_loss(y_pred, data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(train_loader.dataset)

In [24]:
def eval(model, loader, device):
    model.eval()
    total_loss = 0

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            output = model(data)
            loss = F.l1loss(output, data.y)
            total_loss += loss.item()
    return total_loss / len(loader.dataset)


In [25]:
def run_experiment(model, model_name, train_loader, val_loader, test_loader, n_epochs=100, patience=float('inf')):
    print(f"Running experiment for {model_name}, training on {len(train_loader.dataset)} samples for {n_epochs} epochs.")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Training on {device}.')

    print("\nModel architecture:")
    print(model)
    total_param = 0
    for param in model.parameters():
        total_param += np.prod(list(param.data.size()))
    print(f'Total parameters: {total_param}')
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=5, min_lr=0.00001)

    print("\nStart training:")
    best_val_error = float('inf')
    perf_per_epoch = []
    t = time.time()
    patience_counter = 0
    for epoch in range(1, n_epochs + 1):
        lr = scheduler.optimizer.param_groups[0]['lr']
        loss = train(model, train_loader, optimizer, device)
        val_error = eval(model, val_loader, device)

        if val_error < best_val_error:
            best_val_error = val_error
            test_error = eval(model, test_loader, device)  # Evaluate model on test set if validation metric improves
            patience_counter = 0  # Reset patience counter

            # Generate the current date string
            current_date = datetime.now().strftime('%Y-%m-%d_%H-%M')
            model_save_path = f'models/{current_date}/{model_name}_{model.num_layers}layers_{type(model.aggr).__name__}_aggr.pt'

            # Create directory if it does not exist
            os.makedirs(os.path.dirname(model_save_path), exist_ok=True)

            # Save the model
            torch.save(model.state_dict(), model_save_path)
        else:
            patience_counter += 1

        if epoch % 10 == 0:
            print(f'Epoch: {epoch:03d}, LR: {lr:5f}, Loss: {loss:.7f}, '
                  f'Val MAE: {val_error:.7f}, Test MAE: {test_error:.7f}')

        scheduler.step(val_error)
        perf_per_epoch.append((test_error, val_error, epoch, model_name))

        if patience_counter >= patience:
            print(f"Stopping early due to no improvement in validation loss for {patience} epochs.")
            break

    t = time.time() - t
    train_time = t / 60
    print(f"\nDone! Training took {train_time:.2f} mins. Best validation MAE: {best_val_error:.7f}, corresponding test MAE: {test_error:.7f}.")

    return best_val_error, test_error, train_time, perf_per_epoch

In [26]:
class GGNNPyG(torch.nn.Module):
  def __init__(self, num_node_features, num_classes):
    super(GGNNPyG, self).__init__()
    self.ggnn = GatedGraphConv(out_channels=num_node_features, num_layers=3)
    self.lin = Linear(num_node_features, num_classes)

  def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x = self.ggnn(x, edge_index)
    x = GRUAggregation(x, data.batch)
    #x = global_add_pool(x, data.batch) # using add pooling to aggregate the features
    x = F.relu(self.lin(x))
    return F.log_softmax(x, dim=1)

In [27]:
# Initialize the GGNN model
model = GGNNPyG(num_node_features=dataset.num_node_features, num_classes=1)

In [30]:
model_name = type(model).__name__
best_val_error, test_error, train_time, perf_per_epoch = run_experiment(
    model,
    model_name,
    train_loader,
    val_loader,
    test_loader,
    n_epochs=100
)

Running experiment for GGNNPyG, training on 1000 samples for 100 epochs.
Training on cuda.

Model architecture:
GGNNPyG(
  (ggnn): GatedGraphConv(11, num_layers=3)
  (lin): Linear(in_features=11, out_features=1, bias=True)
)
Total parameters: 1167

Start training:


TypeError: hidden_size should be of type int, got: Tensor

In [None]:
RESULTS[model_name] = (best_val_error, test_error, train_time)
df_temp = pd.DataFrame(perf_per_epoch, columns=["Test MAE", "Val MAE", "Epoch", "Model"])
DF_RESULTS = DF_RESULTS.append(df_temp, ignore_index=True)

In [None]:
RESULTS

### Extras (not important)

In [29]:
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
# Training loop
for epoch in range(50):
  model.train()
  for data in train_loader:
    #GPUtil.showUtilization()
    optimizer.zero_grad()
    output = model(data)
    loss = F.mse_loss(output, data.y)
    loss.backward()
    optimizer.step()
  print(f'Epoch {epoch+1}, Loss: {loss.item()}')

In [None]:
#Testing the model
model.eval()
total_loss = 0
with torch.no_grad():
  for data in test_loader:
    output = model(data)
    loss = F.mse_loss(output, data.y)
    total_loss += loss.item()

print(f'Test Loss: {total_loss/len(test_loader)}')