This project aims to provide re-implementations of neural networks in jax, as close to the original author's implementations as practical. Apart from different default initializations,known deviations from original implementation logic are documented.
Implementations are in their own projects. Current implementations include:
- APPNP: Approximate Personalized Propagation of Neural Predictions
- DAGNN: Deep Adaptive Graph Neural Networks
- DEQ_GCN: (Stalled WIP) Deep Equilibrium Graph Convolution Networks
- GAT: Graph Attention Networks
- GCN: Graph Convolution Networks
- GCN2: Graph Convolution Networks 2
- IGAT: (Stalled WIP) Inverse Graph Attention Networks
- igcn: Inverse Graph Convolution Networks
- pigcn: Pseudo-inverse Graph Convolution Networks
- sgc: Simple Graph Convolution Networks
See the relevant subdirectory README.md
for more details and example usage.
This project uses the following large open-source projects:
- jax for performant low-level operations;
- haiku for parameter / state management;
- optax for optimizer implementations; and
- gin for configuration.
Additional functionality is provided in smaller repositories:
- spax for sparse jax classes and operations; and
- huf: minimal framework built on top of haiku and optax.
This library is in early rapid development - things will break frequently.
After installing jax,
git clone https://github.com/jackd/grax
cd grax
pip install -r requirements.txt
pip install -e . # local install
DGL: Citations, Amazon and Coauthor
Citations datasets use dgl which will be installed with the above. You can customize where to download/extract relevant files with:
export DGL_DATA=/path/to/dgl_data_dir # otherwise uses ~/.dgl
pip install ogb
export OGB_DATA=/path/to/ogb_data_dir # otherwise uses ~/ogb
After installing:
# run a single GCN model on pubmed dataset
python -m grax grax_config/single/fit.gin gcn/config/pubmed.gin
# customize configuration
python -m grax grax_config/single/fit.gin gcn/config/pubmed.gin --bindings='
dropout_rate=0.6
seed=1
'
# perform multiple runs
python -m grax grax_config/single/fit_many.gin gcn/config/pubmed.gin
# perform multiple runs with ray
python -m grax grax_config/single/ray/fit_many.gin gcn/config/pubmed.gin
This package uses pre-commit to ensure commits meet minimum criteria. To Install, use
pip install pre-commit
pre-commit install
This will ensure git hooks are run before each commit. While it is not advised to do so, you can skip these hooks with
git commit --no-verify -m "commit message"