```yaml
data:
  module_type: "GraphOGBDataModule"
  args:
    cache_data_path: null
  
    dataset_name: "ogbg-moltox21"
  
    batch_size_training: 16
    batch_size_inference: 16
  
    featurization:
      atom_property_list_float: []
      atom_property_list_onehot: ["atomic-number", "degree"]
      edge_property_list: ["ring", "bond-type-onehot"]
      add_self_loop: false
      use_bonds_weights: false
      explicit_H: false
```

In [2]:
%load_ext autoreload
%autoreload 2

import pathlib
import functools
import tempfile
import importlib

import numpy as np
import pandas as pd
import lightning
import torch
import datamol as dm

import graphium

Using backend: pytorch


In [3]:
dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
dataset_name = dataset_names[3]

# Setup a temporary cache file. Only for
# demo purposes, use a known path in prod.
cache_data_path = pathlib.Path(tempfile.mkdtemp()) / "cache.pkl"

# Setup the featurization
featurization_args = {}
featurization_args["atom_property_list_float"] = []  # ["weight", "valence"]
featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"]
featurization_args["edge_property_list"] = ["bond-type-onehot"]
featurization_args["add_self_loop"] = False
featurization_args["use_bonds_weights"] = False
featurization_args["explicit_H"] = False

# Config for datamodule
dm_args = {}
dm_args["dataset_name"] = dataset_name
dm_args["cache_data_path"] = cache_data_path
dm_args["featurization"] = featurization_args
dm_args["batch_size_training"] = 16
dm_args["batch_size_inference"] = 16
dm_args["num_workers"] = 0
dm_args["pin_memory"] = True
dm_args["featurization_n_jobs"] = 16
dm_args["featurization_progress"] = True

ds = graphium.data.GraphOGBDataModule(**dm_args)
ds

2021-04-15 14:08:11.044 | INFO     | graphium.data.datamodule:_load_dataset:585 - Loading /home/hadim/.cache/graphium/ogb/freesolv/mapping/mol.csv.gz in memory.
2021-04-15 14:08:11.053 | INFO     | graphium.data.datamodule:_load_dataset:598 - Saving splits to /home/hadim/.cache/graphium/ogb/freesolv/split/scaffold.csv.gz


dataset_name: ogbg-molfreesolv
name: GraphOGBDataModule
len: 642
batch_size_training: 16
batch_size_inference: 16
num_node_feats: 50
num_edge_feats: 5
collate_fn: graphium_collate_fn
featurization:
  atom_property_list_float: []
  atom_property_list_onehot:
  - atomic-number
  - degree
  edge_property_list:
  - bond-type-onehot
  add_self_loop: false
  use_bonds_weights: false
  explicit_H: false

In [4]:
# Access to the OGB metadata with
ds.metadata

{'num tasks': '1',
 'eval metric': 'rmse',
 'download_name': 'freesolv',
 'version': '1',
 'url': 'http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/freesolv.zip',
 'add_inverse_edge': 'True',
 'data type': 'mol',
 'has_node_attr': 'True',
 'has_edge_attr': 'True',
 'task type': 'regression',
 'num classes': '-1',
 'split': 'scaffold',
 'additional node files': 'None',
 'additional edge files': 'None',
 'binary': 'False'}

In [5]:
# Load and prepare the data
ds.prepare_data()

# Create the split torch datasets
ds.setup()

2021-04-15 14:08:12.006 | INFO     | graphium.data.datamodule:prepare_data:291 - Prepare dataset with 642 data points.


  0%|          | 0/642 [00:00<?, ?it/s]

2021-04-15 14:08:14.918 | INFO     | graphium.data.datamodule:prepare_data:326 - Write prepared data to /tmp/tmppuh1m6te/cache.pkl


In [6]:
ds.train_ds[0]

{'smiles': 'CN(C)C(=O)c1ccc(cc1)OC',
 'indices': '4-methoxy-N,N-dimethyl-benzamide',
 'features': Graph(num_nodes=13, num_edges=26,
       ndata_schemes={'feat': Scheme(shape=(50,), dtype=torch.float32)}
       edata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)}),
 'labels': array([-11.01])}

In [7]:
# Load a dataloader and get the first batch from it
dl = ds.train_dataloader()
it = iter(dl)
batch = next(it)
batch

{'smiles': ['CCCCO[N+](=O)[O-]',
  'CC(=O)OC',
  'CC(=O)Oc1ccccc1C(=O)O',
  'CCl',
  'CC(C)(C)c1ccc(cc1)O',
  'C(CBr)Br',
  'c1ccc(cc1)C(=O)N',
  'CCCCC[N+](=O)[O-]',
  'CCCCBr',
  'c1cc(c(cc1c2ccc(cc2F)F)C(=O)O)O',
  'c1ccc(cc1)C=O',
  'CCCc1ccc(c(c1)OC)O',
  'CC[C@@H](C)CO',
  'CCOc1ccccc1',
  'c1c(c(cc(c1Cl)Cl)Cl)Cl',
  'C(CO[N+](=O)[O-])CO[N+](=O)[O-]'],
 'indices': ['butyl nitrate',
  'methyl acetate',
  'acetylsalicylic acid',
  'chloromethane',
  '4-tert-butylphenol',
  '1,2-dibromoethane',
  'benzamide',
  '1-nitropentane',
  '1-bromobutane',
  'diflunisal',
  'benzaldehyde',
  '4-propylguaiacol',
  '2-methylbutan-1-ol',
  'ethoxybenzene',
  '1,2,4,5-tetrachlorobenzene',
  '3-nitrooxypropyl nitrate'],
 'features': Graph(num_nodes=139, num_edges=264,
       ndata_schemes={'feat': Scheme(shape=(50,), dtype=torch.float32)}
       edata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)}),
 'labels': tensor([[ -2.0900],
         [ -3.1300],
         [ -9.9400],
         [ -0.