Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: 0 GNN layers on last IPU + Fingerprinting #488

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

---

[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/sdGggS)
[![PyPI](https://img.shields.io/pypi/v/graphium)](https://pypi.org/project/graphium/)
[![Conda](https://img.shields.io/conda/v/conda-forge/graphium?label=conda&color=success)](https://anaconda.org/conda-forge/graphium)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/graphium)](https://pypi.org/project/graphium/)
Expand Down Expand Up @@ -34,10 +33,6 @@ A deep learning library focused on graph representation learning for real-world

Visit https://graphium-docs.datamol.io/.

[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/sdGggS)

You can try running Graphium on Graphcore IPUs for free on Gradient by clicking on the button above.

## Installation for developers

### For CPU and GPU developers
Expand Down
28 changes: 28 additions & 0 deletions graphium/cli/combine_fingerprints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from tqdm import tqdm
import datamol as dm

input_features = torch.load("input_features.pt")
batch_size = 100

all_results = []

for i, index in tqdm(enumerate(range(0, len(input_features), batch_size))):

results = torch.load(f'results/res-{i:04}.pt')
all_results.extend(results)

del input_features

torch.save(all_results, 'results/all_results.pt')

smiles_to_process = torch.load("saved_admet_smiles.pt")

# Generate dictionary SMILES -> fingerprint vector
smiles_to_fingerprint = dict(zip(smiles_to_process, results))
torch.save(smiles_to_fingerprint, "results/smiles_to_fingerprint.pt")

# Generate dictionary unique IDs -> fingerprint vector
ids = [dm.unique_id(smiles) for smiles in smiles_to_process]
ids_to_fingerprint = dict(zip(ids, results))
torch.save(ids_to_fingerprint, "results/ids_to_fingerprint.pt")
202 changes: 202 additions & 0 deletions graphium/cli/get_final_fingerprints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import List, Literal, Union
import os
import time
import timeit
from datetime import datetime

import fsspec
import hydra
import numpy as np
import torch
import wandb
import yaml
from datamol.utils import fs
from hydra.core.hydra_config import HydraConfig
from hydra.types import RunMode
from lightning.pytorch.utilities.model_summary import ModelSummary
from loguru import logger
from omegaconf import DictConfig, OmegaConf

from graphium.config._loader import (
load_accelerator,
load_architecture,
load_datamodule,
load_metrics,
load_predictor,
load_trainer,
save_params_to_wandb,
get_checkpoint_path,
)
from graphium.finetuning import (
FINETUNING_CONFIG_KEY,
GraphFinetuning,
modify_cfg_for_finetuning,
)
from graphium.hyper_param_search import (
HYPER_PARAM_SEARCH_CONFIG_KEY,
extract_main_metric_for_hparam_search,
)
from graphium.trainer.predictor import PredictorModule
from graphium.utils.safe_run import SafeRun

import graphium.cli.finetune_utils

from tqdm import tqdm
from copy import deepcopy
from tdc.benchmark_group import admet_group
import datamol as dm
import sys
from torch_geometric.data import Batch
import random

TESTING_ONLY_CONFIG_KEY = "testing_only"


@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main")
def cli(cfg: DictConfig) -> None:
"""
The main CLI endpoint for training, fine-tuning and evaluating Graphium models.
"""
return get_final_fingerprints(cfg)


def get_final_fingerprints(cfg: DictConfig) -> None:
"""
The main (pre-)training and fine-tuning loop.
"""

# Get ADMET SMILES strings

if not os.path.exists("saved_admet_smiles.pt"):
admet = admet_group(path="admet-data/")
admet_mol_ids = set()
#for dn in tqdm([admet.dataset_names[0]], desc="Getting IDs for ADMET", file=sys.stdout):
for dn in tqdm(admet.dataset_names, desc="Getting IDs for ADMET", file=sys.stdout):
admet_mol_ids |= set(admet.get(dn)["train_val"]["Drug"].apply(dm.unique_id))
admet_mol_ids |= set(admet.get(dn)["test"]["Drug"].apply(dm.unique_id))

smiles_to_process = []
admet_mol_ids_to_find = deepcopy(admet_mol_ids)

for dn in tqdm(admet.dataset_names, desc="Matching molecules to IDs", file=sys.stdout):
#for dn in tqdm([admet.dataset_names[0]], desc="Matching molecules to IDs", file=sys.stdout):
for key in ["train_val", "test"]:
train_mols = set(admet.get(dn)[key]["Drug"])
for smiles in train_mols:
mol_id = dm.unique_id(smiles)
if mol_id in admet_mol_ids_to_find:
smiles_to_process.append(smiles)
admet_mol_ids_to_find.remove(mol_id)

assert set(dm.unique_id(s) for s in smiles_to_process) == admet_mol_ids
torch.save(smiles_to_process, "saved_admet_smiles.pt")
else:
smiles_to_process = torch.load("saved_admet_smiles.pt")

unresolved_cfg = OmegaConf.to_container(cfg, resolve=False)
cfg = OmegaConf.to_container(cfg, resolve=True)

st = timeit.default_timer()

## == Instantiate all required objects from their respective configs ==
# Accelerator
cfg, accelerator_type = load_accelerator(cfg)
assert accelerator_type == "cpu", "get_final_fingerprints script only runs on CPU for now"

## Data-module
datamodule = load_datamodule(cfg, accelerator_type)


# Featurize SMILES strings

input_features_save_path = "input_features.pt"
idx_none_save_path = "idx_none.pt"
if not os.path.exists(input_features_save_path):
input_features, idx_none = datamodule._featurize_molecules(smiles_to_process)

torch.save(input_features, input_features_save_path)
torch.save(idx_none, idx_none_save_path)
else:
input_features = torch.load(input_features_save_path)

'''
for _ in range(100):

index = random.randint(0, len(smiles_to_process) - 1)
features_single, idx_none_single = datamodule._featurize_molecules([smiles_to_process[index]])

def _single_bool(val):
if isinstance(val, bool):
return val
if isinstance(val, torch.Tensor):
return val.all()
raise ValueError(f"Type {type(val)} not accounted for")

assert all(_single_bool(features_single[0][k] == input_features[index][k]) for k in features_single[0].keys())

import sys; sys.exit(0)
'''

failures = 0

# Cast to FP32

for input_feature in tqdm(input_features, desc="Casting to FP32"):
try:
if not isinstance(input_feature, str):
for k, v in input_feature.items():
if isinstance(v, torch.Tensor):
if v.dtype == torch.half:
input_feature[k] = v.float()
elif v.dtype == torch.int32:
input_feature[k] = v.long()
else:
failures += 1
except Exception as e:
print(f"{input_feature = }")
raise e

print(f"{failures = }")


# Load pre-trained model
predictor = PredictorModule.load_pretrained_model(
name_or_path=get_checkpoint_path(cfg), device=accelerator_type
)

logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))

batch_size = 100

# Run the model to get fingerprints

for i, index in tqdm(enumerate(range(0, len(input_features), batch_size))):
batch = Batch.from_data_list(input_features[index:(index + batch_size)])
model_fp32 = predictor.model.float()
output, extras = model_fp32.forward(batch, extra_return_names=["pre_task_heads"])
fingerprint = extras['pre_task_heads']['graph_feat']
num_molecules = min(batch_size, fingerprint.shape[0])
results = [fingerprint[i] for i in range(num_molecules)]

torch.save(results, f'results/res-{i:04}.pt')

if index == 0:
print(fingerprint.shape)

'''
torch.save(results, "results.pt")

# Generate dictionary SMILES -> fingerprint vector
smiles_to_fingerprint = dict(zip(smiles_to_process, results))
torch.save(smiles_to_fingerprint, "smiles_to_fingerprint.pt")

# Generate dictionary unique IDs -> fingerprint vector
ids = [dm.unique_id(smiles) for smiles in smiles_to_process]
ids_to_fingerprint = dict(zip(ids, results))
torch.save(ids_to_fingerprint, "ids_to_fingerprint.pt")
'''


if __name__ == "__main__":
cli()
53 changes: 46 additions & 7 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def _virtual_node_forward(

return feat, vn_feat, edge_feat

def forward(self, g: Batch) -> torch.Tensor:
def forward(self, g: Batch, extra_return_names: List[str] = []) -> torch.Tensor:
r"""
Apply the full graph neural network on the input graph and node features.

Expand Down Expand Up @@ -1198,12 +1198,23 @@ def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu):
f"of {model_depth}"
)

if 0 in gnn_layers_per_ipu[:-1]:
raise ValueError("Only the last IPU can have 0 GNN layers")

begin_block_layer_indices = [sum(gnn_layers_per_ipu[:i]) for i in range(1, pipeline_length)]

for begin_block_layer_index, ipu_id in zip(begin_block_layer_indices, range(1, pipeline_length)):
self.gnn.layers[begin_block_layer_index] = poptorch.BeginBlock(
self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id
)
if begin_block_layer_index < model_depth:
self.gnn.layers[begin_block_layer_index] = poptorch.BeginBlock(
self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id
)
elif self.task_heads is not None and ipu_id == pipeline_length - 1:
self.task_heads = poptorch.BeginBlock(self.task_heads, ipu_id=ipu_id)
else:
raise ValueError(
"Invalid pipeline split, nothing to put on last IPU "
"(0 GNN layers on last IPU but no task heads)"
)

def _enable_readout_cache(self, module_filter: Optional[Union[str, List[str]]]):
"""
Expand Down Expand Up @@ -1286,7 +1297,7 @@ def create_module_map(self, level: Union[Literal["layers"], Literal["module"]] =
self._module_map[module_name] = module.layers
return self._module_map

def forward(self, g: Batch) -> Tensor:
def forward(self, g: Batch, extra_return_names: List[str] = None) -> Tensor:
r"""
Apply the pre-processing neural network, the graph neural network,
and the post-processing neural network on the graph features.
Expand Down Expand Up @@ -1341,9 +1352,29 @@ def forward(self, g: Batch) -> Tensor:
e = self.pre_nn_edges.forward(e)
g["edge_feat"] = e

# Run the graph neural network
# Apologies for the similar names here
extras_to_return = {}

#g, gnn_extras_to_return = self.gnn.forward(g, extra_return_names=extra_return_names)
#extras_to_return.update(gnn_extras_to_return)
g = self.gnn.forward(g)
if extra_return_names:
# Run the graph neural network

if "pre_task_heads" in extra_return_names:
extras_to_return.update({"pre_task_heads": g})

if self.task_heads is None:
return g, extras_to_return

final_output, task_head_extras_to_return = self.task_heads.forward(
g, extra_return_names=extra_return_names
)
extras_to_return.update(task_head_extras_to_return)

return final_output, extras_to_return

# Keep original code if no extra_return_names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something obvious, in which case apologies.
Aren't you missing the original forward pass in the case the extra_return_names is missing?

Suggested change
# Keep original code if no extra_return_names
# Keep original code if no extra_return_names
else:
g = self.gnn.forward(g)

My suspicion is that this will fix the failing tests

if self.task_heads is not None:
return self.task_heads.forward(g)

Expand Down Expand Up @@ -1851,23 +1882,31 @@ def __init__(
filtered_kwargs["in_dim"] = self.graph_output_nn_kwargs[task_level]["out_dim"]
self.task_heads[task_name] = FeedForwardNN(**filtered_kwargs)

def forward(self, g: Batch) -> Dict[str, torch.Tensor]:
def forward(self, g: Batch, extra_return_names: List[str] = []) -> Dict[str, torch.Tensor]:
r"""
forward function of the task head
Parameters:
g: pyg Batch graph
Returns:
task_head_outputs: Return a dictionary: Dict[task_name, Tensor]
"""

extras_to_return = {}

features = {task_level: self.graph_output_nn[task_level](g) for task_level in self.task_levels}

if "task_level_features" in extra_return_names:
extras_to_return["task_level_features"] = features

task_head_outputs = {}
for task_name, head in self.task_heads.items():
task_level = self.task_heads_kwargs[task_name].get(
"task_level", None
) # Get task_level without modifying head_kwargs
task_head_outputs[task_name] = head.forward(features[task_level])

if extra_return_names:
return task_head_outputs, extras_to_return
return task_head_outputs

def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
Expand Down