Skip to content

Commit

Permalink
Merge pull request #457 from datamol-io/pipeline_integration
Browse files Browse the repository at this point in the history
Pipeline integration + Virtual Nodes Edges bug fix
  • Loading branch information
s-maddrellmander committed Sep 5, 2023
2 parents 8be0f9f + daf011c commit 41a1172
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 1 deletion.
22 changes: 22 additions & 0 deletions expts/hydra-configs/accelerator/ipu_pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
type: ipu
ipu_config:
- deviceIterations(60) # IPU would require large batches to be ready for the model.
# 60 for PCQM4mv2
# 30 for largemix
- replicationFactor(4)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 96)
- Precision.enableStochasticRounding(True)

ipu_inference_config:
# set device iteration and replication factor to 1 during inference
# gradient accumulation was set to 1 in the code
- deviceIterations(60)
- replicationFactor(1)
- Precision.enableStochasticRounding(False)

accelerator_kwargs:
_accelerator: "ipu"
gnn_layers_per_ipu: [4, 4, 4, 4]
2 changes: 2 additions & 0 deletions expts/hydra-configs/model/mpnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ architecture:
attn_type: "none" # "full-attention", "none"
# biased_attention: false
attn_kwargs: null
virtual_node: 'sum'
use_virtual_edges: true
4 changes: 4 additions & 0 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def load_architecture(
graph_output_nn_kwargs=graph_output_nn_kwargs,
task_heads_kwargs=task_heads_kwargs,
)
# Get accelerator_kwargs if they exist
accelerator_kwargs = config["accelerator"].get("accelerator_kwargs", None)
if accelerator_kwargs is not None:
model_kwargs["accelerator_kwargs"] = accelerator_kwargs

if model_class is FullGraphFinetuningNetwork:
finetuning_head_kwargs = config["finetuning"].pop("finetuning_head", None)
Expand Down
2 changes: 2 additions & 0 deletions graphium/config/zinc_default_multitask_pyg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,5 @@ architecture: # The parameters for the full graph network are taken from `co
dropout: 0.2
normalization: none
residual_type: none
accelerator:
type: cpu
24 changes: 23 additions & 1 deletion graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import Tensor, nn
import torch
from torch_geometric.data import Data
from omegaconf import DictConfig, OmegaConf

# graphium imports
from graphium.data.utils import get_keys
Expand Down Expand Up @@ -593,6 +594,26 @@ def _check_bad_arguments(self):
) and not self.layer_class.layer_supports_edges:
raise ValueError(f"Cannot use edge features with class `{self.layer_class}`")

def get_nested_key(self, d, target_key):
"""
Get the value associated with a key in a nested dictionary.
Parameters:
- d: The dictionary to search in
- target_key: The key to search for
Returns:
- The value associated with the key if found, None otherwise
"""
if target_key in d:
return d[target_key]
for key, value in d.items():
if isinstance(value, (dict, DictConfig)):
nested_result = self.get_nested_key(value, target_key)
if nested_result is not None:
return nested_result
return None

def _create_layers(self):
r"""
Create all the necessary layers for the network.
Expand Down Expand Up @@ -639,7 +660,8 @@ def _create_layers(self):
this_out_dim_edges = self.full_dims_edges[ii + 1]
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
else:
this_out_dim_edges = self.layer_kwargs.get("out_dim_edges")
this_out_dim_edges = self.get_nested_key(self.layer_kwargs, "out_dim_edges")
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
layer_out_dims_edges.append(this_out_dim_edges)

# Create the GNN layer
Expand Down

0 comments on commit 41a1172

Please sign in to comment.