```yaml
data:
  module_type: "DGLOGBDataModule"
  args:
    cache_data_path: null
  
    dataset_name: "ogbg-moltox21"
  
    batch_size_train_val: 16
    batch_size_test: 16
  
    featurization:
      atom_property_list_float: []
      atom_property_list_onehot: ["atomic-number", "degree"]
      edge_property_list: ["ring", "bond-type-onehot"]
      add_self_loop: false
      use_bonds_weights: false
      explicit_H: false
```

In [1]:
%load_ext autoreload
%autoreload 2

import pathlib
import functools
import tempfile
import importlib

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import datamol as dm

import goli

Using backend: pytorch


In [5]:
dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
dataset_name = dataset_names[3]

# Setup a temporary cache file. Only for
# demo purposes, use a known path in prod.
cache_data_path = pathlib.Path(tempfile.mkdtemp()) / "cache.pkl"

# Setup the featurization
featurization_args = {}
featurization_args["atom_property_list_float"] = []  # ["weight", "valence"]
featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"]
featurization_args["edge_property_list"] = ["bond-type-onehot"]
featurization_args["add_self_loop"] = False
featurization_args["use_bonds_weights"] = False
featurization_args["explicit_H"] = False

# Config for datamodule
dm_args = {}
dm_args["dataset_name"] = dataset_name
dm_args["cache_data_path"] = cache_data_path
dm_args["featurization"] = featurization_args
dm_args["batch_size_train_val"] = 16
dm_args["batch_size_test"] = 16
dm_args["num_workers"] = 0
dm_args["pin_memory"] = True
dm_args["featurization_n_jobs"] = 16
dm_args["featurization_progress"] = True

ds = goli.data.DGLOGBDataModule(**dm_args)

# test metadata
assert set(ds.metadata.keys()) == {
    "num tasks",
    "eval metric",
    "download_name",
    "version",
    "url",
    "add_inverse_edge",
    "data type",
    "has_node_attr",
    "has_edge_attr",
    "task type",
    "num classes",
    "split",
    "additional node files",
    "additional edge files",
    "binary",
}

ds.prepare_data()
ds.setup()

# test module
assert ds.num_edge_feats ==5
assert ds.num_node_feats ==50
assert len(ds) == 642
assert ds.dataset_name == "ogbg-molfreesolv"

# test dataset
assert set(ds.train_ds[0].keys()) == {'smiles', 'indices', 'features', 'labels'}

# test batch loader
batch = next(iter(ds.train_dataloader()))
assert len(batch["smiles"]) == 16
assert len(batch["labels"]) == 16
assert len(batch["indices"]) == 16

2021-04-15 13:59:24.451 | INFO     | goli.data.datamodule:_load_dataset:585 - Loading /home/hadim/.cache/goli/ogb/freesolv/mapping/mol.csv.gz in memory.
2021-04-15 13:59:24.456 | INFO     | goli.data.datamodule:_load_dataset:598 - Saving splits to /home/hadim/.cache/goli/ogb/freesolv/split/scaffold.csv.gz


dataset_name: ogbg-molfreesolv
name: DGLOGBDataModule
len: 642
batch_size_train_val: 16
batch_size_test: 16
num_node_feats: 50
num_edge_feats: 5
collate_fn: goli_collate_fn
featurization:
  atom_property_list_float: []
  atom_property_list_onehot:
  - atomic-number
  - degree
  edge_property_list:
  - bond-type-onehot
  add_self_loop: false
  use_bonds_weights: false
  explicit_H: false

In [16]:
# test metadata
assert set(ds.metadata.keys()) == {
    "num tasks",
    "eval metric",
    "download_name",
    "version",
    "url",
    "add_inverse_edge",
    "data type",
    "has_node_attr",
    "has_edge_attr",
    "task type",
    "num classes",
    "split",
    "additional node files",
    "additional edge files",
    "binary",
}

ds.prepare_data()
ds.setup()

# test module
assert ds.num_edge_feats ==5
assert ds.num_node_feats ==50
assert len(ds) == 642
assert ds.dataset_name == "ogbg-molfreesolv"

# test dataset
assert set(ds.train_ds[0].keys()) == {'smiles', 'indices', 'features', 'labels'}

# test batch loader
batch = next(iter(ds.train_dataloader()))
assert len(batch["smiles"]) == 16
assert len(batch["labels"]) == 16
assert len(batch["indices"]) == 16

2021-04-15 14:01:02.872 | INFO     | goli.data.datamodule:prepare_data:270 - Reload data from /tmp/tmp0jwkbcco/cache.pkl.


In [29]:
g = batch["features"]

In [32]:
g.

Graph(num_nodes=132, num_edges=248,
      ndata_schemes={'feat': Scheme(shape=(50,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)})

In [2]:
dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
dataset_name = dataset_names[3]

# Setup a temporary cache file. Only for
# demo purposes, use a known path in prod.
cache_data_path = pathlib.Path(tempfile.mkdtemp()) / "cache.pkl"

# Setup the featurization
featurization_args = {}
featurization_args["atom_property_list_float"] = []  # ["weight", "valence"]
featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"]
featurization_args["edge_property_list"] = ["bond-type-onehot"]
featurization_args["add_self_loop"] = False
featurization_args["use_bonds_weights"] = False
featurization_args["explicit_H"] = False

# Config for datamodule
dm_args = {}
dm_args["dataset_name"] = dataset_name
dm_args["cache_data_path"] = cache_data_path
dm_args["featurization"] = featurization_args
dm_args["batch_size_train_val"] = 16
dm_args["batch_size_test"] = 16
dm_args["num_workers"] = 0
dm_args["pin_memory"] = True
dm_args["featurization_n_jobs"] = 16
dm_args["featurization_progress"] = True

ds = goli.data.DGLOGBDataModule(**dm_args)
ds

2021-04-15 13:58:25.705 | INFO     | goli.data.datamodule:_load_dataset:585 - Loading /home/hadim/.cache/goli/ogb/freesolv/mapping/mol.csv.gz in memory.
2021-04-15 13:58:25.710 | INFO     | goli.data.datamodule:_load_dataset:598 - Saving splits to /home/hadim/.cache/goli/ogb/freesolv/split/scaffold.csv.gz


dataset_name: ogbg-molfreesolv
name: DGLOGBDataModule
len: 642
batch_size_train_val: 16
batch_size_test: 16
num_node_feats: 50
num_edge_feats: 5
collate_fn: goli_collate_fn
featurization:
  atom_property_list_float: []
  atom_property_list_onehot:
  - atomic-number
  - degree
  edge_property_list:
  - bond-type-onehot
  add_self_loop: false
  use_bonds_weights: false
  explicit_H: false

In [3]:
# Access to the OGB metadata with
ds.metadata

{'num tasks': '1',
 'eval metric': 'rmse',
 'download_name': 'freesolv',
 'version': '1',
 'url': 'http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/freesolv.zip',
 'add_inverse_edge': 'True',
 'data type': 'mol',
 'has_node_attr': 'True',
 'has_edge_attr': 'True',
 'task type': 'regression',
 'num classes': '-1',
 'split': 'scaffold',
 'additional node files': 'None',
 'additional edge files': 'None',
 'binary': 'False'}

In [4]:
# Load and prepare the data
ds.prepare_data()

# Create the split torch datasets
ds.setup()

2021-04-15 13:58:27.338 | INFO     | goli.data.datamodule:prepare_data:291 - Prepare dataset with 642 data points.


  0%|          | 0/642 [00:00<?, ?it/s]

2021-04-15 13:58:30.234 | INFO     | goli.data.datamodule:prepare_data:326 - Write prepared data to /tmp/tmp0l7sfhr5/cache.pkl


In [None]:
ds.train_ds[0]

In [None]:
# Load a dataloader and get the first batch from it
dl = ds.train_dataloader()
it = iter(dl)
batch = next(it)
batch

In [16]:
ds._get_ogb_metadata()

Unnamed: 0,num tasks,eval metric,download_name,version,url,add_inverse_edge,data type,has_node_attr,has_edge_attr,task type,num classes,split,additional node files,additional edge files,binary
ogbg-molbace,1,rocauc,bace,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-molbbbp,1,rocauc,bbbp,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-molclintox,2,rocauc,clintox,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-molmuv,17,ap,muv,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-molpcba,128,ap,pcba,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-molsider,27,rocauc,sider,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-moltox21,12,rocauc,tox21,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-moltoxcast,617,rocauc,toxcast,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-molhiv,1,rocauc,hiv,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,binary classification,2,scaffold,,,False
ogbg-molesol,1,rmse,esol,1,http://snap.stanford.edu/ogb/data/graphproppre...,True,mol,True,True,regression,-1,scaffold,,,False
