Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NN] Add TAGCN nn.module and example #788

Merged
merged 35 commits into from
Aug 25, 2019
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6219fe6
upd
yzh119 Aug 6, 2019
66541cc
fig edgebatch edges
yzh119 Aug 7, 2019
8258153
add test
yzh119 Aug 7, 2019
21404bb
Merge remote-tracking branch 'upstream/master' into fix
yzh119 Aug 7, 2019
fb62f10
trigger
yzh119 Aug 7, 2019
ba138a3
Merge remote-tracking branch 'upstream/master'
yzh119 Aug 12, 2019
d725786
Merge remote-tracking branch 'upstream/master'
yzh119 Aug 15, 2019
e291d6d
Update README.md for pytorch PinSage example.
Aug 16, 2019
eab01bb
Merge branch 'master' of https://github.com/classicsong/dgl
Aug 19, 2019
5fdc289
Provid a frame agnostic API to test nn modules on both CPU and CUDA s…
Aug 19, 2019
2e89c6f
Fix style
classicsong Aug 19, 2019
b1af382
Delete unused code
classicsong Aug 19, 2019
85630e3
Make agnostic test only related to tests/backend
classicsong Aug 19, 2019
874352f
Fix code style
classicsong Aug 19, 2019
b918c9b
Merge remote-tracking branch 'upstream/master'
yzh119 Aug 19, 2019
47e468a
fix
yzh119 Aug 19, 2019
0769269
Merge remote-tracking branch 'zihao/fix-nn'
classicsong Aug 19, 2019
10e1d27
doc
yzh119 Aug 19, 2019
e1b4864
Merge remote-tracking branch 'zihao/fix-nn'
classicsong Aug 19, 2019
4bfe71c
Make all test code under tests.mxnet/pytorch.test_nn.py
classicsong Aug 19, 2019
a91b1bb
Fix syntex
classicsong Aug 19, 2019
475c0c3
Merge branch 'master' into master
yzh119 Aug 21, 2019
edf6a0e
Remove rand
classicsong Aug 21, 2019
ac8f6e4
Merge branch 'master' of https://github.com/classicsong/dgl
classicsong Aug 21, 2019
e3b9dc6
Add TAGCN nn.module and example
classicsong Aug 23, 2019
2b90768
Now tagcn can run on CPU.
classicsong Aug 23, 2019
5d1c200
Add unitest for TGConv
Aug 23, 2019
0c3fcb1
Merge branch 'master' into nn
classicsong Aug 23, 2019
603e1c7
Fix style
classicsong Aug 23, 2019
d263f98
For pubmed dataset, using --lr=0.005 can achieve better acc
classicsong Aug 24, 2019
4cf4ec1
Merge branch 'master' into nn
classicsong Aug 24, 2019
c22e913
Fix style
classicsong Aug 24, 2019
18a16aa
Fix some descriptions
classicsong Aug 25, 2019
3883e2f
trigger
classicsong Aug 25, 2019
5dce579
Fix doc
classicsong Aug 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/pytorch/tagcn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Topology Adaptive Graph Convolutional networks (TAGCN)
============

- Paper link: [https://arxiv.org/abs/1710.10370](https://arxiv.org/abs/1710.10370)

Dependencies
------------
- PyTorch 0.4.1+
- requests

``bash
pip install torch requests
``

Results
-------
Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
python3 train.py --dataset cora --gpu 0 --self-loop
```

* cora: ~0.812 (0.804-0.823) (paper: 0.833)
* citeseer: ~0.715 (paper: 0.714)
* pubmed: ~0.790 (paper: 0.811)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try tuning some hyper-parameters like weight_decay and see if the result is still lower then the number in TAGCN paper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried weight_decay and lr for pubmed dataset, there is minor performance improvement.
The new test result was updated.

39 changes: 39 additions & 0 deletions examples/pytorch/tagcn/tagcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""GCN using DGL nn package

References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
"""
import torch
import torch.nn as nn
from dgl.nn.pytorch.conv import TGConv

class TAGCN(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(TAGCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(TGConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(TGConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.append(TGConv(n_hidden, n_classes)) #activation=None
self.dropout = nn.Dropout(p=dropout)

def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
return h
135 changes: 135 additions & 0 deletions examples/pytorch/tagcn/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

from tagcn import TAGCN

#from gcn import GCN
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove these comments

#from gcn_mp import GCN
#from gcn_spmv import GCN

def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)

def main(args):
# load and preprocess dataset
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().item(),
val_mask.sum().item(),
test_mask.sum().item()))

if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()

# graph preprocess and calculate normalization factor
g = data.graph
# add self loop
if args.self_loop:
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
n_edges = g.number_of_edges()

# create GCN model
model = TAGCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout)

if cuda:
model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()

# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)

# initialize graph
dur = []
for epoch in range(args.n_epochs):
model.train()
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
loss = loss_fcn(logits[train_mask], labels[train_mask])

optimizer.zero_grad()
loss.backward()
optimizer.step()

if epoch >= 3:
dur.append(time.time() - t0)

acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))

print()
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden tagcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden tagcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.set_defaults(self_loop=False)
args = parser.parse_args()
print(args)

main(args)
91 changes: 90 additions & 1 deletion python/dgl/nn/pytorch/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ... import function as fn

__all__ = ['GraphConv']
__all__ = ['GraphConv', 'TGConv']

class GraphConv(nn.Module):
r"""Apply graph convolution over an input signal.
Expand Down Expand Up @@ -148,3 +148,92 @@ def extra_repr(self):
if '_activation' in self.__dict__:
summary += ', activation={_activation}'
return summary.format(**self.__dict__)

class TGConv(nn.Module):
r"""Apply Topology Adaptive Graph Convolutional Network

.. math::
\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k},

where :math:`\mathbf{A}` denotes the adjacency matrix and
:math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix.

Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
K: int, optional
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
Number of hops :math: `K`. (default: 3)
bias: bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.

Attributes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Describe attributes here.

----------
"""
def __init__(self,
in_feats,
out_feats,
k=2,
bias=True,
activation=None):
super(TGConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._k = k

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unnecessary empty lines.

self.lin = nn.Linear(in_feats * (self._k + 1), out_feats, bias=bias)

self.reset_parameters()

self._activation = activation

def reset_parameters(self):
"""Reinitialize learnable parameters."""
self.lin.reset_parameters()

def forward(self, feat, graph):
r"""Compute graph convolution

Parameters
----------
feat : torch.Tensor
The input feature
graph : DGLGraph
The graph.

Returns
-------
torch.Tensor
The output feature
"""
graph = graph.local_var()

norm = th.pow(graph.in_degrees().float(), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)

#D-1/2 A D -1/2 X
fstack = [feat]
for _ in range(self._k):

rst = fstack[-1] * norm
graph.ndata['h'] = rst

graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.ndata['h']
rst = rst * norm
fstack.append(rst)

rst = self.lin(th.cat(fstack, dim=-1))

if self._activation is not None:
rst = self._activation(rst)

return rst
48 changes: 48 additions & 0 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,54 @@ def test_graph_conv():
new_weight = conv.weight.data
assert not F.allclose(old_weight, new_weight)

def _S2AXWb(A, N, X, W, b):
X1 = X * N
X1 = th.matmul(A, X1.view(X1.shape[0], -1))
X1 = X1 * N
X2 = X1 * N
X2 = th.matmul(A, X2.view(X2.shape[0], -1))
X2 = X2 * N
X = th.cat([X, X1, X2], dim=-1)
Y = th.matmul(X, W.rot90())

return Y + b

def test_tgconv():
g = dgl.DGLGraph(nx.path_graph(3))
ctx = F.ctx()
adj = g.adjacency_matrix(ctx=ctx)
norm = th.pow(g.in_degrees().float(), -0.5)

conv = nn.TGConv(5, 2, bias=True)
if F.gpu_ctx():
conv.cuda()
print(conv)

# test#1: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
shp = norm.shape + (1,) * (h0.dim() - 1)
norm = th.reshape(norm, shp).to(ctx)

assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

conv = nn.TGConv(5, 2)
if F.gpu_ctx():
conv.cuda()
# test#2: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0

# test rest_parameters
old_weight = deepcopy(conv.lin.weight.data)
conv.reset_parameters()
new_weight = conv.lin.weight.data
assert not F.allclose(old_weight, new_weight)

def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))

Expand Down