In [1]:
%load_ext autoreload
%autoreload 2

import pathlib
import functools
import tempfile

import numpy as np
import pytorch_lightning as pl
import torch
import datamol as dm
import pandas as pd

import goli

Using backend: pytorch


In [25]:
config = goli.load_config(name="zinc_default_fulldgl")

predictor_args = dict(config.predictor)
predictor_args["model_config"] = config.architecture
goli.trainer.Predictor(**predictor_args)

In [18]:
# Load a default config
config = goli.load_config(name="zinc_default_fulldgl")

# Setup a temporary cache file. Only for
# demo purposes, use a known path in prod.
cache_data_path = pathlib.Path(tempfile.mkdtemp()) / "cache.pkl"
cache_data_path = None

# Load a dataframe
df = goli.data.load_tiny_zinc()

# Optionally load the splits
splits_path = "../expts/data/tiny_zinc_splits.csv"

In [3]:
# Config for datamodule
dm_args = dict(config.data.args)
dm_args["df"] = df
dm_args["cache_data_path"] = cache_data_path
dm_args["featurization_n_jobs"] = -1
dm_args["featurization_progress"] = True
dm_args["splits_path"] = splits_path

dm = goli.data.DGLFromSmilesDataModule(**dm_args)
dm.prepare_data()
dm.setup()
dm

2021-03-17 11:20:40.213 | INFO     | goli.data.datamodule:prepare_data:178 - Prepare dataset with 100 data points.


  0%|          | 0/100 [00:00<?, ?it/s]

name: DGLFromSmilesDataModule
len: 100
batch_size_train_val: 16
batch_size_test: 16
num_node_feats: 50
num_edge_feats: 6
collate_fn: goli_collate_fn
featurization:
  atom_property_list_float: []
  atom_property_list_onehot:
  - atomic-number
  - degree
  edge_property_list:
  - ring
  - bond-type-onehot
  add_self_loop: false
  use_bonds_weights: false
  explicit_H: false

In [17]:
# Load a dataloader and get the first batch from it
dl = dm.train_dataloader()
it = iter(dl)
batch = next(it)
batch.keys()

dict_keys(['smiles', 'features', 'labels'])

In [20]:
in_dim_nodes = dm.num_node_feats
in_dim_edges = dm.num_edge_feats

# Modify the model config with `config.architecture.`
config.architecture.gnn.depth = 2

# Build the model
model = goli.config.load_architecture(config, in_dim_nodes, in_dim_edges)
model

DGL_GNN
---------
    pre-trans-NN(depth=1, ResidualConnectionNone)
        [FCLayer[50 -> 32] -> Linear(32)
    
    main-GNN(depth=2, ResidualConnectionSimple(skip_steps=1))
        PNAMessagePassingLayer[32 -> 32 -> 32]
        -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)
    

In [24]:
config

{'data': {'module_type': 'DGLFromSmilesDataModule', 'args': {'smiles_col': 'SMILES', 'label_cols': ['SA'], 'split_val': 0.2, 'split_test': 0.2, 'split_seed': 19, 'batch_size_train_val': 16, 'batch_size_test': 16, 'featurization': {'atom_property_list_float': [], 'atom_property_list_onehot': ['atomic-number', 'degree'], 'edge_property_list': ['ring', 'bond-type-onehot'], 'add_self_loop': False, 'use_bonds_weights': False, 'explicit_H': False}}}, 'architecture': {'model_type': 'fulldglnetwork', 'pre_nn': {'out_dim': 32, 'hidden_dims': 32, 'depth': 1, 'activation': 'relu', 'last_activation': 'none', 'dropout': 0.1, 'batch_norm': True, 'residual_type': 'none'}, 'post_nn': {'out_dim': 32, 'hidden_dims': 32, 'depth': 2, 'activation': 'relu', 'last_activation': 'none', 'dropout': 0.1, 'batch_norm': True, 'residual_type': 'none'}, 'gnn': {'out_dim': 32, 'hidden_dims': 32, 'depth': 2, 'activation': 'relu', 'last_activation': 'none', 'dropout': 0.1, 'batch_norm': True, 'residual_type': 'simple',

In [23]:
model

main-GNN(depth=2, ResidualConnectionSimple(skip_steps=1))
    PNAMessagePassingLayer[32 -> 32 -> 32]
    -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)

In [22]:
predictor_args = {}
predictor_args["model"] = model

predictor = goli.trainer.PredictorModule(**predictor_args)


#         model=model,
#         dataset=train_dt,
#         validation_split=val_dt,
#         collate_fn=DGLCollate(device=device, siamese=siamese),
#         metrics=metrics,
#         metrics_on_progress_bar=metrics_on_progress_bar,
#         additional_hparams={"layer_fullname": layer_name},
#         **reg_kwargs,
#     )

AttributeError: module 'goli.trainer' has no attribute 'PredictorModule'