# Building and training on IPU from configurations

This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.

There are multiple examples of YAML files located in the folder `goli/expts` that one can refer to when training a new model. The file `docs/tutorials/model_training/config_ipu_tutorials.yaml` shows an example of multi-task regression from a CSV file provided by goli.



### First, the ipu config file
The IPU config file can be found at `expts/configs/ipu.config`. And is given below.

In [1]:
with open("expts/configs/ipu.config") as f:
    lines = f.readlines()
print("".join(lines))

FileNotFoundError: [Errno 2] No such file or directory: 'expts/configs/ipu.config'

## Creating the yaml file

The first step is to create a YAML file containing all the required configurations, with an example given at `goli/docs/tutorials/model_training/config_ipu_tutorials.yaml`. We will go through each part of the configurations.

In [None]:
import yaml
import omegaconf

In [None]:
def print_config_with_key(config, key):
    new_config = {key: config[key]}
    print(omegaconf.OmegaConf.to_yaml(new_config))

In [None]:
# First, let's read the yaml configuration file
with open("config_ipu_tutorials.yaml", "r") as file:
    yaml_config = yaml.load(file, Loader=yaml.FullLoader)

print("Yaml file loaded")

Yaml file loaded


### Constants

First, we define the constants such as the random seed and whether the model should raise or ignore an error.
The `name` here will be used to log the metrics into WandB and log the models.
The `accelerator` is used to define the device. It supports `cpu`, `ipu` or `gpu`. This is the only part that needs to change when working with IPUs.

In [None]:
print_config_with_key(yaml_config, "constants")

constants:
  name: tutorial_model
  seed: 42
  raise_train_error: true
  accelerator:
    type: ipu



### Datamodule

Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.

The `MultitaskIPUFromSmilesDataModule` allows us to define a set of tasks within different CSV or parquet files, and use them to train the model simultaneously via the path `datamodule: args: task_specific_args`

For more details, see class `goli.data.datamodule.MultitaskIPUFromSmilesDataModule`

#### Reading a CSV file and train/val/test splits
Here is an example of configuration regarding the task named "homo", which reads the file located in `https://storage.googleapis.com/goli-public/datasets/QM9/norm_mini_qm9.csv` , selects the columns "homo" and "lumo", and splits into validation and test set with rations of 20% each

In [None]:
print_config_with_key(yaml_config["datamodule"]["args"]["task_specific_args"], "homo")

homo:
  df: null
  df_path: https://storage.googleapis.com/goli-public/datasets/QM9/norm_micro_qm9.csv
  smiles_col: smiles
  label_cols:
  - homo
  - lumo
  split_val: 0.2
  split_test: 0.2
  split_seed: 42
  splits_path: null
  sample_size: null
  idx_col: null
  weights_col: null
  weights_type: null



#### Featurizing a molecule
Molecules can be featurized using various properties and positional / structural encoding. A list of all features is available here:
- `goli.features.featurizer.get_mol_atomic_features_onehot`
- `goli.features.featurizer.get_mol_atomic_features_float`
- `goli.features.featurizer.get_mol_edge_features`
- `goli.features.spectral.compute_laplacian_positional_eigvecs`
- `goli.features.rw.compute_rwse`

Example of a configuration for featurization below. Notice the list of atomic and edge properties. Notice `pos_encoding_as_features` that defines both the laplacian and random-walk positional encodings.

In [None]:
print_config_with_key(yaml_config["datamodule"]["args"], "featurization")

featurization:
  atom_property_list_onehot:
  - atomic-number
  - valence
  atom_property_list_float:
  - mass
  - electronegativity
  - in-ring
  edge_property_list:
  - bond-type-onehot
  - stereo
  - in-ring
  add_self_loop: false
  explicit_H: false
  use_bonds_weights: false
  pos_encoding_as_features:
    pos_types:
      la_pos:
        pos_type: laplacian_eigvec_eigval
        num_pos: 3
        normalization: none
        disconnected_comp: true
      rw_pos:
        pos_type: rwse
        ksteps: 16



### Architecture

In the architecture, we define all the layers for the model, including the layers for the pre-processing MLP (input layers `pre-nn`), the post-processing MLP (output layers `post-nn`), and the main GNN (graph neural network `gnn`).

The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as `pyg:gcn`, `pyg:gin`, `pyg:gine`,  `pyg:gated-gcn`, `pyg:pna-msgpass`, and `pyg:gps`.

For more details, see the following classes:

-  `goli.nn.global_architecture.FullGraphNetwork`: Main class for the architecture
-  `goli.nn.global_architecture.FeedForwardNN`: Main class for the inputs and outputs MLP
-  `goli.nn.pyg_architecture.FeedForwardPyg`: Main class for the GNN layers

#### Parameters for the node pre-processing NN

In [None]:
print_config_with_key(yaml_config["architecture"], "pre_nn")

pre_nn:
  out_dim: 32
  hidden_dims: 32
  depth: 1
  activation: relu
  last_activation: none
  dropout: 0.1
  normalization: none
  last_normalization: none
  residual_type: none



#### Parameters for the edge pre-processing NN

In [None]:
print_config_with_key(yaml_config["architecture"], "pre_nn_edges")

pre_nn_edges:
  out_dim: 16
  hidden_dims: 16
  depth: 1
  activation: relu
  last_activation: none
  dropout: 0.1
  normalization: none
  last_normalization: none
  residual_type: none



#### Parameters for the GNN
Here is an example of a GraphGPS layer, with it's MPNN being a GINE model

In [None]:
print_config_with_key(yaml_config["architecture"], "gnn")

gnn:
  out_dim: 32
  hidden_dims: 32
  depth: 3
  activation: relu
  last_activation: none
  dropout: 0.1
  normalization: none
  last_normalization: none
  residual_type: simple
  pooling:
  - sum
  - mean
  - max
  virtual_node: none
  layer_type: pyg:gps
  layer_kwargs:
    mpnn_type: pyg:gine
    mpnn_kwargs: null
    attn_type: full-attention
    attn_kwargs: null



#### Parameters for the node post-processing NN (after the GNN)

In [None]:
print_config_with_key(yaml_config["architecture"], "post_nn")

post_nn:
  out_dim: 32
  hidden_dims: 32
  depth: 1
  activation: relu
  last_activation: none
  dropout: 0.1
  normalization: none
  last_normalization: none
  residual_type: none



#### Parameters for the multi-task output heads
Here is the example for the task heads. Notice that `"task_name"` should match the tasks in the section `datamodule: args: task_specific_args`.

In [None]:
print_config_with_key(yaml_config["architecture"], "task_heads")


task_heads:
- task_name: homo
  out_dim: 2
  hidden_dims: 32
  depth: 2
  activation: relu
  last_activation: none
  dropout: 0.1
  normalization: none
  last_normalization: none
  residual_type: none
- task_name: alpha
  out_dim: 1
  hidden_dims: 32
  depth: 2
  activation: relu
  last_activation: none
  dropout: 0.1
  normalization: none
  last_normalization: none
  residual_type: none
- task_name: cv
  out_dim: 1
  hidden_dims: 32
  depth: 2
  activation: relu
  last_activation: none
  dropout: 0.1
  normalization: none
  last_normalization: none
  residual_type: none



### Predictor

In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.

Again, each of these arguments depend on the task. And the `task_name` should match the ones from `datamodule: args: task_specific_args`

In [None]:
print_config_with_key(yaml_config, "predictor")

predictor:
  metrics_on_progress_bar:
    homo:
    - mae
    - pearsonr
    alpha:
    - mae
    cv:
    - mae
    - pearsonr
  loss_fun:
    homo: mse_ipu
    alpha: mse_ipu
    cv: mse_ipu
  random_seed: 42
  optim_kwargs:
    lr: 0.001
  torch_scheduler_kwargs: null
  scheduler_kwargs: null
  target_nan_mask: null
  flag_kwargs:
    n_steps: 0
    alpha: 0.0



### Metrics

All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.

See class `goli.trainer.metrics.MetricWrapper` for more details.

See `goli.trainer.metrics.METRICS_CLASSIFICATION` and `goli.trainer.metrics.METRICS_REGRESSION` for a dictionnary of accepted metrics.

Again, the metrics are task-dependant and must match the names in `datamodule: args: task_specific_args`

In [None]:
print_config_with_key(yaml_config, "metrics")

metrics:
  homo:
  - name: mae
    metric: mae
    threshold_kwargs: null
  - name: pearsonr
    metric: pearsonr
    threshold_kwargs: null
    target_nan_mask: ignore-mean-label
  alpha:
  - name: mae
    metric: mae
    threshold_kwargs: null
  - name: pearsonr
    metric: pearsonr
    threshold_kwargs: null
  cv:
  - name: mae
    metric: mae
    threshold_kwargs: null
  - name: pearsonr
    metric: pearsonr
    threshold_kwargs: null



### Trainer

Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.

In [None]:
print_config_with_key(yaml_config, "trainer")

trainer:
  logger:
    save_dir: logs/QM9
    name: tutorial_model
  model_checkpoint:
    dirpath: models_checkpoints/QM9/
    filename: tutorial_model
    save_top_k: 1
    every_n_epochs: 1
  trainer:
    precision: 32
    max_epochs: 2
    min_epochs: 1



## Training the model

Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.

### First, let's do our imports

In [None]:
# General imports
import os
from os.path import dirname, abspath
import yaml
from copy import deepcopy
from omegaconf import DictConfig
import timeit
from loguru import logger
from pytorch_lightning.utilities.model_summary import ModelSummary

# Current project imports
import goli
from goli.config._loader import load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer
from goli.utils.safe_run import SafeRun

# WandB
import wandb

### Then, let's load the configuration file

In [None]:

# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(goli.__file__)))
CONFIG_FILE = "docs/tutorials/model_training/config_ipu_tutorials.yaml"
os.chdir(MAIN_DIR)

with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f:
    cfg = yaml.safe_load(f)

### Now let's process the data

In [None]:
# Load and initialize the dataset
datamodule = load_datamodule(cfg)
datamodule.prepare_data()

2022-09-16 21:25:05.422 | INFO     | goli.data.datamodule:prepare_data:699 - Reading data for task 'homo'
2022-09-16 21:25:05.558 | INFO     | goli.data.datamodule:prepare_data:699 - Reading data for task 'alpha'
2022-09-16 21:25:05.679 | INFO     | goli.data.datamodule:prepare_data:699 - Reading data for task 'cv'
2022-09-16 21:25:05.803 | INFO     | goli.data.datamodule:prepare_data:721 - Done reading datasets
2022-09-16 21:25:05.803 | INFO     | goli.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'homo' with 1005 data points.
2022-09-16 21:25:05.804 | INFO     | goli.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'alpha' with 1005 data points.
2022-09-16 21:25:05.804 | INFO     | goli.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'cv' with 1005 data points.


mols to ids:   0%|          | 0/3015 [00:00<?, ?it/s]

featurizing_smiles:   0%|          | 0/1005 [00:00<?, ?it/s]

### Let's build the architecture, metrics, and set-up the trainer

In [None]:
# Initialize the network
model_class, model_kwargs = load_architecture(
    cfg,
    in_dims=datamodule.in_dims,
)

metrics = load_metrics(cfg)
logger.info(metrics)

predictor = load_predictor(cfg, model_class, model_kwargs, metrics)

logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))

trainer = load_trainer(cfg, "tutorial-run")

2022-09-16 21:25:15.740 | INFO     | __main__:<cell line: 8>:8 - {'homo': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}, 'alpha': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}, 'cv': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}}
2022-09-16 21:25:15.784 | INFO     | __main__:<cell line: 12>:12 - Multitask_GNN
---------------
    pre-NN(depth=1, ResidualConnectionNone)
        [FCLayer[87 -> 32]
    
    pre-NN-edges(depth=1, ResidualConnectionNone)
        [FCLayer[13 -> 16]
    
    GNN(depth=3, ResidualConnectionSimple(skip_steps=1))
        GPSLayerPyg[32 -> 32 -> 32 -> 32]
        -> Pooling(['sum', 'mean', 'max']) -> FCLayer(96 -> 32, activation=None)
    
    post-NN(depth=1, ResidualConnectionNone)
        [FCLayer[32 -> 32]
2022-09-16 21:25:15.786 | INFO     | __main__:<cell line: 13>:13 -    | Name                                | Type                      | Params
--------------------------------------------------------------------------

  rank_zero_deprecation(
  rank_zero_warn("more than one device specific flag has been set")
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: True, using: 1 IPUs
HPU available: False, using: 0 HPUs


### Finally, let's run the model

In [None]:
# Run the model training
with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
    trainer.fit(model=predictor, datamodule=datamodule)

# Exit WandB
wandb.finish()

  rank_zero_deprecation(
  rank_zero_deprecation(


mols to ids:   0%|          | 0/1809 [00:00<?, ?it/s]

mols to ids:   0%|          | 0/603 [00:00<?, ?it/s]

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name  | Type                      | Params
----------------------------------------------------
0 | model | FullGraphMultiTaskNetwork | 47.6 K
----------------------------------------------------
47.6 K    Trainable params
0         Non-trainable params
47.6 K    Total params
0.190     Total estimated model params size (MB)


-------------------
MultitaskDataset
	about = training set
	num_graphs_total = 603
	num_nodes_total = 5285
	max_num_nodes_per_graph = 9
	min_num_nodes_per_graph = 1
	std_num_nodes_per_graph = 0.6989218547197285
	mean_num_nodes_per_graph = 8.764510779436153
	num_edges_total = 11316
	max_num_edges_per_graph = 26
	min_num_edges_per_graph = 0
	std_num_edges_per_graph = 2.6821780989470896
	mean_num_edges_per_graph = 18.766169154228855
-------------------

-------------------
MultitaskDataset
	about = validation set
	num_graphs_total = 201
	num_nodes_total = 1756
	max_num_nodes_per_graph = 9
	min_num_nodes_per_graph = 2
	std_num_nodes_per_graph = 0.7949619835770694
	mean_num_nodes_per_graph = 8.7363184079602
	num_edges_total = 3736
	max_num_edges_per_graph = 24
	min_num_edges_per_graph = 2
	std_num_edges_per_graph = 2.7704251592456193
	mean_num_edges_per_graph = 18.587064676616915
-------------------



Sanity Checking: 0it [00:00, ?it/s]

2022-09-16 21:25:20.077 | INFO     | goli.ipu.ipu_dataloader:create_ipu_dataloader:231 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0
2022-09-16 21:25:20.079 | INFO     | goli.ipu.ipu_dataloader:create_ipu_dataloader:232 - Provided `max_num_nodes=72`
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
2022-09-16 21:25:28.739 | INFO     | goli.ipu.ipu_dataloader:create_ipu_dataloader:231 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0
2022-09-16 21:25:28.741 | INFO     | goli.ipu.ipu_dataloader:create_ipu_dataloader:232 - Provided `max_num_nodes=72`
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  if all([this_input.shape[0] == 1 for this_input in inputs]):
  batch_idx = batch.pop("_batch_idx").item()
  if torch.prod(torch.as_tensor(h.shape[:-1])) == 0:
  if torch.prod(torch.as_tensor(h.shape[:-1])) == 0:
  assert node_feats.shape[0] == g.num_nodes
  assert edge_feats.shape[0] == g.num_edges
  if torch.prod(torch.as_tensor(e.shape[:-1])) == 0:
  if torch.prod(torch.as_tensor(e.shape[:-1])) == 0:
  assert edge_index.min() >= 0
  assert edge_index.max() < torch.iinfo(edge_index.dtype).max
  assert edge_index.size(0) == 2
  on_ipu = ("graph_is_true" in batch.keys) and (not batch.graph_is_true.all())
  max_num_nodes_per_graph = batch.dataset_max_nodes_per_graph[0].item()
  batch_size = int(batch.max()) + 1
  assert (
  if targets[task].shape == preds[task].shape:
  total_norm = torch.tensor(0.0)
  return obj.item(), True
Graph compilation: 100%|██████████| 100/100 [00:02<00:00]


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha/MSELossIPU/train,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
alpha/MSELossIPU/val,▁█
alpha/loss/val,▁█
alpha/mae/train,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
alpha/mae/val,▁█
alpha/mean_pred/train,█▆▁▁▁▁▁▁▁▁▁▁▁▁▁
alpha/mean_pred/val,▁█
alpha/mean_target/train,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
alpha/mean_target/val,▁▁
alpha/median_pred/train,██▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
alpha/MSELossIPU/train,10.27925
alpha/MSELossIPU/val,1.24586
alpha/loss/val,1.24586
alpha/mae/train,1.92692
alpha/mae/val,0.74558
alpha/mean_pred/train,-0.0
alpha/mean_pred/val,0.54309
alpha/mean_target/train,-1.47591
alpha/mean_target/val,-0.05393
alpha/median_pred/train,-0.0


## Testing the model
Once the model is trained, we can use the same datamodule to get the results on the test set. Here, `ckpt_path` refers to the checkpoint path where the model at the best validation step was saved. Thus, the results on the test set represent the early stopping.

All the metrics that were computed on the validation set are then computed on the test set, printed, and saved into the `metrics.yaml` file.

In [None]:
ckpt_path = trainer.checkpoint_callbacks[0].best_model_path
trainer.test(model=predictor, datamodule=datamodule, ckpt_path=ckpt_path)

  rank_zero_deprecation(


mols to ids:   0%|          | 0/603 [00:00<?, ?it/s]

Restoring states from the checkpoint path at /home/dom/goli/models_checkpoints/QM9/tutorial_model-v2.ckpt
Loaded model weights from checkpoint at /home/dom/goli/models_checkpoints/QM9/tutorial_model-v2.ckpt


-------------------
MultitaskDataset
	about = test set
	num_graphs_total = 201
	num_nodes_total = 1751
	max_num_nodes_per_graph = 9
	min_num_nodes_per_graph = 1
	std_num_nodes_per_graph = 0.9443829267824616
	mean_num_nodes_per_graph = 8.711442786069652
	num_edges_total = 3748
	max_num_edges_per_graph = 26
	min_num_edges_per_graph = 0
	std_num_edges_per_graph = 3.0647412282392468
	mean_num_edges_per_graph = 18.64676616915423
-------------------



2022-09-16 21:26:11.968 | INFO     | goli.ipu.ipu_dataloader:create_ipu_dataloader:231 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0
2022-09-16 21:26:11.969 | INFO     | goli.ipu.ipu_dataloader:create_ipu_dataloader:232 - Provided `max_num_nodes=72`


Testing: 0it [00:00, ?it/s]

Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/dom/.venv/goli_ipu/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 113, in run
    shandler(sreq)
  File "/home/dom/.venv/goli_ipu/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 172, in server_record_publish
    iface = self._mux.get_stream(stream_id).interface
  File "/home/dom/.venv/goli_ipu/lib/python3.8/site-packages/wandb/sdk/service/streams.py", line 186, in get_stream
    stream = self._streams[stream_id]
KeyError: '1b65xyq8'


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  alpha/MSELossIPU/test     1.3306002616882324
     alpha/loss/test        1.3306002616882324
     alpha/mae/test         0.7145935893058777
  alpha/mean_pred/test      0.4449523687362671
 alpha/mean_target/test     -0.1731886863708496
 alpha/median_pred/test     0.4823043942451477
alpha/median_target/test     -0.05999755859375
   alpha/pearsonr/test      0.6725651025772095
   alpha/std_pred/test      0.6750865578651428
  alpha/std_target/test     1.2930139303207397
   cv/MSELossIPU/test        1.053282618522644
      cv/loss/test           1.053282618522644
       cv/mae/test          0.7169217467308044
    cv/mean_pred/test       0.37237051129341125
   cv/mean_target/test     -0.08974086493253708
   cv/

[{'homo/mean_pred/test': 0.09660027176141739,
  'homo/std_pred/test': 0.2802477180957794,
  'homo/median_pred/test': -0.013679077848792076,
  'homo/mean_target/test': 0.1092241033911705,
  'homo/std_target/test': 1.07249116897583,
  'homo/median_target/test': 0.1390380859375,
  'homo/mae/test': 0.7971344590187073,
  'homo/pearsonr/test': 0.12582433223724365,
  'homo/MSELossIPU/test': 1.1376948356628418,
  'homo/loss/test': 1.1376948356628418,
  'alpha/mean_pred/test': 0.4449523687362671,
  'alpha/std_pred/test': 0.6750865578651428,
  'alpha/median_pred/test': 0.4823043942451477,
  'alpha/mean_target/test': -0.1731886863708496,
  'alpha/std_target/test': 1.2930139303207397,
  'alpha/median_target/test': -0.05999755859375,
  'alpha/mae/test': 0.7145935893058777,
  'alpha/pearsonr/test': 0.6725651025772095,
  'alpha/MSELossIPU/test': 1.3306002616882324,
  'alpha/loss/test': 1.3306002616882324,
  'cv/mean_pred/test': 0.37237051129341125,
  'cv/std_pred/test': 0.5452650785446167,
  'cv/medi