Skip to content

Commit

Permalink
fix #2278
Browse files Browse the repository at this point in the history
  • Loading branch information
VoVAllen committed Nov 22, 2021
1 parent 7b4b812 commit ea8b5d7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
14 changes: 5 additions & 9 deletions python/dgl/nn/pytorch/conv/ginconv.py
Expand Up @@ -76,14 +76,9 @@ def __init__(self,
super(GINConv, self).__init__()
self.apply_func = apply_func
self._aggregator_type = aggregator_type
if aggregator_type == 'sum':
self._reducer = fn.sum
elif aggregator_type == 'max':
self._reducer = fn.max
elif aggregator_type == 'mean':
self._reducer = fn.mean
else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
if aggregator_type not in ('sum', 'max', 'mean'):
raise KeyError(
'Aggregator type {} not recognized.'.format(aggregator_type))
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
Expand Down Expand Up @@ -120,6 +115,7 @@ def forward(self, graph, feat, edge_weight=None):
If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality.
"""
_reducer = getattr(fn, self._aggregator_type)
with graph.local_scope():
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
Expand All @@ -129,7 +125,7 @@ def forward(self, graph, feat, edge_weight=None):

feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
graph.update_all(aggregate_fn, _reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
Expand Down
3 changes: 2 additions & 1 deletion tests/pytorch/test_nn.py
Expand Up @@ -779,12 +779,13 @@ def test_gin_conv(g, idtype, aggregator_type):
th.nn.Linear(5, 12),
aggregator_type
)
th.save(gin, tmp_buffer)
feat = F.randn((g.number_of_src_nodes(), 5))
gin = gin.to(ctx)
h = gin(g, feat)

# test pickle
th.save(h, tmp_buffer)
th.save(gin, tmp_buffer)

assert h.shape == (g.number_of_dst_nodes(), 12)

Expand Down

0 comments on commit ea8b5d7

Please sign in to comment.