In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import yaml
import omegaconf
from datetime import datetime

from typing import Union, List
from copy import deepcopy
from tdc.utils import retrieve_benchmark_names

from graphium.config._loader import (
    load_datamodule,
    load_metrics,
    load_architecture,
    load_predictor,
    load_trainer,
    save_params_to_wandb,
    load_accelerator,
    load_yaml_config,
)

# Fine-tuning on the TDC ADMET benchmarking group

[TDC](https://tdcommons.ai/) hosts a variety of ML-ready datasets and benchmarks for ML for drug discovery. The [TDC ADMET benchmarking group](https://tdcommons.ai/benchmark/admet_group/overview/) is a popular collection of benchmarks for evaluating new _foundation models_ (see e.g. [MolE](https://arxiv.org/abs/2211.02657)) due to the variety and relevance of the included tasks.

The ADMET benchmarking group is integrated in `graphium` through the `TDCBenchmarkDataModule` data-module. This notebook shows how to easily fine-tune and test a model using that data-module. 

<div style="background-color: #fff3cd; border-radius: 10px; border-color: #ffeeba; padding: 20px; margin: 20px 0;  color: #856404">
    <b>NOTE:</b> This notebook is still <i>work in progress</i>. While the <b>fine-tuning logic is unfinished</b>, the notebook does demo how one could use the data-module to easily loop over each of the datasets in the benchmarking group and get the prescribed train-test split. Once the fine-tuning logic is finalized, we will finish this notebook and officially provide it as a tutorial within Graphium. 
</div>

In [3]:
# First, let's read the yaml configuration file
with open("../expts/configs/config_tdc_admet_demo.yaml", "r") as file:
    config = yaml.load(file, Loader=yaml.FullLoader)

FileNotFoundError: [Errno 2] No such file or directory: '../expts/configs/config_tdc_admet_demo.yaml'

## Get all TDC benchmark names

In [4]:
benchmarks = retrieve_benchmark_names("admet_group")
len(benchmarks)

22

While there is a total of 22, let's just use two for practicality sake: One regression and one classification task! 

In [5]:
benchmarks = ["caco2_wang", "hia_hou"]

## Initialize all training components per task
**NOTE**: Since we do not have fine-tuning logic, this for now just creates a new model. Ultimately, we will want to use fine-tuning code to evaluate how well the pre-trained model transfers to downstream tasks. 

In [4]:
def training_testing_loop(cfg):
    """
    Simple loop to train a model from scratch and test it. 
    """
    
    # Initialize object from config
    cfg, accelerator_type = load_accelerator(cfg)
    datamodule = load_datamodule(cfg, accelerator_type)
    model_class, model_kwargs = load_architecture(cfg, in_dims=datamodule.in_dims)
    metrics = load_metrics(cfg)
    
    # Prepare data
    datamodule.prepare_data()
    
    # Initialize the predictor
    predictor = load_predictor(
        cfg,
        model_class,
        model_kwargs,
        metrics,
        datamodule.get_task_levels(),
        accelerator_type,
        datamodule.featurization,
        datamodule.task_norms
    )
    
    # Initialize the trainer
    date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S")
    trainer = load_trainer(cfg, "tdc-admet", accelerator_type, date_time_suffix)
        
    # Train
    predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"])
    trainer.fit(model=predictor, datamodule=datamodule)
    
    # Test
    predictor.set_max_nodes_edges_per_graph(datamodule, stages=["test"])
    results = trainer.test(model=predictor, datamodule=datamodule)
    
    return results


In [5]:
def filter_cfg_based_on_benchmark_name(config, names: Union[List[str], str]):
    """
    Filter a base config for the full TDC ADMET benchmarking group to only 
    have settings related to a subset of the endpoints
    """
    
    if config["datamodule"]["module_type"] != "TDCBenchmarkDataModule":
        raise ValueError("You can only use this method for the `TDCBenchmarkDataModule`")
        
    if isinstance(names, str):
        names = [names]
    
    def _filter(d):
        return {k: v for k, v in d.items() if k in names}
         
    cfg = deepcopy(config)
    
    # Update the datamodule arguments
    cfg["datamodule"]["args"]["tdc_benchmark_names"] = names
    
    # Filter the relevant config sections
    cfg["architecture"]["task_heads"] = _filter(cfg["architecture"]["task_heads"])
    cfg["predictor"]["metrics_on_progress_bar"] = _filter(cfg["predictor"]["metrics_on_progress_bar"])
    cfg["predictor"]["loss_fun"] = _filter(cfg["predictor"]["loss_fun"])
    cfg["metrics"] = _filter(cfg["metrics"])
    
    return cfg

In [8]:
results = {}

for name in benchmarks: 
    
    # Run the training-testing loop
    cfg = filter_cfg_based_on_benchmark_name(config, name)
    benchmark_results = training_testing_loop(cfg)
    
    # Extract the main metric from the config
    metric = cfg["predictor"]["metrics_on_progress_bar"][name][0]
    key = f"graph_{name}/{metric}/test"
    results[f"{name}/{metric}"] = benchmark_results[0][key]

[32m2023-07-13 14:02:22.438[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36m__init__[0m:[36m2425[0m - [1mPreparing the TDC ADMET Benchmark Group splits for each of the 1 benchmarks.[0m
[34m[1mwandb[0m: Currently logged in as: [33mcwognum[0m ([33mvalence-ood[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[32m2023-07-13 14:02:28.065[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36msetup[0m:[36m1168[0m - [1m-------------------
MultitaskDataset
	about = training set
	num_graphs_total = 634
	num_nodes_total = 18351
	max_num_nodes_per_graph = 67
	min_num_nodes_per_graph = 2
	std_num_nodes_per_graph = 11.441048173903344
	mean_num_nodes_per_graph = 28.944794952681388
	num_edges_total = 39466
	max_num_edges_per_graph = 144
	min_num_edges_per_graph = 2
	std_num_edges_per_graph = 25.32877270037731
	mean_num_edges_per_graph = 62.24921135646688
-------------------
[0m
[32m2023-07-13 14:02:28.065[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36msetup[0m:[36m1169[0m - [1m-------------------
MultitaskDataset
	about = validation set
	num_graphs_total = 91
	num_nodes_total = 2791
	max_num_nodes_per_graph = 67
	mi

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
[32m2023-07-13 14:02:49.408[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36msetup[0m:[36m1168[0m - [1m-------------------
MultitaskDataset
	about = training set
	num_graphs_total = 634
	num_nodes_total = 18351
	max_num_nodes_per_graph = 67
	min_num_nodes_per_graph = 2
	std_num_nodes_per_graph = 11.441048173903344
	mean_num_nodes_per_graph = 28.944794952681388
	num_edges_total = 39466
	max_num_edges_per_graph = 144
	min_num_edges_per_graph = 2
	std_num_edges_per_graph = 25.32877270037731
	mean_num_edges_per_graph = 62.24921135646688
-------------------
[0m
[32m2023-07-13 14:02:49.409[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36msetup[0m:[36m1169[0m - [1m-------------------
MultitaskDataset
	about = validation set
	num_graphs_total = 91
	num_nodes_total = 2791
	max_num_nodes_per_graph = 67
	min_num_nodes_per_graph = 13
	std_num_nodes_per_graph = 11.461406761225364
	mean_num_nodes_per_graph = 30.

Testing: 0it [00:00, ?it/s]

[32m2023-07-13 14:02:50.286[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36m__init__[0m:[36m2425[0m - [1mPreparing the TDC ADMET Benchmark Group splits for each of the 1 benchmarks.[0m
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[32m2023-07-13 14:02:50.421[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36msetup[0m:[36m1168[0m - [1m-------------------
MultitaskDataset
	about = training set
	num_graphs_total = 403
	num_nodes_total = 9065
	max_num_nodes_per_graph = 101
	min_num_nodes_per_graph = 7
	std_num_nodes_per_graph = 8.379156239363269
	mean_num_nodes_per_graph = 22.49379652605459
	num_edges_total = 19420
	max_num_edges_per_graph = 220
	min_num_edges_per_graph = 14
	std_num_edges_per_graph = 18.678530787820403
	mean_num_edges_per_graph = 48.188585607940446
-------------------
[0m
[32m2023-07-13 14:02:50.422[0m | [1mINFO    [0m | [

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
[32m2023-07-13 14:03:06.289[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36msetup[0m:[36m1168[0m - [1m-------------------
MultitaskDataset
	about = training set
	num_graphs_total = 403
	num_nodes_total = 9065
	max_num_nodes_per_graph = 101
	min_num_nodes_per_graph = 7
	std_num_nodes_per_graph = 8.379156239363269
	mean_num_nodes_per_graph = 22.49379652605459
	num_edges_total = 19420
	max_num_edges_per_graph = 220
	min_num_edges_per_graph = 14
	std_num_edges_per_graph = 18.678530787820403
	mean_num_edges_per_graph = 48.188585607940446
-------------------
[0m
[32m2023-07-13 14:03:06.290[0m | [1mINFO    [0m | [36mgraphium.data.datamodule[0m:[36msetup[0m:[36m1169[0m - [1m-------------------
MultitaskDataset
	about = validation set
	num_graphs_total = 58
	num_nodes_total = 1431
	max_num_nodes_per_graph = 67
	min_num_nodes_per_graph = 9
	std_num_nodes_per_graph = 9.610317607573146
	mean_num_nodes_per_graph = 24.6

Testing: 0it [00:00, ?it/s]



In [9]:
print(omegaconf.OmegaConf.to_yaml(results))

caco2_wang/mae: 2.312683582305908
hia_hou/auroc: 0.7008230686187744



The End. 