# Building and training a simple model from configurations

This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.

There are multiple examples of YAML files located in the folder `goli/expts` that one can refer to when training a new model. The file `config_ZINC_bench_gnn.yaml` shows an example of single task regression from a CSV file provided by goli. And the file `config_molpcba.yaml` shows an example of a multi-task classification on a dataset provided by OGB with some missing data.

## Creating the yaml file

The first step is to create a YAML file containing all the required configurations, with an example given at `goli/expts/config_micro_ZINC.yaml`. We will go through each part of the configurations.

In [1]:
import yaml
import omegaconf

In [2]:
def print_config_with_key(config, key):
    new_config = {key: config[key]}
    print(omegaconf.OmegaConf.to_yaml(new_config))

In [3]:
# First, let's read the yaml configuration file
with open("../../../expts/config_micro_ZINC.yaml", "r") as file:
    yaml_config = yaml.load(file, Loader=yaml.FullLoader)

print("Yaml file loaded")

Yaml file loaded


### Constants

First, we define the constants such as the random seed and whether the model should raise or ignore an error.

In [4]:
print_config_with_key(yaml_config, "constants")

constants:
  seed: 42
  raise_train_error: true



### Datamodule

Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.

For more details, see class `goli.data.datamodule.DGLFromSmilesDataModule`

In [5]:
print_config_with_key(yaml_config, "datamodule")

datamodule:
  module_type: DGLFromSmilesDataModule
  args:
    df_path: goli/data/micro_ZINC/micro_ZINC.csv
    cache_data_path: goli/data/cache/micro_ZINC/full.cache
    label_cols:
    - score
    smiles_col: SMILES
    featurization_n_jobs: -1
    featurization_progress: true
    featurization:
      atom_property_list_onehot:
      - atomic-number
      - valence
      atom_property_list_float:
      - mass
      - electronegativity
      - in-ring
      edge_property_list:
      - bond-type-onehot
      - stereo
      - in-ring
      add_self_loop: false
      explicit_H: false
      use_bonds_weights: false
      pos_encoding_as_features:
        pos_type: laplacian_eigvec
        num_pos: 3
        normalization: none
        disconnected_comp: true
      pos_encoding_as_directions:
        pos_type: laplacian_eigvec
        num_pos: 3
        normalization: none
        disconnected_comp: true
    split_val: 0.2
    split_test: 0.2
    split_seed: 42
    splits_path: null
    b

### Architecture

In the architecture, we define all the layers for the model, including the layers for the pre-processing MLP (input layers `pre-nn`), the post-processing MLP (output layers `post-nn`), and the main GNN (graph neural network `gnn`).

The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as `gcn`, `gin`, `gat`, `gated-gcn`, `pna-conv` and `pna-msgpass`.

For more details, see the following classes:

-  `goli.nn.architecture.FullDGLNetwork`: Main class for the architecture
-  `goli.nn.architecture.FeedForwardNN`: Main class for the inputs and outputs MLP
-  `goli.nn.architecture.FeedForwardDGL`: Main class for the GNN layers

In [6]:
print_config_with_key(yaml_config, "architecture")

architecture:
  model_type: fulldglnetwork
  pre_nn:
    out_dim: 32
    hidden_dims: 32
    depth: 1
    activation: relu
    last_activation: none
    dropout: 0.1
    normalization: batch_norm
    last_normalization: batch_norm
    residual_type: none
  pre_nn_edges:
    out_dim: 16
    hidden_dims: 16
    depth: 2
    activation: relu
    last_activation: none
    dropout: 0.1
    normalization: batch_norm
    last_normalization: batch_norm
    residual_type: none
  gnn:
    out_dim: 32
    hidden_dims: 32
    depth: 4
    activation: relu
    last_activation: none
    dropout: 0.1
    normalization: batch_norm
    last_normalization: batch_norm
    residual_type: simple
    pooling:
    - sum
    - max
    - dir1
    virtual_node: sum
    layer_type: dgn-msgpass
    layer_kwargs:
      aggregators:
      - mean
      - max
      - dir1/dx_abs
      - dir1/smooth
      scalers:
      - identity
      - amplification
      - attenuation
  post_nn:
    out_dim: 1
    hidden_dims: 32


### Predictor

In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.

In [7]:
print_config_with_key(yaml_config, "predictor")

predictor:
  metrics_on_progress_bar:
  - mae
  - pearsonr
  - f1 > 3
  - precision > 3
  loss_fun: mse
  random_seed: 42
  optim_kwargs:
    lr: 0.01
    weight_decay: 1.0e-07
  lr_reduce_on_plateau_kwargs:
    factor: 0.5
    patience: 7
  scheduler_kwargs:
    monitor: loss/val
    frequency: 1
  target_nan_mask: 0



### Metrics

All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.

See class `goli.trainer.metrics.MetricWrapper` for more details.

See `goli.trainer.metrics.METRICS_CLASSIFICATION` and `goli.trainer.metrics.METRICS_REGRESSION` for a dictionnary of accepted metrics.

In [8]:
print_config_with_key(yaml_config, "metrics")

metrics:
- name: mae
  metric: mae
  threshold_kwargs: null
- name: pearsonr
  metric: pearsonr
  threshold_kwargs: null
- name: f1 > 3
  metric: f1
  num_classes: 2
  average: micro
  threshold_kwargs:
    operator: greater
    threshold: 3
    th_on_preds: true
    th_on_target: true
    target_to_int: true
- name: f1 > 5
  metric: f1
  num_classes: 2
  average: micro
  threshold_kwargs:
    operator: greater
    threshold: 5
    th_on_preds: true
    th_on_target: true
    target_to_int: true
- name: precision > 3
  metric: precision
  class_reduction: micro
  threshold_kwargs:
    operator: greater
    threshold: 3
    th_on_preds: true
    th_on_target: true
    target_to_int: true



### Trainer

Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.

In [9]:
print_config_with_key(yaml_config, "trainer")

trainer:
  logger:
    save_dir: logs/micro_ZINC
  early_stopping:
    monitor: loss/val
    min_delta: 0
    patience: 10
    mode: min
  model_checkpoint:
    dirpath: models_checkpoints/micro_ZINC/
    filename: model
    monitor: loss/val
    mode: min
    save_top_k: 1
    every_n_val_epochs: 1
  trainer:
    max_epochs: 25
    min_epochs: 5
    gpus: 1



## Training the model

Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.

In [10]:
from os.path import dirname, abspath
from copy import deepcopy

import goli
from goli.config._loader import (load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer)

MAIN_DIR = dirname(dirname(abspath(goli.__file__)))
os.chdir(MAIN_DIR)

cfg = dict(deepcopy(yaml_config))

# Load and initialize the dataset
datamodule = load_datamodule(cfg)
print("\ndatamodule:\n", datamodule, "\n")

# Initialize the network
model_class, model_kwargs = load_architecture(
    cfg,
    in_dim_nodes=datamodule.num_node_feats_with_positional_encoding,
    in_dim_edges=datamodule.num_edge_feats,
)

# Load and print the metrics
metrics = load_metrics(cfg)
print(metrics)

# Load the predictor, print the model, and print a summary of the number of parameters
predictor = load_predictor(cfg, model_class, model_kwargs, metrics)
print(predictor.model)
print(predictor.summarize(mode=4, to_print=False))

# Load the trainer, and start the training
trainer = load_trainer(cfg)
trainer.fit(model=predictor, datamodule=datamodule)

Using backend: pytorch
  eigvecs[comp, :] = this_eigvecs
  eigvals_tile[comp, :] = this_eigvals

datamodule:
 name: DGLFromSmilesDataModule
len: 1000
train_size: null
val_size: null
test_size: null
batch_size_train_val: 128
batch_size_test: 128
num_node_feats: 55
num_node_feats_with_positional_encoding: 58
num_edge_feats: 13
num_labels: 1
collate_fn: goli_collate_fn
featurization:
  atom_property_list_onehot:
  - atomic-number
  - valence
  atom_property_list_float:
  - mass
  - electronegativity
  - in-ring
  edge_property_list:
  - bond-type-onehot
  - stereo
  - in-ring
  add_self_loop: false
  explicit_H: false
  use_bonds_weights: false
  pos_encoding_as_features:
    pos_type: laplacian_eigvec
    num_pos: 3
    normalization: none
    disconnected_comp: true
  pos_encoding_as_directions:
    pos_type: laplacian_eigvec
    num_pos: 3
    normalization: none
    disconnected_comp: true
 

{'mae': mean_absolute_error, 'pearsonr': pearsonr, 'f1 > 3': f1(>3), 'f1 > 5': f1(>5), 'preci

## Testing the model
Once the model is trained, we can use the same datamodule to get the results on the test set. Here, `ckpt_path` refers to the checkpoint path where the model at the best validation step was saved. Thus, the results on the test set represent the early stopping.

All the metrics that were computed on the validation set are then computed on the test set, printed, and saved into the `metrics.yaml` file.

In [11]:
ckpt_path = trainer.checkpoint_callbacks[0].best_model_path
trainer.test(model=predictor, datamodule=datamodule, ckpt_path=ckpt_path)

Testing: 100%|██████████| 2/2 [00:00<00:00,  3.52it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'MSELoss/test': 0.6393043398857117,
 'f1 > 3/test': 0.0,
 'f1 > 5/test': 0.0,
 'loss/test': 0.6393043398857117,
 'mae/test': 0.5762802362442017,
 'mean_pred/test': -0.48094016313552856,
 'mean_target/test': -0.6447566151618958,
 'pearsonr/test': 0.9359009265899658,
 'precision > 3/test': nan,
 'std_pred/test': 1.7718788385391235,
 'std_target/test': 2.1336562633514404}
--------------------------------------------------------------------------------


[{'mean_pred/test': -0.48094016313552856,
  'std_pred/test': 1.7718788385391235,
  'mean_target/test': -0.6447566151618958,
  'std_target/test': 2.1336562633514404,
  'mae/test': 0.5762802362442017,
  'pearsonr/test': 0.9359009265899658,
  'f1 > 3/test': 0.0,
  'f1 > 5/test': 0.0,
  'precision > 3/test': nan,
  'MSELoss/test': 0.6393043398857117,
  'loss/test': 0.6393043398857117}]