This repository contains Python scripts to perform the Graph Neural Network (GNN) emulation experiments as described in the paper Emulation of Cardiac Mechanics using Graph Neural Networks. Please cite the paper if you use this code.
Experiments were performed with Python
version 3.9.7, JAX
version 0.3.16 and Flax
version 0.3.6. The module pytest is required to run the test file models_test.py
, while tensorboard is also required to monitor training. To set up a virtual environment using conda, run the following commands in sequence once the repo has been cloned:
conda create --name gnnEmulEnv python=3.9.7
conda activate gnnEmulEnv
pip install "jax[cuda]==0.3.16" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade -r requirements.txt
Note if this error message appears when running experiments "ImportError: cannot import name 'isin' from 'jax._src.numpy.lax_numpy'", it can be resolved by changing the import on line 31 of ~/anaconda3/envs/gnnEmulEnv/lib/python3.9/site-packages/flax/linen/module.py
to
from jax.numpy import isin
The error occurs because isin
is no longer part of the private jax._src.numpy.lax_numpy
submodule.
The beam data is included in the repository, inside data/beamData
. A GNN emulator with K=2 message passing steps for the beam data can be trained for 300 epochs as follows:
python -m main --mode="train" --n_epochs=300 --K=2 --data_path="beamData" --n_shape_coeff=2
The trained emulator can then be used to predict using the same command as above with "train"
replaced with "evaluate"
.
The varying-geometry LV emulation data set is too large to be included in the repository - an external download link is available here. Assuming the data has been downloaded to data/lvData
, training can be performed as:
python -m main --mode="train" --n_epochs=3000 --K=5 --data_path="lvData" --n_shape_coeff=32
Again, the trained emulator can then be used to predict using the same command as above with "train"
replaced with "evaluate"
.
The emulation dataset for the fixed LV geometry is included in the repository, inside data/lvDataFixedGeom
. Training can be performed as follows:
python -m main --mode="train" --n_epochs=1000 --K=5 --lr=1e-5 --data_path="lvDataFixedGeom" --fixed_geom=True --n_shape_coeff=32 --trained_params_dir="emulationResults/trainedParameters/lvData/"
Note how we initialise training based on the pre-trained parameters from the varying LV geometry data (/lvData
). Once training is complete, the emulator can be used to predict on the test data as follows:
python -m main --mode="evaluate" --n_epochs=600 --K=5 --lr=1e-5 --data_path="lvDataFixedGeom" --fixed_geom=True --n_shape_coeff=32
Trained GNN parameters for each of the three datasets are stored in emulationResults/trainedParameters
. Detailed instructions on how to to use these parameters to replicate paper results in given in PAPER_REPLICATION.md
Tensorboard can be used to monitor training, for example for the beam data by running:
tensorboard --logdir=emulationResults/beamData
and then following the instructions printed to the console.
Tests for the GNN implementation in models.py
can be run as follows:
pytest models_test.py -v
- Optimisation: is performed using
flax.optim
, which is now deprecated in favour ofOptax
- Batching: a batch size of one is used in all examples - the code can easily extended to larger batches by "stacking" the graphs of multiple data points into one large graph (and shifting the sender/receiver indices accordingly), on which the existing emulators defined in
models.py
can then be applied - Applying to other datasets: the emulation framework can be applied to other datasets beyond those provided in this repository. See the file
DATA_FORMAT_REQUIREMENTS.md
inside the subdirectory/data
for details of the required data format
Main script for training and evaluating emulators
Implements DeepGraphEmulator GNN emulation architecture
Contains a data loader utility class
Contains utility functions for emulator training and evaluation
Contains the packages in addition to JAX
that are required for experiments to be run
Details how the results from the paper can be reproduced
Stores simulation data for training and testing of the GNN emulator. Also contains scripts to process raw simulation data into the augmented graph format described in the manuscript: see the README.md
file inside /data
for more details.
Stores the trained emulator parameters and predictions.