diff --git a/README.md b/README.md index a83f7ab40..758af2cf2 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@ --- -[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/sdGggS) [![PyPI](https://img.shields.io/pypi/v/graphium)](https://pypi.org/project/graphium/) [![Conda](https://img.shields.io/conda/v/conda-forge/graphium?label=conda&color=success)](https://anaconda.org/conda-forge/graphium) [![PyPI - Downloads](https://img.shields.io/pypi/dm/graphium)](https://pypi.org/project/graphium/) @@ -34,10 +33,6 @@ A deep learning library focused on graph representation learning for real-world Visit https://graphium-docs.datamol.io/. -[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/sdGggS) - -You can try running Graphium on Graphcore IPUs for free on Gradient by clicking on the button above. - ## Installation for developers ### For CPU and GPU developers diff --git a/graphium/cli/combine_fingerprints.py b/graphium/cli/combine_fingerprints.py new file mode 100644 index 000000000..613bd4ea7 --- /dev/null +++ b/graphium/cli/combine_fingerprints.py @@ -0,0 +1,28 @@ +import torch +from tqdm import tqdm +import datamol as dm + +input_features = torch.load("input_features.pt") +batch_size = 100 + +all_results = [] + +for i, index in tqdm(enumerate(range(0, len(input_features), batch_size))): + + results = torch.load(f'results/res-{i:04}.pt') + all_results.extend(results) + +del input_features + +torch.save(all_results, 'results/all_results.pt') + +smiles_to_process = torch.load("saved_admet_smiles.pt") + +# Generate dictionary SMILES -> fingerprint vector +smiles_to_fingerprint = dict(zip(smiles_to_process, results)) +torch.save(smiles_to_fingerprint, "results/smiles_to_fingerprint.pt") + +# Generate dictionary unique IDs -> fingerprint vector +ids = [dm.unique_id(smiles) for smiles in smiles_to_process] +ids_to_fingerprint = dict(zip(ids, results)) +torch.save(ids_to_fingerprint, "results/ids_to_fingerprint.pt") diff --git a/graphium/cli/get_final_fingerprints.py b/graphium/cli/get_final_fingerprints.py new file mode 100644 index 000000000..7b2004b29 --- /dev/null +++ b/graphium/cli/get_final_fingerprints.py @@ -0,0 +1,202 @@ +from typing import List, Literal, Union +import os +import time +import timeit +from datetime import datetime + +import fsspec +import hydra +import numpy as np +import torch +import wandb +import yaml +from datamol.utils import fs +from hydra.core.hydra_config import HydraConfig +from hydra.types import RunMode +from lightning.pytorch.utilities.model_summary import ModelSummary +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from graphium.config._loader import ( + load_accelerator, + load_architecture, + load_datamodule, + load_metrics, + load_predictor, + load_trainer, + save_params_to_wandb, + get_checkpoint_path, +) +from graphium.finetuning import ( + FINETUNING_CONFIG_KEY, + GraphFinetuning, + modify_cfg_for_finetuning, +) +from graphium.hyper_param_search import ( + HYPER_PARAM_SEARCH_CONFIG_KEY, + extract_main_metric_for_hparam_search, +) +from graphium.trainer.predictor import PredictorModule +from graphium.utils.safe_run import SafeRun + +import graphium.cli.finetune_utils + +from tqdm import tqdm +from copy import deepcopy +from tdc.benchmark_group import admet_group +import datamol as dm +import sys +from torch_geometric.data import Batch +import random + +TESTING_ONLY_CONFIG_KEY = "testing_only" + + +@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main") +def cli(cfg: DictConfig) -> None: + """ + The main CLI endpoint for training, fine-tuning and evaluating Graphium models. + """ + return get_final_fingerprints(cfg) + + +def get_final_fingerprints(cfg: DictConfig) -> None: + """ + The main (pre-)training and fine-tuning loop. + """ + + # Get ADMET SMILES strings + + if not os.path.exists("saved_admet_smiles.pt"): + admet = admet_group(path="admet-data/") + admet_mol_ids = set() + #for dn in tqdm([admet.dataset_names[0]], desc="Getting IDs for ADMET", file=sys.stdout): + for dn in tqdm(admet.dataset_names, desc="Getting IDs for ADMET", file=sys.stdout): + admet_mol_ids |= set(admet.get(dn)["train_val"]["Drug"].apply(dm.unique_id)) + admet_mol_ids |= set(admet.get(dn)["test"]["Drug"].apply(dm.unique_id)) + + smiles_to_process = [] + admet_mol_ids_to_find = deepcopy(admet_mol_ids) + + for dn in tqdm(admet.dataset_names, desc="Matching molecules to IDs", file=sys.stdout): + #for dn in tqdm([admet.dataset_names[0]], desc="Matching molecules to IDs", file=sys.stdout): + for key in ["train_val", "test"]: + train_mols = set(admet.get(dn)[key]["Drug"]) + for smiles in train_mols: + mol_id = dm.unique_id(smiles) + if mol_id in admet_mol_ids_to_find: + smiles_to_process.append(smiles) + admet_mol_ids_to_find.remove(mol_id) + + assert set(dm.unique_id(s) for s in smiles_to_process) == admet_mol_ids + torch.save(smiles_to_process, "saved_admet_smiles.pt") + else: + smiles_to_process = torch.load("saved_admet_smiles.pt") + + unresolved_cfg = OmegaConf.to_container(cfg, resolve=False) + cfg = OmegaConf.to_container(cfg, resolve=True) + + st = timeit.default_timer() + + ## == Instantiate all required objects from their respective configs == + # Accelerator + cfg, accelerator_type = load_accelerator(cfg) + assert accelerator_type == "cpu", "get_final_fingerprints script only runs on CPU for now" + + ## Data-module + datamodule = load_datamodule(cfg, accelerator_type) + + + # Featurize SMILES strings + + input_features_save_path = "input_features.pt" + idx_none_save_path = "idx_none.pt" + if not os.path.exists(input_features_save_path): + input_features, idx_none = datamodule._featurize_molecules(smiles_to_process) + + torch.save(input_features, input_features_save_path) + torch.save(idx_none, idx_none_save_path) + else: + input_features = torch.load(input_features_save_path) + + ''' + for _ in range(100): + + index = random.randint(0, len(smiles_to_process) - 1) + features_single, idx_none_single = datamodule._featurize_molecules([smiles_to_process[index]]) + + def _single_bool(val): + if isinstance(val, bool): + return val + if isinstance(val, torch.Tensor): + return val.all() + raise ValueError(f"Type {type(val)} not accounted for") + + assert all(_single_bool(features_single[0][k] == input_features[index][k]) for k in features_single[0].keys()) + + import sys; sys.exit(0) + ''' + + failures = 0 + + # Cast to FP32 + + for input_feature in tqdm(input_features, desc="Casting to FP32"): + try: + if not isinstance(input_feature, str): + for k, v in input_feature.items(): + if isinstance(v, torch.Tensor): + if v.dtype == torch.half: + input_feature[k] = v.float() + elif v.dtype == torch.int32: + input_feature[k] = v.long() + else: + failures += 1 + except Exception as e: + print(f"{input_feature = }") + raise e + + print(f"{failures = }") + + + # Load pre-trained model + predictor = PredictorModule.load_pretrained_model( + name_or_path=get_checkpoint_path(cfg), device=accelerator_type + ) + + logger.info(predictor.model) + logger.info(ModelSummary(predictor, max_depth=4)) + + batch_size = 100 + + # Run the model to get fingerprints + + for i, index in tqdm(enumerate(range(0, len(input_features), batch_size))): + batch = Batch.from_data_list(input_features[index:(index + batch_size)]) + model_fp32 = predictor.model.float() + output, extras = model_fp32.forward(batch, extra_return_names=["pre_task_heads"]) + fingerprint = extras['pre_task_heads']['graph_feat'] + num_molecules = min(batch_size, fingerprint.shape[0]) + results = [fingerprint[i] for i in range(num_molecules)] + + torch.save(results, f'results/res-{i:04}.pt') + + if index == 0: + print(fingerprint.shape) + + ''' + torch.save(results, "results.pt") + + # Generate dictionary SMILES -> fingerprint vector + smiles_to_fingerprint = dict(zip(smiles_to_process, results)) + torch.save(smiles_to_fingerprint, "smiles_to_fingerprint.pt") + + # Generate dictionary unique IDs -> fingerprint vector + ids = [dm.unique_id(smiles) for smiles in smiles_to_process] + ids_to_fingerprint = dict(zip(ids, results)) + torch.save(ids_to_fingerprint, "ids_to_fingerprint.pt") + ''' + + +if __name__ == "__main__": + cli() diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 6570ca492..5258530d3 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -866,7 +866,7 @@ def _virtual_node_forward( return feat, vn_feat, edge_feat - def forward(self, g: Batch) -> torch.Tensor: + def forward(self, g: Batch, extra_return_names: List[str] = []) -> torch.Tensor: r""" Apply the full graph neural network on the input graph and node features. @@ -1198,12 +1198,23 @@ def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu): f"of {model_depth}" ) + if 0 in gnn_layers_per_ipu[:-1]: + raise ValueError("Only the last IPU can have 0 GNN layers") + begin_block_layer_indices = [sum(gnn_layers_per_ipu[:i]) for i in range(1, pipeline_length)] for begin_block_layer_index, ipu_id in zip(begin_block_layer_indices, range(1, pipeline_length)): - self.gnn.layers[begin_block_layer_index] = poptorch.BeginBlock( - self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id - ) + if begin_block_layer_index < model_depth: + self.gnn.layers[begin_block_layer_index] = poptorch.BeginBlock( + self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id + ) + elif self.task_heads is not None and ipu_id == pipeline_length - 1: + self.task_heads = poptorch.BeginBlock(self.task_heads, ipu_id=ipu_id) + else: + raise ValueError( + "Invalid pipeline split, nothing to put on last IPU " + "(0 GNN layers on last IPU but no task heads)" + ) def _enable_readout_cache(self, module_filter: Optional[Union[str, List[str]]]): """ @@ -1286,7 +1297,7 @@ def create_module_map(self, level: Union[Literal["layers"], Literal["module"]] = self._module_map[module_name] = module.layers return self._module_map - def forward(self, g: Batch) -> Tensor: + def forward(self, g: Batch, extra_return_names: List[str] = None) -> Tensor: r""" Apply the pre-processing neural network, the graph neural network, and the post-processing neural network on the graph features. @@ -1341,9 +1352,29 @@ def forward(self, g: Batch) -> Tensor: e = self.pre_nn_edges.forward(e) g["edge_feat"] = e - # Run the graph neural network + # Apologies for the similar names here + extras_to_return = {} + + #g, gnn_extras_to_return = self.gnn.forward(g, extra_return_names=extra_return_names) + #extras_to_return.update(gnn_extras_to_return) g = self.gnn.forward(g) + if extra_return_names: + # Run the graph neural network + + if "pre_task_heads" in extra_return_names: + extras_to_return.update({"pre_task_heads": g}) + if self.task_heads is None: + return g, extras_to_return + + final_output, task_head_extras_to_return = self.task_heads.forward( + g, extra_return_names=extra_return_names + ) + extras_to_return.update(task_head_extras_to_return) + + return final_output, extras_to_return + + # Keep original code if no extra_return_names if self.task_heads is not None: return self.task_heads.forward(g) @@ -1851,7 +1882,7 @@ def __init__( filtered_kwargs["in_dim"] = self.graph_output_nn_kwargs[task_level]["out_dim"] self.task_heads[task_name] = FeedForwardNN(**filtered_kwargs) - def forward(self, g: Batch) -> Dict[str, torch.Tensor]: + def forward(self, g: Batch, extra_return_names: List[str] = []) -> Dict[str, torch.Tensor]: r""" forward function of the task head Parameters: @@ -1859,8 +1890,14 @@ def forward(self, g: Batch) -> Dict[str, torch.Tensor]: Returns: task_head_outputs: Return a dictionary: Dict[task_name, Tensor] """ + + extras_to_return = {} + features = {task_level: self.graph_output_nn[task_level](g) for task_level in self.task_levels} + if "task_level_features" in extra_return_names: + extras_to_return["task_level_features"] = features + task_head_outputs = {} for task_name, head in self.task_heads.items(): task_level = self.task_heads_kwargs[task_name].get( @@ -1868,6 +1905,8 @@ def forward(self, g: Batch) -> Dict[str, torch.Tensor]: ) # Get task_level without modifying head_kwargs task_head_outputs[task_name] = head.forward(features[task_level]) + if extra_return_names: + return task_head_outputs, extras_to_return return task_head_outputs def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]: