This repository has been archived by the owner on Dec 15, 2022. It is now read-only.
/
ParticleNet.py
277 lines (226 loc) · 10.3 KB
/
ParticleNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import numpy as np
import torch
import torch.nn as nn
'''Based on https://github.com/WangYueFt/dgcnn/blob/master/pytorch/model.py.'''
def knn(x, k):
inner = -2 * torch.matmul(x.transpose(2, 1), x)
xx = torch.sum(x ** 2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1)
idx = pairwise_distance.topk(k=k + 1, dim=-1)[1][:, :, 1:] # (batch_size, num_points, k)
return idx
# v1 is faster on GPU
def get_graph_feature_v1(x, k, idx):
batch_size, num_dims, num_points = x.size()
idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
fts = x.transpose(2, 1).reshape(-1, num_dims) # -> (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims)
fts = fts[idx, :].view(batch_size, num_points, k, num_dims) # neighbors: -> (batch_size*num_points*k, num_dims) -> ...
fts = fts.permute(0, 3, 1, 2).contiguous() # (batch_size, num_dims, num_points, k)
x = x.view(batch_size, num_dims, num_points, 1).repeat(1, 1, 1, k)
fts = torch.cat((x, fts - x), dim=1) # ->(batch_size, 2*num_dims, num_points, k)
return fts
# v2 is faster on CPU
def get_graph_feature_v2(x, k, idx):
batch_size, num_dims, num_points = x.size()
idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
fts = x.transpose(0, 1).reshape(num_dims, -1) # -> (num_dims, batch_size, num_points) -> (num_dims, batch_size*num_points)
fts = fts[:, idx].view(num_dims, batch_size, num_points, k) # neighbors: -> (num_dims, batch_size*num_points*k) -> ...
fts = fts.transpose(1, 0).contiguous() # (batch_size, num_dims, num_points, k)
x = x.view(batch_size, num_dims, num_points, 1).repeat(1, 1, 1, k)
fts = torch.cat((x, fts - x), dim=1) # ->(batch_size, 2*num_dims, num_points, k)
return fts
class EdgeConvBlock(nn.Module):
r"""EdgeConv layer.
Introduced in "`Dynamic Graph CNN for Learning on Point Clouds
<https://arxiv.org/pdf/1801.07829>`__". Can be described as follows:
.. math::
x_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}(
\Theta \cdot (x_j^{(l)} - x_i^{(l)}) + \Phi \cdot x_i^{(l)})
where :math:`\mathcal{N}(i)` is the neighbor of :math:`i`.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
batch_norm : bool
Whether to include batch normalization on messages.
"""
def __init__(self, k, in_feat, out_feats, batch_norm=True, activation=True, cpu_mode=False):
super(EdgeConvBlock, self).__init__()
self.k = k
self.batch_norm = batch_norm
self.activation = activation
self.num_layers = len(out_feats)
self.get_graph_feature = get_graph_feature_v2 if cpu_mode else get_graph_feature_v1
self.convs = nn.ModuleList()
for i in range(self.num_layers):
self.convs.append(nn.Conv2d(2 * in_feat if i == 0 else out_feats[i - 1], out_feats[i], kernel_size=1, bias=False if self.batch_norm else True))
if batch_norm:
self.bns = nn.ModuleList()
for i in range(self.num_layers):
self.bns.append(nn.BatchNorm2d(out_feats[i]))
if activation:
self.acts = nn.ModuleList()
for i in range(self.num_layers):
self.acts.append(nn.ReLU())
if in_feat == out_feats[-1]:
self.sc = None
else:
self.sc = nn.Conv1d(in_feat, out_feats[-1], kernel_size=1, bias=False)
self.sc_bn = nn.BatchNorm1d(out_feats[-1])
if activation:
self.sc_act = nn.ReLU()
def forward(self, points, features):
topk_indices = knn(points, self.k)
x = self.get_graph_feature(features, self.k, topk_indices)
for conv, bn, act in zip(self.convs, self.bns, self.acts):
x = conv(x) # (N, C', P, K)
if bn:
x = bn(x)
if act:
x = act(x)
fts = x.mean(dim=-1) # (N, C, P)
# shortcut
if self.sc:
sc = self.sc(features) # (N, C_out, P)
sc = self.sc_bn(sc)
else:
sc = features
return self.sc_act(sc + fts) # (N, C_out, P)
class ParticleNet(nn.Module):
def __init__(self,
input_dims,
num_classes,
conv_params=[(7, (32, 32, 32)), (7, (64, 64, 64))],
fc_params=[(128, 0.1)],
use_fusion=True,
use_fts_bn=True,
use_counts=True,
for_inference=False,
for_segmentation=False,
**kwargs):
super(ParticleNet, self).__init__(**kwargs)
self.use_fts_bn = use_fts_bn
if self.use_fts_bn:
self.bn_fts = nn.BatchNorm1d(input_dims)
self.use_counts = use_counts
self.edge_convs = nn.ModuleList()
for idx, layer_param in enumerate(conv_params):
k, channels = layer_param
in_feat = input_dims if idx == 0 else conv_params[idx - 1][1][-1]
self.edge_convs.append(EdgeConvBlock(k=k, in_feat=in_feat, out_feats=channels, cpu_mode=for_inference))
self.use_fusion = use_fusion
if self.use_fusion:
in_chn = sum(x[-1] for _, x in conv_params)
out_chn = np.clip((in_chn // 128) * 128, 128, 1024)
self.fusion_block = nn.Sequential(nn.Conv1d(in_chn, out_chn, kernel_size=1, bias=False), nn.BatchNorm1d(out_chn), nn.ReLU())
self.for_segmentation = for_segmentation
fcs = []
for idx, layer_param in enumerate(fc_params):
channels, drop_rate = layer_param
if idx == 0:
in_chn = out_chn if self.use_fusion else conv_params[-1][1][-1]
else:
in_chn = fc_params[idx - 1][0]
if self.for_segmentation:
fcs.append(nn.Sequential(nn.Conv1d(in_chn, channels, kernel_size=1, bias=False),
nn.BatchNorm1d(channels), nn.ReLU(), nn.Dropout(drop_rate)))
else:
fcs.append(nn.Sequential(nn.Linear(in_chn, channels), nn.ReLU(), nn.Dropout(drop_rate)))
if self.for_segmentation:
fcs.append(nn.Conv1d(fc_params[-1][0], num_classes, kernel_size=1))
else:
fcs.append(nn.Linear(fc_params[-1][0], num_classes))
self.fc = nn.Sequential(*fcs)
self.for_inference = for_inference
def forward(self, points, features, mask=None):
# print('points:\n', points)
# print('features:\n', features)
if mask is None:
mask = (features.abs().sum(dim=1, keepdim=True) != 0) # (N, 1, P)
points *= mask
features *= mask
coord_shift = (mask == 0) * 1e9
if self.use_counts:
counts = mask.float().sum(dim=-1)
counts = torch.max(counts, torch.ones_like(counts)) # >=1
if self.use_fts_bn:
fts = self.bn_fts(features) * mask
else:
fts = features
outputs = []
for idx, conv in enumerate(self.edge_convs):
pts = (points if idx == 0 else fts) + coord_shift
fts = conv(pts, fts) * mask
if self.use_fusion:
outputs.append(fts)
if self.use_fusion:
fts = self.fusion_block(torch.cat(outputs, dim=1)) * mask
# assert(((fts.abs().sum(dim=1, keepdim=True) != 0).float() - mask.float()).abs().sum().item() == 0)
if self.for_segmentation:
x = fts
else:
if self.use_counts:
x = fts.sum(dim=-1) / counts # divide by the real counts
else:
x = fts.mean(dim=-1)
output = self.fc(x)
if self.for_inference:
output = torch.softmax(output, dim=1)
# print('output:\n', output)
return output
class FeatureConv(nn.Module):
def __init__(self, in_chn, out_chn, **kwargs):
super(FeatureConv, self).__init__(**kwargs)
self.conv = nn.Sequential(
nn.BatchNorm1d(in_chn),
nn.Conv1d(in_chn, out_chn, kernel_size=1, bias=False),
nn.BatchNorm1d(out_chn),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class ParticleNetTagger(nn.Module):
def __init__(self,
pf_features_dims,
sv_features_dims,
num_classes,
conv_params=[(7, (32, 32, 32)), (7, (64, 64, 64))],
fc_params=[(128, 0.1)],
use_fusion=True,
use_fts_bn=True,
use_counts=True,
pf_input_dropout=None,
sv_input_dropout=None,
for_inference=False,
**kwargs):
super(ParticleNetTagger, self).__init__(**kwargs)
self.pf_input_dropout = nn.Dropout(pf_input_dropout) if pf_input_dropout else None
self.sv_input_dropout = nn.Dropout(sv_input_dropout) if sv_input_dropout else None
self.pf_conv = FeatureConv(pf_features_dims, 32)
self.sv_conv = FeatureConv(sv_features_dims, 32)
self.pn = ParticleNet(input_dims=32,
num_classes=num_classes,
conv_params=conv_params,
fc_params=fc_params,
use_fusion=use_fusion,
use_fts_bn=use_fts_bn,
use_counts=use_counts,
for_inference=for_inference)
def forward(self, pf_points, pf_features, pf_mask, sv_points, sv_features, sv_mask):
if self.pf_input_dropout:
pf_mask = (self.pf_input_dropout(pf_mask) != 0).float()
pf_points *= pf_mask
pf_features *= pf_mask
if self.sv_input_dropout:
sv_mask = (self.sv_input_dropout(sv_mask) != 0).float()
sv_points *= sv_mask
sv_features *= sv_mask
points = torch.cat((pf_points, sv_points), dim=2)
features = torch.cat((self.pf_conv(pf_features * pf_mask) * pf_mask, self.sv_conv(sv_features * sv_mask) * sv_mask), dim=2)
mask = torch.cat((pf_mask, sv_mask), dim=2)
return self.pn(points, features, mask)