Skip to content

Commit

Permalink
[Model] Improve GAT models (#348)
Browse files Browse the repository at this point in the history
* two better GAT implementations

* update numbers

* use version switch for spmm

* add missing dropout and output heads
  • Loading branch information
jermainewang committed Jan 11, 2019
1 parent 3a868eb commit efae0f9
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 270 deletions.
1 change: 0 additions & 1 deletion examples/mxnet/gat/gat_batch.py
@@ -1,4 +1,3 @@

"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Expand Down
45 changes: 39 additions & 6 deletions examples/pytorch/gat/README.md
Expand Up @@ -2,15 +2,48 @@ Graph Attention Networks (GAT)
============

- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)
- Author's code repo:
- Author's code repo (in Tensorflow):
[https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
- Popular pytorch implementation:
[https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).

Note that the original code is implemented with Tensorflow for the paper.
Requirements
------------
- torch v1.0: the autograd support for sparse mm is only available in v1.0.
- requests

Results
-------
```bash
pip install torch==1.0.0 requests
```

How to run
----------

Run with following:

```bash
python train.py --dataset=cora --gpu=0
```

Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
python gat.py --dataset cora --gpu 0 --num-heads 8
python train.py --dataset=citeseer --gpu=0
```

```bash
python train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001
```

Results
-------

| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) |
| ------- | ------------- | ------- | ------------------- | ------------------- |
| Cora | 84.0% | 0.0127 | 0.0982 (**7.7x**) | 0.0424 (**3.3x**) |
| Citeseer | 70.7% | 0.0123 | n/a | n/a |
| Pubmed | 78.1% | 0.0302 | n/a | n/a |

* All the accuracy numbers are obtained after 300 epochs.
* The time measures how long it takes to train one epoch.
* All time is measured on EC2 p3.2xlarge instance w/ V100 GPU.
* Baseline#1: [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
* Baseline#2: [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).
261 changes: 0 additions & 261 deletions examples/pytorch/gat/gat.py

This file was deleted.

0 comments on commit efae0f9

Please sign in to comment.