In [1]:
%load_ext autoreload
%autoreload 2

import pathlib
import functools
import tempfile

import numpy as np
import pytorch_lightning as pl
import torch
import datamol as dm

import goli

Using backend: pytorch


In [15]:
# Setup a temporary cache file. Only for
# demo purposes, use a known path in prod.
cache_data_path = pathlib.Path(tempfile.mkdtemp()) / "cache.pkl"
cache_data_path = "/home/hadim/test-cache.pkl"

# Load a dataframe
df = goli.data.load_tiny_zinc()
df.head()

# Setup the featurization
featurization_args = {}
featurization_args["atom_property_list_onehot"] = ["atomic-number", "valence"]
featurization_args["atom_property_list_float"] = ["mass", "electronegativity", "in-ring"]
featurization_args["edge_property_list"] = ["bond-type-onehot", "stereo"]#, "in-ring"]
featurization_args["add_self_loop"] = False
featurization_args["use_bonds_weights"] = False
featurization_args["explicit_H"] = False

# Config for datamodule
dm_args = {}
dm_args["df"] = df
dm_args["cache_data_path"] = cache_data_path
dm_args["featurization"] = featurization_args
dm_args["smiles_col"] = "SMILES"
dm_args["label_cols"] = ["SA"]
dm_args["split_val"] = 0.2
dm_args["split_test"] = 0.2
dm_args["split_seed"] = 19
dm_args["batch_size_train_val"] = 16
dm_args["batch_size_test"] = 16
dm_args["num_workers"] = 0
dm_args["pin_memory"] = True
dm_args["featurization_n_jobs"] = 16
dm_args["featurization_progress"] = True

datam = goli.data.DGLFromSmilesDataModule(**dm_args)
# datam

In [5]:
# Load and prepare the data
datam.prepare_data()

# Create the split torch datasets
datam.setup()

2021-04-30 12:44:48.070 | INFO     | goli.data.datamodule:prepare_data:305 - Reload data from /home/hadim/test-cache.pkl.
2021-04-30 12:44:48.100 | INFO     | goli.data.datamodule:prepare_data:317 - Cache looks invalid with keys: dict_keys(['dataset', 'train_indices', 'val_indices', 'test_indices', 'featurization_args'])
2021-04-30 12:44:48.101 | INFO     | goli.data.datamodule:prepare_data:318 - Fallback to regular data preparation steps.
2021-04-30 12:44:48.101 | INFO     | goli.data.datamodule:prepare_data:327 - Prepare dataset with 100 data points.


  0%|          | 0/100 [00:00<?, ?it/s]

2021-04-30 12:44:48.192 | INFO     | goli.data.datamodule:_save_to_cache:447 - Write prepared data to /home/hadim/test-cache.pkl


In [16]:
import inspect
from goli.features import mol_to_dglgraph_signature

cache = torch.load(cache_data_path)

In [17]:
# What would be the signature with supplied feature arguments
current_signature = mol_to_dglgraph_signature(dict(featurization_args or {}))
cache_signature = mol_to_dglgraph_signature(cache["featurization_args"])

In [18]:
current_signature == cache_signature

False

In [7]:
cache["featurization_args"]

{'atom_property_list_onehot': ['atomic-number', 'valence'],
 'atom_property_list_float': ['mass', 'electronegativity', 'in-ring'],
 'edge_property_list': ['bond-type-onehot', 'stereo', 'in-ring'],
 'add_self_loop': False,
 'explicit_H': False,
 'use_bonds_weights': False,
 'pos_encoding_as_features': None,
 'pos_encoding_as_directions': None,
 'dtype': torch.float32}

In [57]:
# Get the signature of `mol_to_dglgraph`
signature = inspect.signature(mol_to_dglgraph)

# Filter out empty arguments (without default value)
parameters = list(filter(lambda param: param.default is not param.empty, signature.parameters.values()))

# Convert to dict
parameters = {param.name: param.default for param in parameters}

# Update the parameters with the supplied ones
parameters.update(featurization_args)

In [58]:
parameters

{'atom_property_list_onehot': ['atomic-number', 'valence'],
 'atom_property_list_float': ['mass', 'electronegativity', 'in-ring'],
 'edge_property_list': ['bond-type-onehot', 'stereo', 'in-ring'],
 'add_self_loop': False,
 'explicit_H': False,
 'use_bonds_weights': False,
 'pos_encoding_as_features': None,
 'pos_encoding_as_directions': None,
 'dtype': torch.float32}

In [42]:
args = dict(st.parameters)
args

{'mol': <Parameter "mol: rdkit.Chem.rdchem.Mol">,
 'atom_property_list_onehot': <Parameter "atom_property_list_onehot: List[str] = []">,
 'atom_property_list_float': <Parameter "atom_property_list_float: List[Union[str, Callable]] = []">,
 'edge_property_list': <Parameter "edge_property_list: List[str] = []">,
 'add_self_loop': <Parameter "add_self_loop: bool = False">,
 'explicit_H': <Parameter "explicit_H: bool = False">,
 'use_bonds_weights': <Parameter "use_bonds_weights: bool = False">,
 'pos_encoding_as_features': <Parameter "pos_encoding_as_features: Dict[str, Any] = None">,
 'pos_encoding_as_directions': <Parameter "pos_encoding_as_directions: Dict[str, Any] = None">,
 'dtype': <Parameter "dtype: torch.dtype = torch.float32">}

In [44]:
# args = {k: v.default for k, v in args.items()}
args = dict(filter(lambda x: x[1].default is not x[1].empty, args.items()))
args

{}

In [41]:
p = args["mol"]
type(p)

type

In [6]:
cache.keys()

dict_keys(['dataset', 'train_indices', 'val_indices', 'test_indices'])

In [7]:
# Load a dataloader and get the first batch from it
dl = dm.train_dataloader()
it = iter(dl)
batch = next(it)
batch

{'smiles': ['c1cc2c(cc1N[C@@H]1CCOC3(CCC3)C1)CCC2',
  'CC(=O)N1CC[C@@H]([NH2+][C@@H](C)CSCC(C)C)C1',
  'Cc1ccc(-c2nc3ccc(C)c(C)c3[nH]2)nc1',
  'CCOC[C@H](O)[C@](C)(CC)[NH+]1CCCC1',
  'Cc1cc(CC(=O)N[C@H](c2ccc(F)cc2)C2CCC2)no1',
  'CCc1ccc(C(=O)/C(=C(/S)NC2CC2)[n+]2ccc(CC)cc2)cc1',
  'Cc1ccc(NC(=O)c2ccc(F)cc2F)cc1S(=O)(=O)Nc1ccc(Cl)cc1',
  'CC#CCCC(=O)Nc1cccc2c1C(=O)c1ccccc1C2=O',
  'CCOc1cc(/C=C(\\C#N)C(=O)c2c[nH]c3cc(Cl)ccc23)ccc1OC',
  'CNC(=O)[C@@H]1CCC[NH+]1Cc1ccc(C)c(F)c1',
  'CC(C)(C)OC(=O)N[C@H]1CCN(c2cc(-c3cccs3)n[nH]2)C1',
  'Cc1nn(-c2ccccc2)c(O)c1/C=[NH+]/Cc1ccncc1',
  'COCC[NH+](C)Cc1c(C)cc(C)c(C(C)=O)c1C',
  'Cc1nc(C)c(S(=O)(=O)/N=C(\\[O-])C[C@H]2CCCO2)s1',
  'COc1cccc(CN2CCC[NH+](CC(=O)Nc3ccc(F)cc3)S2(=O)=O)c1',
  'COc1cc(F)cc(CNC(=O)[C@H]2CCCN2C(=O)Cc2ccccc2)c1'],
 'features': Graph(num_nodes=352, num_edges=754,
       ndata_schemes={'feat': Scheme(shape=(50,), dtype=torch.float32)}
       edata_schemes={'feat': Scheme(shape=(6,), dtype=torch.float32)}),
 'labels': tensor

---

## Launch a training

In `goli.cli.train` I have added a `click` CLI command that take a config file as input and build the datamodule. Once I am done with the PL module I will complete the command.

The way to use it is quite simple, you need to install goli with `pip install -e .` (omit`-e` in prod) and then:

```bash
goli train -c my_config.yaml
```

It's not here for now but usually I "augment" the CLI command with various config key you might want to set without having to modify the config file itself:

```bash
goli train -c my_config.yaml --training-path /home/hadim/data/goli/runs/exp_1
```

Later the same strategy could be done to launch an hparams tuning run (with a config file as above + a config file defning the search space).

```bash
goli tune -c my_config.yaml --tune-config tuning_space.yaml
```