This repository is the official implementation of Revisiting Attention Weights as Interpretations of Message-Passing Neural Networks. In the paper, we show that GAtt provides a better way to calculate edge attribution scores from attention weights in attention-based GNNs!
Just clone the repo, and build the docker image by:
docker build <directory_to_cloned_repo> --tags <your:tags>
The code is tested in...
- Python 3.10.13
- Pytorch 2.0.1+cu117
- Pytorch Geometric 2.3.1
which should be enough to run the demos.
- Source code for GAtt
- Demos
- Demo on the Cora dataset on how to use the
get_gatt
and theget_gatt_batch
function - Demo on the BAShapes (generated from
torch_geometric.datasets.ExplainerDataset
): Visualizations of GAtt and comparison to AvgAtt - Demo on the Infection dataset (generated from the code in the original authors' repo): Visualizations of GAtt and comparison to AvgAtt
This is one of the results in the demo notebooks:
Figure (left to right)
- Ground truth explanation (blue edges) for the target node (orange node)
- Edge attribution from GAtt
- Edge attribution from AvgAtt (averaging over the layers)
The figures show that the edge attribution scores in GAtt is more aligned with the ground truth explanation edges compared to just averaging over the GAT layers.