From ea3afd9f80396b8e3ee82732b72f0a67f349c703 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Mon, 4 Dec 2023 11:28:47 +0000 Subject: [PATCH 1/6] Account for possibility of having '0 GNN layers' on an IPU in pipeline split --- .../nn/architectures/global_architectures.py | 57 ++++++++++++++++--- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 6570ca492..6e0bb88c4 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -866,7 +866,7 @@ def _virtual_node_forward( return feat, vn_feat, edge_feat - def forward(self, g: Batch) -> torch.Tensor: + def forward(self, g: Batch, extra_return_names: List[str] = []) -> torch.Tensor: r""" Apply the full graph neural network on the input graph and node features. @@ -1198,12 +1198,24 @@ def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu): f"of {model_depth}" ) + if 0 in gnn_layers_per_ipu[:-1]: + raise ValueError("Only the last IPU can have 0 GNN layers") + begin_block_layer_indices = [sum(gnn_layers_per_ipu[:i]) for i in range(1, pipeline_length)] for begin_block_layer_index, ipu_id in zip(begin_block_layer_indices, range(1, pipeline_length)): - self.gnn.layers[begin_block_layer_index] = poptorch.BeginBlock( - self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id - ) + if begin_block_layer_index < model_depth: + self.gnn.layers[begin_block_layer_index] = poptorch.BeginBlock( + self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id + ) + elif self.task_heads is not None and ipu_id == pipeline_length - 1: + self.task_heads = poptorch.BeginBlock( + self.task_heads, ipu_id=ipu_id + ) + else: + raise ValueError("Invalid pipeline split, nothing to put on last IPU " + "(0 GNN layers on last IPU but no task heads)") + def _enable_readout_cache(self, module_filter: Optional[Union[str, List[str]]]): """ @@ -1286,7 +1298,8 @@ def create_module_map(self, level: Union[Literal["layers"], Literal["module"]] = self._module_map[module_name] = module.layers return self._module_map - def forward(self, g: Batch) -> Tensor: + + def forward(self, g: Batch, extra_return_names: List[str] = None) -> Tensor: r""" Apply the pre-processing neural network, the graph neural network, and the post-processing neural network on the graph features. @@ -1341,9 +1354,27 @@ def forward(self, g: Batch) -> Tensor: e = self.pre_nn_edges.forward(e) g["edge_feat"] = e - # Run the graph neural network - g = self.gnn.forward(g) + # Apologies for the similar names here + extras_to_return = [] + + if extra_return_names: + + # Run the graph neural network + g, gnn_extras_to_return = self.gnn.forward(g, extra_return_names=extra_return_names) + extras_to_return.update(gnn_extras_to_return) + if "pre_task_heads" in extra_returns: + extras_to_return.update({"pre_task_heads": g}) + + if self.task_heads is None: + return g, extras_to_return + + final_output, task_head_extras_to_return = self.task_heads.forward(g, extra_return_names=extra_return_names) + extras_to_return.update(task_head_extras_to_return) + + return final_output, extras_to_return + + # Keep original code if no extra_return_names if self.task_heads is not None: return self.task_heads.forward(g) @@ -1851,7 +1882,7 @@ def __init__( filtered_kwargs["in_dim"] = self.graph_output_nn_kwargs[task_level]["out_dim"] self.task_heads[task_name] = FeedForwardNN(**filtered_kwargs) - def forward(self, g: Batch) -> Dict[str, torch.Tensor]: + def forward(self, g: Batch, extra_return_names: List[str] = []) -> Dict[str, torch.Tensor]: r""" forward function of the task head Parameters: @@ -1859,15 +1890,23 @@ def forward(self, g: Batch) -> Dict[str, torch.Tensor]: Returns: task_head_outputs: Return a dictionary: Dict[task_name, Tensor] """ + + extras_to_return = {} + features = {task_level: self.graph_output_nn[task_level](g) for task_level in self.task_levels} + if 'task_level_features' in extra_return_names: + extras_to_return['task_level_features'] = features + task_head_outputs = {} for task_name, head in self.task_heads.items(): task_level = self.task_heads_kwargs[task_name].get( "task_level", None ) # Get task_level without modifying head_kwargs task_head_outputs[task_name] = head.forward(features[task_level]) - + + if extra_return_names: + return task_head_outputs, extras_to_return return task_head_outputs def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]: From 9372268b72c15f64e64b98b7793efd1cb9a72484 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Mon, 4 Dec 2023 11:32:41 +0000 Subject: [PATCH 2/6] `black` linting --- .../nn/architectures/global_architectures.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 6e0bb88c4..4d3e8b4eb 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -1209,13 +1209,12 @@ def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu): self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id ) elif self.task_heads is not None and ipu_id == pipeline_length - 1: - self.task_heads = poptorch.BeginBlock( - self.task_heads, ipu_id=ipu_id - ) + self.task_heads = poptorch.BeginBlock(self.task_heads, ipu_id=ipu_id) else: - raise ValueError("Invalid pipeline split, nothing to put on last IPU " - "(0 GNN layers on last IPU but no task heads)") - + raise ValueError( + "Invalid pipeline split, nothing to put on last IPU " + "(0 GNN layers on last IPU but no task heads)" + ) def _enable_readout_cache(self, module_filter: Optional[Union[str, List[str]]]): """ @@ -1298,7 +1297,6 @@ def create_module_map(self, level: Union[Literal["layers"], Literal["module"]] = self._module_map[module_name] = module.layers return self._module_map - def forward(self, g: Batch, extra_return_names: List[str] = None) -> Tensor: r""" Apply the pre-processing neural network, the graph neural network, @@ -1358,7 +1356,6 @@ def forward(self, g: Batch, extra_return_names: List[str] = None) -> Tensor: extras_to_return = [] if extra_return_names: - # Run the graph neural network g, gnn_extras_to_return = self.gnn.forward(g, extra_return_names=extra_return_names) extras_to_return.update(gnn_extras_to_return) @@ -1369,10 +1366,12 @@ def forward(self, g: Batch, extra_return_names: List[str] = None) -> Tensor: if self.task_heads is None: return g, extras_to_return - final_output, task_head_extras_to_return = self.task_heads.forward(g, extra_return_names=extra_return_names) + final_output, task_head_extras_to_return = self.task_heads.forward( + g, extra_return_names=extra_return_names + ) extras_to_return.update(task_head_extras_to_return) - return final_output, extras_to_return + return final_output, extras_to_return # Keep original code if no extra_return_names if self.task_heads is not None: @@ -1892,11 +1891,11 @@ def forward(self, g: Batch, extra_return_names: List[str] = []) -> Dict[str, tor """ extras_to_return = {} - + features = {task_level: self.graph_output_nn[task_level](g) for task_level in self.task_levels} - if 'task_level_features' in extra_return_names: - extras_to_return['task_level_features'] = features + if "task_level_features" in extra_return_names: + extras_to_return["task_level_features"] = features task_head_outputs = {} for task_name, head in self.task_heads.items(): @@ -1904,7 +1903,7 @@ def forward(self, g: Batch, extra_return_names: List[str] = []) -> Dict[str, tor "task_level", None ) # Get task_level without modifying head_kwargs task_head_outputs[task_name] = head.forward(features[task_level]) - + if extra_return_names: return task_head_outputs, extras_to_return return task_head_outputs From ff6e6a1aa42c3ec7842b77b60b705c7c2d30fa02 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Wed, 20 Dec 2023 09:42:59 +0000 Subject: [PATCH 3/6] fingerprinting working but slow --- graphium/nn/architectures/global_architectures.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 4d3e8b4eb..5258530d3 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -1353,14 +1353,15 @@ def forward(self, g: Batch, extra_return_names: List[str] = None) -> Tensor: g["edge_feat"] = e # Apologies for the similar names here - extras_to_return = [] + extras_to_return = {} + #g, gnn_extras_to_return = self.gnn.forward(g, extra_return_names=extra_return_names) + #extras_to_return.update(gnn_extras_to_return) + g = self.gnn.forward(g) if extra_return_names: # Run the graph neural network - g, gnn_extras_to_return = self.gnn.forward(g, extra_return_names=extra_return_names) - extras_to_return.update(gnn_extras_to_return) - if "pre_task_heads" in extra_returns: + if "pre_task_heads" in extra_return_names: extras_to_return.update({"pre_task_heads": g}) if self.task_heads is None: From 43b254f1cd097ab09fec0683ae27698a4f0ebec4 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Wed, 20 Dec 2023 11:00:05 +0000 Subject: [PATCH 4/6] Fingerprints --- graphium/cli/get_final_fingerprints.py | 198 +++++++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 graphium/cli/get_final_fingerprints.py diff --git a/graphium/cli/get_final_fingerprints.py b/graphium/cli/get_final_fingerprints.py new file mode 100644 index 000000000..32075b08f --- /dev/null +++ b/graphium/cli/get_final_fingerprints.py @@ -0,0 +1,198 @@ +from typing import List, Literal, Union +import os +import time +import timeit +from datetime import datetime + +import fsspec +import hydra +import numpy as np +import torch +import wandb +import yaml +from datamol.utils import fs +from hydra.core.hydra_config import HydraConfig +from hydra.types import RunMode +from lightning.pytorch.utilities.model_summary import ModelSummary +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from graphium.config._loader import ( + load_accelerator, + load_architecture, + load_datamodule, + load_metrics, + load_predictor, + load_trainer, + save_params_to_wandb, + get_checkpoint_path, +) +from graphium.finetuning import ( + FINETUNING_CONFIG_KEY, + GraphFinetuning, + modify_cfg_for_finetuning, +) +from graphium.hyper_param_search import ( + HYPER_PARAM_SEARCH_CONFIG_KEY, + extract_main_metric_for_hparam_search, +) +from graphium.trainer.predictor import PredictorModule +from graphium.utils.safe_run import SafeRun + +import graphium.cli.finetune_utils + +from tqdm import tqdm +from copy import deepcopy +from tdc.benchmark_group import admet_group +import datamol as dm +import sys +from torch_geometric.data import Batch +import random + +TESTING_ONLY_CONFIG_KEY = "testing_only" + + +@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main") +def cli(cfg: DictConfig) -> None: + """ + The main CLI endpoint for training, fine-tuning and evaluating Graphium models. + """ + return get_final_fingerprints(cfg) + + +def get_final_fingerprints(cfg: DictConfig) -> None: + """ + The main (pre-)training and fine-tuning loop. + """ + + # Get ADMET SMILES strings + + if not os.path.exists("saved_admet_smiles.pt"): + admet = admet_group(path="admet-data/") + admet_mol_ids = set() + for dn in tqdm(admet.dataset_names, desc="Getting IDs for ADMET", file=sys.stdout): + admet_mol_ids |= set(admet.get(dn)["train_val"]["Drug"].apply(dm.unique_id)) + admet_mol_ids |= set(admet.get(dn)["test"]["Drug"].apply(dm.unique_id)) + + smiles_to_process = [] + admet_mol_ids_to_find = deepcopy(admet_mol_ids) + + for dn in tqdm(admet.dataset_names, desc="Matching molecules to IDs", file=sys.stdout): + for key in ["train_val", "test"]: + train_mols = set(admet.get(dn)[key]["Drug"]) + for smiles in train_mols: + mol_id = dm.unique_id(smiles) + if mol_id in admet_mol_ids_to_find: + smiles_to_process.append(smiles) + admet_mol_ids_to_find.remove(mol_id) + + assert set(dm.unique_id(s) for s in smiles_to_process) == admet_mol_ids + torch.save(smiles_to_process, "saved_admet_smiles.pt") + else: + smiles_to_process = torch.load("saved_admet_smiles.pt") + + unresolved_cfg = OmegaConf.to_container(cfg, resolve=False) + cfg = OmegaConf.to_container(cfg, resolve=True) + + st = timeit.default_timer() + + ## == Instantiate all required objects from their respective configs == + # Accelerator + cfg, accelerator_type = load_accelerator(cfg) + assert accelerator_type == "cpu", "get_final_fingerprints script only runs on CPU for now" + + ## Data-module + datamodule = load_datamodule(cfg, accelerator_type) + + + # Featurize SMILES strings + + input_features_save_path = "input_features.pt" + idx_none_save_path = "idx_none.pt" + if not os.path.exists(input_features_save_path): + input_features, idx_none = datamodule._featurize_molecules(smiles_to_process) + + torch.save(input_features, input_features_save_path) + torch.save(idx_none, idx_none_save_path) + else: + input_features = torch.load(input_features_save_path) + + ''' + for _ in range(100): + + index = random.randint(0, len(smiles_to_process) - 1) + features_single, idx_none_single = datamodule._featurize_molecules([smiles_to_process[index]]) + + def _single_bool(val): + if isinstance(val, bool): + return val + if isinstance(val, torch.Tensor): + return val.all() + raise ValueError(f"Type {type(val)} not accounted for") + + assert all(_single_bool(features_single[0][k] == input_features[index][k]) for k in features_single[0].keys()) + + import sys; sys.exit(0) + ''' + + failures = 0 + + # Cast to FP32 + + for input_feature in tqdm(input_features, desc="Casting to FP32"): + try: + if not isinstance(input_feature, str): + for k, v in input_feature.items(): + if isinstance(v, torch.Tensor): + if v.dtype == torch.half: + input_feature[k] = v.float() + elif v.dtype == torch.int32: + input_feature[k] = v.long() + else: + failures += 1 + except Exception as e: + print(f"{input_feature = }") + raise e + + print(f"{failures = }") + + + # Load pre-trained model + predictor = PredictorModule.load_pretrained_model( + name_or_path=get_checkpoint_path(cfg), device=accelerator_type + ) + + logger.info(predictor.model) + logger.info(ModelSummary(predictor, max_depth=4)) + + batch_size = 100 + + results = [] + + + # Run the model to get fingerprints + + for index in tqdm(range(0, len(input_features), batch_size)): + batch = Batch.from_data_list(input_features[index:(index + batch_size)]) + model_fp32 = predictor.model.float() + output, extras = model_fp32.forward(batch, extra_return_names=["pre_task_heads"]) + fingerprint = extras['pre_task_heads']['graph_feat'] + results.extend([fingerprint[i] for i in range(batch_size)]) + + if index == 0: + print(fingerprint.shape) + + torch.save(results, "results.pt") + + # Generate dictionary SMILES -> fingerprint vector + smiles_to_fingerprint = dict(zip(smiles_to_process, results)) + torch.save(smiles_to_fingerprint, "smiles_to_fingerprint.pt") + + # Generate dictionary unique IDs -> fingerprint vector + ids = [dm.unique_id(smiles) for smiles in smiles_to_process] + ids_to_fingerprint = dict(zip(ids, results)) + torch.save(ids_to_fingerprint, "ids_to_fingerprint.pt") + + +if __name__ == "__main__": + cli() From c57559a95df3e41f4b89cc6aa3cd4dbdc227420b Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Fri, 5 Jan 2024 15:48:30 +0000 Subject: [PATCH 5/6] Fix RAM issue --- graphium/cli/combine_fingerprints.py | 28 ++++++++++++++++++++++++++ graphium/cli/get_final_fingerprints.py | 14 ++++++++----- 2 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 graphium/cli/combine_fingerprints.py diff --git a/graphium/cli/combine_fingerprints.py b/graphium/cli/combine_fingerprints.py new file mode 100644 index 000000000..613bd4ea7 --- /dev/null +++ b/graphium/cli/combine_fingerprints.py @@ -0,0 +1,28 @@ +import torch +from tqdm import tqdm +import datamol as dm + +input_features = torch.load("input_features.pt") +batch_size = 100 + +all_results = [] + +for i, index in tqdm(enumerate(range(0, len(input_features), batch_size))): + + results = torch.load(f'results/res-{i:04}.pt') + all_results.extend(results) + +del input_features + +torch.save(all_results, 'results/all_results.pt') + +smiles_to_process = torch.load("saved_admet_smiles.pt") + +# Generate dictionary SMILES -> fingerprint vector +smiles_to_fingerprint = dict(zip(smiles_to_process, results)) +torch.save(smiles_to_fingerprint, "results/smiles_to_fingerprint.pt") + +# Generate dictionary unique IDs -> fingerprint vector +ids = [dm.unique_id(smiles) for smiles in smiles_to_process] +ids_to_fingerprint = dict(zip(ids, results)) +torch.save(ids_to_fingerprint, "results/ids_to_fingerprint.pt") diff --git a/graphium/cli/get_final_fingerprints.py b/graphium/cli/get_final_fingerprints.py index 32075b08f..7b2004b29 100644 --- a/graphium/cli/get_final_fingerprints.py +++ b/graphium/cli/get_final_fingerprints.py @@ -70,6 +70,7 @@ def get_final_fingerprints(cfg: DictConfig) -> None: if not os.path.exists("saved_admet_smiles.pt"): admet = admet_group(path="admet-data/") admet_mol_ids = set() + #for dn in tqdm([admet.dataset_names[0]], desc="Getting IDs for ADMET", file=sys.stdout): for dn in tqdm(admet.dataset_names, desc="Getting IDs for ADMET", file=sys.stdout): admet_mol_ids |= set(admet.get(dn)["train_val"]["Drug"].apply(dm.unique_id)) admet_mol_ids |= set(admet.get(dn)["test"]["Drug"].apply(dm.unique_id)) @@ -78,6 +79,7 @@ def get_final_fingerprints(cfg: DictConfig) -> None: admet_mol_ids_to_find = deepcopy(admet_mol_ids) for dn in tqdm(admet.dataset_names, desc="Matching molecules to IDs", file=sys.stdout): + #for dn in tqdm([admet.dataset_names[0]], desc="Matching molecules to IDs", file=sys.stdout): for key in ["train_val", "test"]: train_mols = set(admet.get(dn)[key]["Drug"]) for smiles in train_mols: @@ -167,21 +169,22 @@ def _single_bool(val): batch_size = 100 - results = [] - - # Run the model to get fingerprints - for index in tqdm(range(0, len(input_features), batch_size)): + for i, index in tqdm(enumerate(range(0, len(input_features), batch_size))): batch = Batch.from_data_list(input_features[index:(index + batch_size)]) model_fp32 = predictor.model.float() output, extras = model_fp32.forward(batch, extra_return_names=["pre_task_heads"]) fingerprint = extras['pre_task_heads']['graph_feat'] - results.extend([fingerprint[i] for i in range(batch_size)]) + num_molecules = min(batch_size, fingerprint.shape[0]) + results = [fingerprint[i] for i in range(num_molecules)] + + torch.save(results, f'results/res-{i:04}.pt') if index == 0: print(fingerprint.shape) + ''' torch.save(results, "results.pt") # Generate dictionary SMILES -> fingerprint vector @@ -192,6 +195,7 @@ def _single_bool(val): ids = [dm.unique_id(smiles) for smiles in smiles_to_process] ids_to_fingerprint = dict(zip(ids, results)) torch.save(ids_to_fingerprint, "ids_to_fingerprint.pt") + ''' if __name__ == "__main__": From 5ee8bc1298eec9c6571500677429101d4e7d99df Mon Sep 17 00:00:00 2001 From: kerstink-GC Date: Tue, 12 Mar 2024 12:35:58 +0000 Subject: [PATCH 6/6] remove Gradient link from README --- README.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/README.md b/README.md index a83f7ab40..758af2cf2 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@ --- -[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/sdGggS) [![PyPI](https://img.shields.io/pypi/v/graphium)](https://pypi.org/project/graphium/) [![Conda](https://img.shields.io/conda/v/conda-forge/graphium?label=conda&color=success)](https://anaconda.org/conda-forge/graphium) [![PyPI - Downloads](https://img.shields.io/pypi/dm/graphium)](https://pypi.org/project/graphium/) @@ -34,10 +33,6 @@ A deep learning library focused on graph representation learning for real-world Visit https://graphium-docs.datamol.io/. -[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/sdGggS) - -You can try running Graphium on Graphcore IPUs for free on Gradient by clicking on the button above. - ## Installation for developers ### For CPU and GPU developers