-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NN] Add TAGCN nn.module and example #788
Changes from 29 commits
6219fe6
66541cc
8258153
21404bb
fb62f10
ba138a3
d725786
e291d6d
eab01bb
5fdc289
2e89c6f
b1af382
85630e3
874352f
b918c9b
47e468a
0769269
10e1d27
e1b4864
4bfe71c
a91b1bb
475c0c3
edf6a0e
ac8f6e4
e3b9dc6
2b90768
5d1c200
0c3fcb1
603e1c7
d263f98
4cf4ec1
c22e913
18a16aa
3883e2f
5dce579
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
Topology Adaptive Graph Convolutional networks (TAGCN) | ||
============ | ||
|
||
- Paper link: [https://arxiv.org/abs/1710.10370](https://arxiv.org/abs/1710.10370) | ||
|
||
Dependencies | ||
------------ | ||
- PyTorch 0.4.1+ | ||
- requests | ||
|
||
``bash | ||
pip install torch requests | ||
`` | ||
|
||
Results | ||
------- | ||
Run with following (available dataset: "cora", "citeseer", "pubmed") | ||
```bash | ||
python3 train.py --dataset cora --gpu 0 --self-loop | ||
``` | ||
|
||
* cora: ~0.812 (0.804-0.823) (paper: 0.833) | ||
* citeseer: ~0.715 (paper: 0.714) | ||
* pubmed: ~0.790 (paper: 0.811) | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""GCN using DGL nn package | ||
|
||
References: | ||
- Semi-Supervised Classification with Graph Convolutional Networks | ||
- Paper: https://arxiv.org/abs/1609.02907 | ||
- Code: https://github.com/tkipf/gcn | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
from dgl.nn.pytorch.conv import TGConv | ||
|
||
class TAGCN(nn.Module): | ||
def __init__(self, | ||
g, | ||
in_feats, | ||
n_hidden, | ||
n_classes, | ||
n_layers, | ||
activation, | ||
dropout): | ||
super(TAGCN, self).__init__() | ||
self.g = g | ||
self.layers = nn.ModuleList() | ||
# input layer | ||
self.layers.append(TGConv(in_feats, n_hidden, activation=activation)) | ||
# hidden layers | ||
for i in range(n_layers - 1): | ||
self.layers.append(TGConv(n_hidden, n_hidden, activation=activation)) | ||
# output layer | ||
self.layers.append(TGConv(n_hidden, n_classes)) #activation=None | ||
self.dropout = nn.Dropout(p=dropout) | ||
|
||
def forward(self, features): | ||
h = features | ||
for i, layer in enumerate(self.layers): | ||
if i != 0: | ||
h = self.dropout(h) | ||
h = layer(h, self.g) | ||
return h |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import argparse, time | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from dgl import DGLGraph | ||
from dgl.data import register_data_args, load_data | ||
|
||
from tagcn import TAGCN | ||
|
||
#from gcn import GCN | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove these comments |
||
#from gcn_mp import GCN | ||
#from gcn_spmv import GCN | ||
|
||
def evaluate(model, features, labels, mask): | ||
model.eval() | ||
with torch.no_grad(): | ||
logits = model(features) | ||
logits = logits[mask] | ||
labels = labels[mask] | ||
_, indices = torch.max(logits, dim=1) | ||
correct = torch.sum(indices == labels) | ||
return correct.item() * 1.0 / len(labels) | ||
|
||
def main(args): | ||
# load and preprocess dataset | ||
data = load_data(args) | ||
features = torch.FloatTensor(data.features) | ||
labels = torch.LongTensor(data.labels) | ||
train_mask = torch.ByteTensor(data.train_mask) | ||
val_mask = torch.ByteTensor(data.val_mask) | ||
test_mask = torch.ByteTensor(data.test_mask) | ||
in_feats = features.shape[1] | ||
n_classes = data.num_labels | ||
n_edges = data.graph.number_of_edges() | ||
print("""----Data statistics------' | ||
#Edges %d | ||
#Classes %d | ||
#Train samples %d | ||
#Val samples %d | ||
#Test samples %d""" % | ||
(n_edges, n_classes, | ||
train_mask.sum().item(), | ||
val_mask.sum().item(), | ||
test_mask.sum().item())) | ||
|
||
if args.gpu < 0: | ||
cuda = False | ||
else: | ||
cuda = True | ||
torch.cuda.set_device(args.gpu) | ||
features = features.cuda() | ||
labels = labels.cuda() | ||
train_mask = train_mask.cuda() | ||
val_mask = val_mask.cuda() | ||
test_mask = test_mask.cuda() | ||
|
||
# graph preprocess and calculate normalization factor | ||
g = data.graph | ||
# add self loop | ||
if args.self_loop: | ||
g.remove_edges_from(g.selfloop_edges()) | ||
g.add_edges_from(zip(g.nodes(), g.nodes())) | ||
g = DGLGraph(g) | ||
n_edges = g.number_of_edges() | ||
|
||
# create GCN model | ||
model = TAGCN(g, | ||
in_feats, | ||
args.n_hidden, | ||
n_classes, | ||
args.n_layers, | ||
F.relu, | ||
args.dropout) | ||
|
||
if cuda: | ||
model.cuda() | ||
loss_fcn = torch.nn.CrossEntropyLoss() | ||
|
||
# use optimizer | ||
optimizer = torch.optim.Adam(model.parameters(), | ||
lr=args.lr, | ||
weight_decay=args.weight_decay) | ||
|
||
# initialize graph | ||
dur = [] | ||
for epoch in range(args.n_epochs): | ||
model.train() | ||
if epoch >= 3: | ||
t0 = time.time() | ||
# forward | ||
logits = model(features) | ||
loss = loss_fcn(logits[train_mask], labels[train_mask]) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if epoch >= 3: | ||
dur.append(time.time() - t0) | ||
|
||
acc = evaluate(model, features, labels, val_mask) | ||
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " | ||
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), | ||
acc, n_edges / np.mean(dur) / 1000)) | ||
|
||
print() | ||
acc = evaluate(model, features, labels, test_mask) | ||
print("Test Accuracy {:.4f}".format(acc)) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='GCN') | ||
register_data_args(parser) | ||
parser.add_argument("--dropout", type=float, default=0.5, | ||
help="dropout probability") | ||
parser.add_argument("--gpu", type=int, default=-1, | ||
help="gpu") | ||
parser.add_argument("--lr", type=float, default=1e-2, | ||
help="learning rate") | ||
parser.add_argument("--n-epochs", type=int, default=200, | ||
help="number of training epochs") | ||
parser.add_argument("--n-hidden", type=int, default=16, | ||
help="number of hidden tagcn units") | ||
parser.add_argument("--n-layers", type=int, default=1, | ||
help="number of hidden tagcn layers") | ||
parser.add_argument("--weight-decay", type=float, default=5e-4, | ||
help="Weight for L2 loss") | ||
parser.add_argument("--self-loop", action='store_true', | ||
help="graph self-loop (default=False)") | ||
parser.set_defaults(self_loop=False) | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
from ... import function as fn | ||
|
||
__all__ = ['GraphConv'] | ||
__all__ = ['GraphConv', 'TGConv'] | ||
|
||
class GraphConv(nn.Module): | ||
r"""Apply graph convolution over an input signal. | ||
|
@@ -148,3 +148,92 @@ def extra_repr(self): | |
if '_activation' in self.__dict__: | ||
summary += ', activation={_activation}' | ||
return summary.format(**self.__dict__) | ||
|
||
class TGConv(nn.Module): | ||
r"""Apply Topology Adaptive Graph Convolutional Network | ||
|
||
.. math:: | ||
\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A} | ||
\mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k}, | ||
|
||
where :math:`\mathbf{A}` denotes the adjacency matrix and | ||
:math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix. | ||
|
||
Parameters | ||
---------- | ||
in_feats : int | ||
Number of input features. | ||
out_feats : int | ||
Number of output features. | ||
K: int, optional | ||
yzh119 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Number of hops :math: `K`. (default: 3) | ||
bias: bool, optional | ||
If True, adds a learnable bias to the output. Default: ``True``. | ||
activation: callable activation function/layer or None, optional | ||
If not None, applies an activation function to the updated node features. | ||
Default: ``None``. | ||
|
||
Attributes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Describe attributes here. |
||
---------- | ||
""" | ||
def __init__(self, | ||
in_feats, | ||
out_feats, | ||
k=2, | ||
bias=True, | ||
activation=None): | ||
super(TGConv, self).__init__() | ||
self._in_feats = in_feats | ||
self._out_feats = out_feats | ||
self._k = k | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove unnecessary empty lines. |
||
self.lin = nn.Linear(in_feats * (self._k + 1), out_feats, bias=bias) | ||
|
||
self.reset_parameters() | ||
|
||
self._activation = activation | ||
|
||
def reset_parameters(self): | ||
"""Reinitialize learnable parameters.""" | ||
self.lin.reset_parameters() | ||
|
||
def forward(self, feat, graph): | ||
r"""Compute graph convolution | ||
|
||
Parameters | ||
---------- | ||
feat : torch.Tensor | ||
The input feature | ||
graph : DGLGraph | ||
The graph. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
The output feature | ||
""" | ||
graph = graph.local_var() | ||
|
||
norm = th.pow(graph.in_degrees().float(), -0.5) | ||
shp = norm.shape + (1,) * (feat.dim() - 1) | ||
norm = th.reshape(norm, shp).to(feat.device) | ||
|
||
#D-1/2 A D -1/2 X | ||
fstack = [feat] | ||
for _ in range(self._k): | ||
|
||
rst = fstack[-1] * norm | ||
graph.ndata['h'] = rst | ||
|
||
graph.update_all(fn.copy_src(src='h', out='m'), | ||
fn.sum(msg='m', out='h')) | ||
rst = graph.ndata['h'] | ||
rst = rst * norm | ||
fstack.append(rst) | ||
|
||
rst = self.lin(th.cat(fstack, dim=-1)) | ||
|
||
if self._activation is not None: | ||
rst = self._activation(rst) | ||
|
||
return rst |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try tuning some hyper-parameters like weight_decay and see if the result is still lower then the number in TAGCN paper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried weight_decay and lr for pubmed dataset, there is minor performance improvement.
The new test result was updated.