Deep Graph Infomax
Deep Graph Infomax (DGI) is an unsupervised algorithm for finding representations of graphs that can be used in downstream tasks like node classification.
This is a TensorFlow implementation of DGI, based on the Graph Convolutional Network implementation by Thomas Kipf.
python setup.py install
- tensorflow (>0.12)
First train a DGI model:
python train.py --model dgi
Once the model is trained, the graph embeddings are saved as a pickle file in the
runs folder. Take note of its path (e.g.
runs/2018-11-04-164053/embeddings.p and use it to train a logistic regression model on the node classification task:
python train.py --model logreg --embeddings_path runs/2018-11-04-164053/embeddings.p
In order to use your own data, you have to provide
- an N by N adjacency matrix (N is the number of nodes),
- an N by D feature matrix (D is the number of features per node), and
- an N by E binary label matrix (E is the number of classes).
Have a look at the
load_data() function in
utils.py for an example.
In this example, we load citation network data (Cora, Citeseer or Pubmed). The original datasets can be found here: http://linqs.cs.umd.edu/projects/projects/lbc/. In our version (see
data folder) we use dataset splits provided by https://github.com/kimiyoung/planetoid (Zhilin Yang, William W. Cohen, Ruslan Salakhutdinov, Revisiting Semi-Supervised Learning with Graph Embeddings, ICML 2016).
You can specify a dataset as follows:
python train.py --dataset citeseer
(or by editing
You can choose between the following models:
dgi: Deep Graph Infomax (Petar Velicković et al., Deep Graph Infomax, 2018)
gcn: Graph convolutional network (Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016)
gcn_cheby: Chebyshev polynomial version of graph convolutional network as described in (Michaël Defferrard, Xavier Bresson, Pierre Vandergheynst, Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering, NIPS 2016)
dense: Basic multi-layer perceptron that supports sparse inputs