Skip to content

Commit

Permalink
Fixed import
Browse files Browse the repository at this point in the history
  • Loading branch information
guochengqian committed Feb 8, 2022
2 parents 6e00abd + a15c43a commit fc83d08
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
7 changes: 2 additions & 5 deletions examples/sem_seg_sparse/architecture.py
@@ -1,10 +1,9 @@
import __init__
import torch
from torch.nn import Linear as Lin
import torch_geometric as tg
from gcn_lib.sparse import MultiSeq, MLP, GraphConv, PlainDynBlock, ResDynBlock, DenseDynBlock, DilatedKnnGraph
from utils.pyg_util import scatter_
from torch_geometric.data import Data
from torch_scatter import scatter


class SparseDeepGCN(torch.nn.Module):
Expand Down Expand Up @@ -66,9 +65,7 @@ def forward(self, data):
feats.append(self.backbone[i](feats[-1], batch)[0])
feats = torch.cat(feats, dim=1)

# fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch)
# fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0)
fusion = scatter(self.fusion_block(feats), batch, dim=0, reduce='max')
fusion = scatter_('max', self.fusion_block(feats), batch)
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0)
return self.prediction(torch.cat((fusion, feats), dim=1))

Expand Down
3 changes: 2 additions & 1 deletion gcn_lib/sparse/torch_vertex.py
Expand Up @@ -5,6 +5,7 @@
from .torch_nn import MLP, act_layer, norm_layer, BondEncoder
from .torch_edge import DilatedKnnGraph
from .torch_message import GenMessagePassing, MsgNorm
from utils.pyg_util import scatter_
from torch_geometric.utils import remove_self_loops, add_self_loops


Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True,

def forward(self, x, edge_index):
""""""
x_j = tg.utils.scatter_(self.aggr, torch.index_select(x, 0, edge_index[0]) - torch.index_select(x, 0, edge_index[1]), edge_index[1], dim_size=x.shape[0])
x_j = scatter_(self.aggr, torch.index_select(x, 0, edge_index[0]) - torch.index_select(x, 0, edge_index[1]), edge_index[1], dim_size=x.shape[0])
return self.nn(torch.cat([x, x_j], dim=1))


Expand Down
35 changes: 35 additions & 0 deletions utils/pyg_util.py
@@ -0,0 +1,35 @@
import torch_scatter


def scatter_(name, src, index, dim=0, dim_size=None):
r"""Aggregates all values from the :attr:`src` tensor at the indices
specified in the :attr:`index` tensor along the first dimension.
If multiple indices reference the same location, their contributions
are aggregated according to :attr:`name` (either :obj:`"add"`,
:obj:`"mean"` or :obj:`"max"`).
Args:
name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
:obj:`"min"`, :obj:`"max"`).
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index. (default: :obj:`0`)
dim_size (int, optional): Automatically create output tensor with size
:attr:`dim_size` in the first dimension. If set to :attr:`None`, a
minimal sized output tensor is returned. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""

assert name in ['add', 'mean', 'min', 'max']

op = getattr(torch_scatter, 'scatter_{}'.format(name))
out = op(src, index, dim, None, dim_size)
out = out[0] if isinstance(out, tuple) else out

if name == 'max':
out[out < -10000] = 0
elif name == 'min':
out[out > 10000] = 0

return out

0 comments on commit fc83d08

Please sign in to comment.