Skip to content

Commit

Permalink
Fixed knn in sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
guochengqian committed Feb 8, 2022
1 parent 7885181 commit 6e00abd
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 15 deletions.
18 changes: 16 additions & 2 deletions examples/sem_seg_dense/architecture.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import __init__
import torch
from gcn_lib.dense import BasicConv, GraphConv2d, PlainDynBlock2d, ResDynBlock2d, DenseDynBlock2d, DenseDilatedKnnGraph
from torch.nn import Sequential as Seq
Expand Down Expand Up @@ -89,8 +90,21 @@ def forward(self, inputs):

inputs = torch.cat((pos, x), 2).transpose(1, 2).unsqueeze(-1)

# net = DGCNNSegDense().to(device)
net = DenseDeepGCN(args).to(device)
print(net)

out = net(inputs)
print(out.shape)

print(inputs.shape, out.shape)
import time
st = time.time()
runs = 1000

with torch.no_grad():
for i in range(runs):

out = net(inputs)
torch.cuda.synchronize()

print(time.time() - st)

64 changes: 63 additions & 1 deletion examples/sem_seg_sparse/architecture.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import __init__
import torch
from torch.nn import Linear as Lin
import torch_geometric as tg
from gcn_lib.sparse import MultiSeq, MLP, GraphConv, PlainDynBlock, ResDynBlock, DenseDynBlock, DilatedKnnGraph
from torch_geometric.data import Data
from torch_scatter import scatter


class SparseDeepGCN(torch.nn.Module):
Expand Down Expand Up @@ -63,6 +66,65 @@ def forward(self, data):
feats.append(self.backbone[i](feats[-1], batch)[0])
feats = torch.cat(feats, dim=1)

fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch)
# fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch)
# fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0)
fusion = scatter(self.fusion_block(feats), batch, dim=0, reduce='max')
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0)
return self.prediction(torch.cat((fusion, feats), dim=1))


if __name__ == "__main__":
import random, numpy as np, argparse
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

batch_size = 2
N = 1024
device = 'cuda'

parser = argparse.ArgumentParser(description='PyTorch implementation of Deep GCN For semantic segmentation')
parser.add_argument('--in_channels', default=9, type=int, help='input channels (default:9)')
parser.add_argument('--n_classes', default=13, type=int, help='num of segmentation classes (default:13)')
parser.add_argument('--k', default=20, type=int, help='neighbor num (default:16)')
parser.add_argument('--block', default='res', type=str, help='graph backbone block type {plain, res, dense}')
parser.add_argument('--conv', default='edge', type=str, help='graph conv layer {edge, mr}')
parser.add_argument('--act', default='relu', type=str, help='activation layer {relu, prelu, leakyrelu}')
parser.add_argument('--norm', default='batch', type=str, help='{batch, instance} normalization')
parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False')
parser.add_argument('--n_filters', default=64, type=int, help='number of channels of deep features')
parser.add_argument('--n_blocks', default=7, type=int, help='number of basic blocks')
parser.add_argument('--dropout', default=0.5, type=float, help='ratio of dropout')
parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn')
parser.add_argument('--stochastic', default=False, type=bool, help='stochastic for gcn, True or False')
args = parser.parse_args()

pos = torch.rand((batch_size*N, 3), dtype=torch.float).to(device)
x = torch.rand((batch_size*N, 6), dtype=torch.float).to(device)

data = Data()
data.pos = pos
data.x = x
data.batch = torch.arange(batch_size).unsqueeze(-1).expand(-1, N).contiguous().view(-1).contiguous()
data = data.to(device)

net = SparseDeepGCN(args).to(device)
print(net)

out = net(data)

print('out logits shape', out.shape)
import time
st = time.time()
runs = 1000

with torch.no_grad():
for i in range(runs):
# print(i)
out = net(data)
torch.cuda.synchronize()
print(time.time() - st)

2 changes: 1 addition & 1 deletion gcn_lib/dense/torch_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def dense_knn_matrix(x, k=16):
x = x.transpose(2, 1).squeeze(-1)
batch_size, n_points, n_dims = x.shape
_, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k)
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
center_idx = torch.arange(0, n_points, device=x.device).expand(batch_size, k, -1).transpose(2, 1)
return torch.stack((nn_idx, center_idx), dim=0)


Expand Down
14 changes: 3 additions & 11 deletions gcn_lib/sparse/torch_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,15 @@ def knn_matrix(x, k=16, batch=None):
x = x.view(batch_size, -1, x.shape[-1])

neg_adj = -pairwise_distance(x.detach())

_, nn_idx = torch.topk(neg_adj, k=k)
del neg_adj

n_points = x.shape[1]
start_idx = torch.arange(0, n_points*batch_size, n_points).long().view(batch_size, 1, 1)
if x.is_cuda:
start_idx = start_idx.cuda()
start_idx = torch.arange(0, n_points * batch_size, n_points, device=x.device).view(batch_size, 1, 1)
nn_idx += start_idx
del start_idx

if x.is_cuda:
torch.cuda.empty_cache()

nn_idx = nn_idx.view(1, -1)
center_idx = torch.arange(0, n_points*batch_size).repeat(k, 1).transpose(1, 0).contiguous().view(1, -1)
if x.is_cuda:
center_idx = center_idx.cuda()
center_idx = torch.arange(0, n_points*batch_size, device=x.device).expand(k, -1).transpose(1, 0).contiguous().view(1, -1)
return nn_idx, center_idx


Expand Down

0 comments on commit 6e00abd

Please sign in to comment.