In [1]:
%load_ext autoreload
%autoreload 2

import pathlib
import functools
import tempfile

import numpy as np
import pytorch_lightning as pl
import torch
import datamol as dm

import goli

Using backend: pytorch


In [2]:
# Setup a temporary cache file. Only for
# demo purposes, use a known path in prod
cache_data_path = pathlib.Path(tempfile.mkdtemp()) / "cache.pkl"

# Load a dataframe
df = goli.data.load_tiny_zinc()
df.head()

Unnamed: 0,SMILES,SA,logp,score
0,CCc1ccc(C(=O)/C(=C(/S)NC2CC2)[n+]2ccc(CC)cc2)cc1,2.775189,3.7897,6.564889
1,C=CCOc1ccc(C(F)(F)F)cc1C(=O)NC[C@H]1CCC[C@H]1O,3.071498,3.161,6.232498
2,COc1cc(OC)cc([C@@H](NC(=O)Cn2ccccc2=O)c2nccn2C)c1,2.840009,1.5048,4.344809
3,CN1CCC[C@@]2(CCN(C(=O)CN3CCNC(=O)C3)C2)C1=O,3.605168,-1.1109,2.494268
4,COc1ccc(Cl)cc1NC(=O)c1nn(C)cc1[N+](=O)[O-],2.132665,2.2426,4.375265


In [7]:
# All the below config are primitive types and can be easily
# embeded in a YAML config file

# Setup the featurization
featurization_args = {}
featurization_args["atom_property_list_float"] = []  # ["weight", "valence"]
featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"]
featurization_args["edge_property_list"] = ["ring", "bond-type-onehot"]
featurization_args["add_self_loop"] = False
featurization_args["use_bonds_weights"] = False
featurization_args["explicit_H"] = False

# Config for datamodule
dm_args = {}
dm_args["df"] = df
dm_args["cache_data_path"] = None#cache_data_path  # unsed at the moment
dm_args["featurization"] = featurization_args
dm_args["smiles_col"] = "SMILES"
dm_args["label_cols"] = ["SA"]
dm_args["split_val"] = 0.2
dm_args["split_test"] = 0.2
dm_args["split_seed"] = 19
dm_args["train_val_batch_size"] = 16
dm_args["test_batch_size"] = 16
dm_args["num_workers"] = 0
dm_args["pin_memory"] = True
dm_args["featurization_n_jobs"] = 16
dm_args["featurization_progress"] = True


dm = goli.data.DGLFromSmilesDataModule(**dm_args)
dm

[autoreload of goli.data.datamodule failed: Traceback (most recent call last):
  File "/home/hadim/local/conda/envs/goli/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/hadim/local/conda/envs/goli/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/home/hadim/local/conda/envs/goli/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/home/hadim/local/conda/envs/goli/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 302, in update_class
    if update_generic(old_obj, new_obj): continue
  File "/home/hadim/local/conda/envs/goli/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/home/hadim/local/conda/envs/goli/lib/python3.8/site-packages/IPython/extensions/autoreload.py",

<goli.data.datamodule.DGLFromSmilesDataModule at 0x7f153e8d95b0>

In [4]:
# Load and prepare the data
dm.prepare_data()

# Create the split torch datasets
dm.setup()

2021-03-16 17:42:12.764 | INFO     | goli.data.datamodule:prepare_data:171 - Prepare dataset with 100 data points.


  0%|          | 0/100 [00:00<?, ?it/s]

2021-03-16 17:42:15.590 | INFO     | goli.data.datamodule:prepare_data:203 - Write prepared data to /tmp/tmphnijy1ze/cache.pkl


In [5]:
# Load a dataloader and get the first batch from it
dl = dm.train_dataloader()
it = iter(dl)
batch = next(it)
batch

{'smiles': ['CCN1C(=O)C(C#N)=C(C)/C(=C\\Nc2ccc([N+](=O)[O-])cc2)C1=O',
  'COc1ccc(-c2noc3ncnc(N4CCC[C@H](C(=O)[O-])C4)c23)cc1',
  'Cc1ccc(C(=O)N2CCC[C@@H](C(=O)N(C)CC(=O)NC(C)C)C2)cc1',
  'O=C(Nc1ccnn1Cc1cccc(Cl)c1)C1CCN(S(=O)(=O)c2ccccc2F)CC1',
  'CCc1onc(C)c1NC(=O)CCCC(C)(C)C',
  'C=CCSc1nnc(C2CCOCC2)n1N',
  'O=c1c2c3nc4ccccc4nc3n(CCC3=CCCCC3)c2ncn1C[C@H]1CCCO1',
  'CC(=O)c1c(C)[nH]c(C(=O)OCC(=O)N2C[C@H](C)C[C@@H](C)C2)c1C',
  'CCOc1ccc(C[NH+]2CCS[C@H]3COCC[C@@H]32)cc1OC',
  'CN1C[C@H](C(=O)NC[C@H]2CCC(C)(C)c3ccccc32)CC1=O',
  'Cc1cc(C(=O)N2CCN(C(=O)N[C@H]3CC(=O)N(C4CC4)C3)CC2)c(C)o1',
  'COc1ccc2c(c1)[C@H]([NH2+][C@H](C)CCN1CCOCC1)CCCO2',
  'C[NH2+]Cc1ccc(-c2cc(C)ccc2F)s1',
  'COC(=O)[C@H](C)N(Cc1ccccc1)C(=O)Cc1ccon1',
  'Cc1ccsc1C(=O)NCCNC(=O)c1ccco1',
  'COc1cccc(OCC(=O)N2CCN(c3nc4ccc(S(C)(=O)=O)cc4s3)CC2)c1'],
 'features': Graph(num_nodes=382, num_edges=824,
       ndata_schemes={'feat': Scheme(shape=(50,), dtype=torch.float32)}
       edata_schemes={'feat': Scheme(shape=(6,), dt