## This document is meant to compare the output of graphium-2 vs graphium-3 datamodules

In [1]:
import graphium
import hydra
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
import os
from os.path import dirname, realpath
from graphium.config._loader import load_accelerator, load_datamodule
import shutil

In [2]:
os.chdir(dirname(dirname(realpath(graphium.__file__))))

In [3]:

def initialize_hydra(config_path, job_name="app"):
    if GlobalHydra.instance().is_initialized():
        GlobalHydra.instance().clear()
    hydra.initialize(config_path=config_path, job_name=job_name)

def compose_main_config(config_dir):
    initialize_hydra(config_dir)
    # Compose the main configuration
    main_config = hydra.compose(config_name="main")
    return main_config

In [4]:

# Load the main configuration for toymix
CONFIG_DIR = "../../expts/hydra-configs/"
cfg = compose_main_config(CONFIG_DIR)
cfg = OmegaConf.to_container(cfg, resolve=True)
cfg.pop("tasks")

# Adapt the configuration to reduce the time it takes to run the test, less samples, less epochs
cfg["constants"]["max_epochs"] = 4
TINY_DIR = "expts/data/neurips2023/tiny-dataset/"
cfg["constants"]["data_dir"] = TINY_DIR
cfg["trainer"]["trainer"]["check_val_every_n_epoch"] = 1
cfg["trainer"]["trainer"]["max_epochs"] = 4

cfg["datamodule"]["args"]["processed_graph_data_path"] = "datacache/testing_feats"
cfg["datamodule"]["args"]["batch_size_training"] = 1
cfg["datamodule"]["args"]["batch_size_inference"] = 1
cfg["datamodule"]["args"]["task_specific_args"]["qm9"]["df_path"] = TINY_DIR + "qm9.csv.gz"
cfg["datamodule"]["args"]["task_specific_args"]["qm9"]["splits_path"] = TINY_DIR + "qm9_random_splits.pt"
cfg["datamodule"]["args"]["task_specific_args"]["tox21"]["df_path"] = TINY_DIR + "Tox21-7k-12-labels.csv.gz"
cfg["datamodule"]["args"]["task_specific_args"]["tox21"]["splits_path"] = TINY_DIR + "Tox21_random_splits.pt"
cfg["datamodule"]["args"]["task_specific_args"]["zinc"]["df_path"] = TINY_DIR + "ZINC12k.csv.gz"
cfg["datamodule"]["args"]["task_specific_args"]["zinc"]["splits_path"] = TINY_DIR + "ZINC12k_random_splits.pt"

# Initialize the accelerator
cfg, accelerator_type = load_accelerator(cfg)

# Load and initialize the dataset
datamodule = load_datamodule(cfg, accelerator_type)


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path=config_path, job_name=job_name)


In [5]:
datamodule.prepare_data()
datamodule.setup("fit")
datamodule

[32m2025-02-04 14:26:48.378[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36mprepare_data[0m:[36m969[0m - [1mData is already prepared.[0m
[32m2025-02-04 14:26:48.381[0m | [1mINFO    [0m | [36mgraphium.data.dataset[0m:[36m__init__[0m:[36m80[0m - [1mDataloading from DISK[0m
[32m2025-02-04 14:26:48.382[0m | [1mINFO    [0m | [36mgraphium.data.dataset[0m:[36m__init__[0m:[36m80[0m - [1mDataloading from DISK[0m
[32m2025-02-04 14:26:48.383[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36msetup[0m:[36m1136[0m - [1m-------------------
MultitaskDataset
	about = training set
	num_graphs_total = 15
	num_nodes_total = 267
	max_num_nodes_per_graph = 38
	min_num_nodes_per_graph = 8
	std_num_nodes_per_graph = 9.382252750095079
	mean_num_nodes_per_graph = 17.8
	num_edges_total = 570
	max_num_edges_per_graph = 78
	min_num_edges_per_graph = 16
	std_num_edges_per_graph = 20.199009876724155
	mean_num_edges_per_graph = 38.0
-------------------


name: MultitaskFromSmilesDataModule
len: 60
batch_size_training: 1
batch_size_inference: 1
num_node_feats: 85
num_edge_feats: 13
num_tasks: 3
collate_fn: graphium_collate_fn
featurization:
  atom_property_list_onehot:
  - atomic-number
  - group
  - period
  - total-valence
  atom_property_list_float:
  - degree
  - formal-charge
  - radical-electron
  - aromatic
  - in-ring
  edge_property_list:
  - bond-type-onehot
  - stereo
  - in-ring
  add_self_loop: false
  explicit_H: false
  use_bonds_weights: false
  pos_encoding_as_features:
    pos_types:
      lap_eigvec:
        pos_level: node
        pos_type: laplacian_eigvec
        num_pos: 8
        normalization: none
        disconnected_comp: true
      lap_eigval:
        pos_level: node
        pos_type: laplacian_eigval
        num_pos: 8
        normalization: none
        disconnected_comp: true
      rw_pos:
        pos_level: node
        pos_type: rw_return_probs
        ksteps: 16

In [6]:
val_loader = datamodule.val_dataloader()
val_loader

<torch.utils.data.dataloader.DataLoader at 0x1d65dc99cd0>

In [7]:
iter_loader = iter(val_loader)
elems = [next(iter_loader) for _ in range(30)]

for elem in elems:
    print(elem)

{'labels': DataBatch(graph_qm9=[1, 19], graph_zinc=[1, 3], graph_tox21=[1, 12]), 'features': DataBatch(edge_index=[2, 16], edge_weight=[16], num_nodes=8, feat=[8, 85], edge_feat=[16, 13], laplacian_eigvec=[8, 8], laplacian_eigval=[8, 8], rw_return_probs=[8, 16], batch=[8], ptr=[2])}
{'labels': DataBatch(graph_zinc=[1, 3], graph_qm9=[1, 19], graph_tox21=[1, 12]), 'features': DataBatch(edge_index=[2, 42], edge_weight=[42], num_nodes=20, feat=[20, 85], edge_feat=[42, 13], laplacian_eigvec=[20, 8], laplacian_eigval=[20, 8], rw_return_probs=[20, 16], batch=[20], ptr=[2])}
{'labels': DataBatch(graph_tox21=[1, 12], graph_zinc=[1, 3], graph_qm9=[1, 19]), 'features': DataBatch(edge_index=[2, 24], edge_weight=[24], num_nodes=12, feat=[12, 85], edge_feat=[24, 13], laplacian_eigvec=[12, 8], laplacian_eigval=[12, 8], rw_return_probs=[12, 16], batch=[12], ptr=[2])}
{'labels': DataBatch(graph_zinc=[1, 3], graph_qm9=[1, 19], graph_tox21=[1, 12]), 'features': DataBatch(edge_index=[2, 62], edge_weight=[

In [8]:
# Pickle the list of graphs
import pickle
with open("expts/notebooks/val_graphs_graphium3_new.pkl", "wb") as f:
    pickle.dump(elems, f)

# Load the list of graphs to make sure the pickle works
with open("expts/notebooks/val_graphs_graphium3_new.pkl", "rb") as f:
    elems = pickle.load(f)
for elem in elems:
    print(elem)

{'labels': DataBatch(graph_qm9=[1, 19], graph_zinc=[1, 3], graph_tox21=[1, 12]), 'features': DataBatch(edge_index=[2, 16], edge_weight=[16], num_nodes=8, feat=[8, 85], edge_feat=[16, 13], laplacian_eigvec=[8, 8], laplacian_eigval=[8, 8], rw_return_probs=[8, 16], batch=[8], ptr=[2])}
{'labels': DataBatch(graph_zinc=[1, 3], graph_qm9=[1, 19], graph_tox21=[1, 12]), 'features': DataBatch(edge_index=[2, 42], edge_weight=[42], num_nodes=20, feat=[20, 85], edge_feat=[42, 13], laplacian_eigvec=[20, 8], laplacian_eigval=[20, 8], rw_return_probs=[20, 16], batch=[20], ptr=[2])}
{'labels': DataBatch(graph_tox21=[1, 12], graph_zinc=[1, 3], graph_qm9=[1, 19]), 'features': DataBatch(edge_index=[2, 24], edge_weight=[24], num_nodes=12, feat=[12, 85], edge_feat=[24, 13], laplacian_eigvec=[12, 8], laplacian_eigval=[12, 8], rw_return_probs=[12, 16], batch=[12], ptr=[2])}
{'labels': DataBatch(graph_zinc=[1, 3], graph_qm9=[1, 19], graph_tox21=[1, 12]), 'features': DataBatch(edge_index=[2, 62], edge_weight=[