Skip to content

Exploring and visualizing limitations of message-passing paradigm for GNNs. πŸ“‰

Notifications You must be signed in to change notification settings

camille-004/gnn-long-range

Repository files navigation

gnn-long-range

This project contains customizable baseline SoTA message passing graph neural network (GNN) architectures. Training functionality is provided by pytorch_lightning, and logging by Weights & Biases. It also consists of experiments attempting to visualize assessments of oversmoothing (Rayleigh constant, Dirichlet energy) and oversquashing (embedding Jacobian) for node classification. The main purpose of this repository is to explore the limitations of MPNNs for long-range interactions.

Project Structure

β”œβ”€β”€ config                 # --> Contains all config files.
β”œβ”€β”€ logs                   # --> Contains results.csv from experiments.
β”œβ”€β”€ notebooks              # --> Notebooks, to be used for presentation.
β”œβ”€β”€ references             # --> List of papers and codebases referenced.
β”œβ”€β”€ reports                # --> Reports and figures generated by run script.
β”œβ”€β”€ scripts                # --> Run shell scripts for experiments.
└── src
    └── data
        β”œβ”€β”€ add_edges.py   # --> Custom PyG transform for data augmentation.
        └── data_module.py # --> NodeDataModule definition.
    └── models
        β”œβ”€β”€ train.py       # --> Training script, based on PyTorch Lightning's `Trainer`.
        └── utils.py       # --> Utility functions for computing oversmoothing and oversquashing metrics.
    β”œβ”€β”€ data_module.py     # --> Definitions of graph and node `LightningDataModule`s.
    β”œβ”€β”€ utils.py           # --> Utility function for loading a configuration.
    └── visualize.py       # --> Model graphs to save to reports/figures.

Prerequisites

This project is built on conda and Python 3.10.

GPU tools: These models are built using torch v1.13.0 and CUDA v11.7, and this is reflected in environment_gpu.yaml. You may change your CUDA and torch versions in environment_gpu.yaml.

To install all necessary dependencies, create a new conda environment:

conda env create -f environment_cpu.yaml  # CPU environment
conda env create -f environment_gpu.yaml  # GPU environment

In case the pip installations hang when running the above, run the following after all conda dependencies are installed.

pip install -r requirements_cpu.txt  # For CPU environment
pip install -r requirements_gpu.txt  # For GPU environment

Usage

To run the experiments from the report, simply execute scripts/run.sh. The script starts by emptying logs/results.csv and reports/figures, in which oversmoothing and oversquashing plots will be stored. Note that the results will differ slightly from those presented in the report, as the experiment has been updated to address the footnotes. Alternatively, to run your own, execute run.py with the following parameters:

  • model - Name of chosen model. gin_jk only supported by the graph classification task.
  • -e, --max_epochs - optional, Maximum number of epochs to run model, if early stopping not converged.
  • -d, --dataset - optional, Name of dataset on which to train model.
  • -a, --activation - optional, Activation function used by neural network.
  • -nh, --n_hidden_layers - optional, Number of hidden layers to include in neural network.
  • -t, --add_edges_thres - optional, Threshold, as a percentage of original edge cardinality, for amount of new random edges to add
  • --n_heads - optional, Number of heads for multi-head attention. GATs only!
  • --jk_mode - optional, Mode of jumping knowledge for graph classification gin_jk.
  • --plot_energy - optional, Plot Dirichlet energy of each layer.
  • --plot_rayleigh - optional, Plot Rayleigh quotient of each layer.
  • --plot_influence - optional, Plot up to r-th-order neighborhood influence on a random node.

You may edit any model hyperparameters, or data and training parameters in the files in the config directory. Note: If you get an empty Jacobian when getting the influence scores (resulting in an empty plot), you most likely randomly chose an isolated node. For now, a simple fix would be to change the seed in global_config.

Example

python run.py gin -d pubmed -nh 2 --plot_energy --plot_rayleigh

will log performance to logs/results.csv and save the following graphs to reports/figures:

About

Exploring and visualizing limitations of message-passing paradigm for GNNs. πŸ“‰

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published