In [1]:
%load_ext autoreload
%autoreload 2

# General imports
import os
from os.path import dirname, abspath
import yaml
from copy import deepcopy
from omegaconf import DictConfig, OmegaConf


# Current project imports
import goli
from goli.utils.config_loader import (
    config_load_constants,
    config_load_dataset,
    config_load_architecture,
    config_load_metrics,
    config_load_predictor,
    config_load_training,
)


Using backend: pytorch


## Read the config file

In [2]:
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(goli.__file__)))
os.chdir(MAIN_DIR)

with open(os.path.join(MAIN_DIR, "expts/config_micro_ZINC.yaml"), "r") as f:
    cfg = yaml.safe_load(f)

cfg = dict(deepcopy(cfg))

# Get the general parameters and generate the train/val/test datasets
data_device, model_device, dtype, exp_name, seed, raise_train_error = config_load_constants(
    **cfg["constants"], main_dir=MAIN_DIR
)

  return torch._C._cuda_getDeviceCount() > 0


## Load a dataset

In [3]:

# Load and initialize the dataset
datamodule = config_load_dataset(**cfg["datasets"], main_dir=MAIN_DIR,)
print("\ndatamodule:\n", datamodule, "\n")



datamodule:
 name: DGLFromSmilesDataModule
len: 1000
batch_size_train_val: 128
batch_size_test: 256
num_node_feats: 55
num_edge_feats: 13
collate_fn: goli_collate_fn
featurization:
  atom_property_list_onehot:
  - atomic-number
  - valence
  atom_property_list_float:
  - mass
  - electronegativity
  - in-ring
  edge_property_list:
  - bond-type-onehot
  - stereo
  - in-ring
  add_self_loop: false
  explicit_H: false
  use_bonds_weights: false
 



In [4]:
# Initialize the network
model = config_load_architecture(
    **cfg["architecture"],
    in_dim_nodes=datamodule.num_node_feats,
    in_dim_edges=datamodule.num_edge_feats
)

print("\nmodel:\n", model, "\n")


model:
 DGL_GNN
---------
    pre-trans-NN(depth=1, ResidualConnectionNone)
        [FCLayer[55 -> 32] -> Linear(32)
    
    main-GNN(depth=4, ResidualConnectionSimple(skip_steps=1))
        PNAMessagePassingLayer[32 -> 32 -> 32 -> 32 -> 32]
        -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)
    
    post-trans-NN(depth=2, ResidualConnectionNone)
        [FCLayer[32 -> 32 -> 32] -> Linear(32) 



In [5]:
metrics = config_load_metrics(cfg["metrics"])
print(metrics)

{'mae': mean_absolute_error, 'pearsonr': pearsonr, 'f1 > 5': f1(>5), 'precision > 5': precision(>5), 'auroc > 5': auroc(>5)}
