-
Notifications
You must be signed in to change notification settings - Fork 65
/
visualization.py
1130 lines (999 loc) · 44.9 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import torch
import numpy as np
import os.path as osp
import plotly.graph_objects as go
from src.data import Data, NAG, Cluster
from src.transforms import GridSampling3D, SaveNodeIndex
from src.utils import fast_randperm, to_trimmed
from torch_scatter import scatter_mean
from src.utils.color import *
# TODO: To go further with ipwidgets :
# - https://plotly.com/python/figurewidget-app/
# - https://ipywidgets.readthedocs.io/en/stable/
def visualize_3d(
input,
keys=None,
figsize=1000,
width=None,
height=None,
class_names=None,
class_colors=None,
stuff_classes=None,
num_classes=None,
hide_void_pred=False,
voxel=-1,
max_points=50000,
point_size=3,
centroid_size=None,
error_color=None,
centroids=False,
h_edge=False,
h_edge_attr=False,
h_edge_width=None,
v_edge=False,
v_edge_width=None,
gap=None,
radius=None,
center=None,
select=None,
alpha=0.1,
alpha_super=None,
alpha_stuff=0.2,
point_symbol='circle',
centroid_symbol='circle',
colorscale='Agsunset',
**kwargs):
"""3D data interactive visualization.
:param input: `Data` or `NAG` object
:param keys: `List(str)` or `str`
By default, the following attributes will be parsed in `input`
for visualization {`pos`, `rgb`, `y`, `obj`, `semantic_pred`,
`obj_pred`}. Yet, if `input` contains other attributes that you
want to visualize, these can be passed as `keys`. This only
supports point-wise attributes stored as 1D or 2D tensors.
If the tensor contains only 1 channel, the attribute will be
represented with a grayscale colormap. If the tensor contains
2 or 3 channels, these will be represented as RGB, with
an additional all-1 channel if need be. If the tensor contains
more than 3 channels, a PCA projection to RGB will be shown. In
any case, the attribute values will be rescaled with respect to
their statistics before visualization, meaning that colors may
not compare between two different plots
:param figsize: `int`
Figure dimensions will be `(figsize, figsize/2)` if `width` and
`height` are not specified
:param width: `int`
Figure width
:param height: `int`
Figure height
:param class_names: `List(str)`
Names for point labels found in attributes `y` and
`semantic_pred`
:param class_colors: `List(List(int, int, int))`
Colors palette for point labels found in attributes `y` and
`semantic_pred`
:param stuff_classes: `List(int)`
Semantic labels of the classes considered as `stuff` for
instance and panoptic segmentation. If `y` and `obj` are found
in the point attributes, the stuff annotations will appear
accordingly. Otherwise, stuff instance labeling will appear as
any other object
:param num_classes: `int`
Number of valid classes. By convention, we assume
`y ∈ [0, num_classes-1]` are VALID LABELS, while
`y < 0` AND `y >= num_classes` ARE VOID LABELS
:param hide_void_pred: `bool`
Whether predictions on points labeled as VOID be visualized
:param voxel: `float`
Voxel size to subsample the point cloud to facilitate
visualization
:param max_points: `int`
Maximum number of points displayed to facilitate visualization
:param point_size: `int` or `float`
Size of point markers
:param centroid_size: `int` or `float`
Size of superpoint markers
:param error_color: `List(int, int, int)`
Color used to identify mis-predicted points
:param centroids: `bool`
Whether superpoint centroids should be displayed
:param h_edge: `bool`
Whether horizontal edges should be displayed (only if
`centroids=True`)
:param h_edge_attr: `bool`
Whether the edges should be colored by their features found in
`edge_attr` (only if `h_edge=True`)
:param h_edge_width: `float`
Width of the horizontal edges, if `h_edge=True`. Defaults to
`None`, in which case `point_size` will be used for the edge
width
:param v_edge: `bool`
Whether vertical edges should be displayed (only if
`centroids=True` and `gap` is not `None`)
:param v_edge_width: `float`
Width of the vertical edges, if `v_edge=True`. Defaults to
`None`, in which case `point_size` will be used for the edge
width
:param gap: `List(float, float, float)`
If `None`, the hierarchical graphs will be overlaid on the points.
If not `None`, a 3D tensor indicating the offset by which the
hierarchical graphs should be plotted
:param radius: `float`
If not `None`, only visualize a spherical sampling of the input
data, centered on `center` and with size `radius`. This option
is not compatible with `select`
:param center: `List(float, float, float)`
If `radius` is provided, only visualize a spherical sampling of
the input data, centered on `center` and with size `radius`. If
`None`, the center of the scene will be used
:param select: `Tuple(int, Tensor)`
If not `None`, will call `Data.select(*select)` or
`NAG.select(*select)` on the input data (depending on its nature)
and the coloring schemes will illustrate it. This option is not
compatible with `radius`
:param alpha: `float`
Rules the whitening of selected points, nodes and edges (only if
select is not `None`)
:param alpha_super: `float`
Rules the whitening of superpoints (only if select is not
`None`). If `None`, alpha will be used as fallback
:param alpha_stuff: `float`
Rules the whitening of stuff points (only if the input
points have `obj` and `semantic_pred` attributes, and
`stuff_classes` or `num_classes` is specified). If `None`,
`alpha` will be used as fallback
:param point_symbol: `str`
Marker symbol used for points. Must be one of
`{'circle', 'circle-open', 'square', 'square-open', 'diamond',
'diamond-open', 'cross', 'x'}`. Defaults to `'circle'`
:param centroid_symbol: `str`
Marker symbol used for centroids. Must be one of
`{'circle', 'circle-open', 'square', 'square-open', 'diamond',
'diamond-open', 'cross', 'x'}`. Defaults to `'circle'`
:param colorscale: `str`
Plotly colorscale used for coloring 1D continuous features. See
https://plotly.com/python/builtin-colorscales for options
:param kwargs
:return:
"""
# Data attributes plotted by default if found in the input
_DEFAULT_KEYS = [
'pos',
'rgb',
'y',
'semantic_pred',
'obj',
'obj_pred',
'x',
'super_sampling',
'super_index']
# assert isinstance(input, (Data, NAG))
gap = torch.tensor(gap) if gap is not None else gap
assert gap is None or gap.shape == torch.Size([3])
assert not (radius and (select is not None)), \
"Cannot use both a `radius` and `select` at once"
# We work on copies of the input data, to allow modified in this
# scope
input = input.clone().cpu()
# If the input is a simple Data object, we convert it to a NAG
input = NAG([input]) if isinstance(input, Data) else input
# If the last level of the NAG has super_index, we manually
# construct an additional Data level and append it to the NAG
if input[input.num_levels - 1].is_sub:
data_last = input[input.num_levels - 1]
sub = Cluster(
data_last.super_index, torch.arange(data_last.num_nodes),
dense=True)
obj = data_last.obj.merge(data_last.super_index)
pos = scatter_mean(data_last.pos, data_last.super_index, dim=0)
input = NAG(input.to_list() + [Data(pos=pos, sub=sub, obj=obj)])
is_nag = isinstance(input, NAG)
num_levels = input.num_levels if is_nag else 1
# Make sure alpha is in [0, 1]
alpha = max(0, min(alpha, 1))
alpha_super = max(0, min(alpha_super, 1)) if alpha_super else alpha
alpha_stuff = max(0, min(alpha_stuff, 1)) if alpha_stuff else alpha
# If `radius` is provided, we only visualize a spherical selection
# of size `radius` around the `center`
if radius is not None:
# If no `center` provided, pick the middle of the scene
if center is None:
hi = input[0].pos.max(dim=0).values
lo = input[0].pos.min(dim=0).values
center = (hi + lo) / 2
# For Z, we center on the average Z, because the middle
# value may cause empty samplings for outdoor scenes with
# some very high objects and most of the interesting stuff
# happening near the ground
center[2] = input[0].pos[:, 2].mean()
else:
center = torch.as_tensor(center).cpu()
center = center.view(1, -1)
# Create a mask on level-0 (ie points) to be used for indexing
# the NAG structure
mask = torch.where(
torch.linalg.norm(input[0].pos - center, dim=1) < radius)[0]
# Subselect the hierarchical partition based on the level-0 mask
input = input.select(0, mask)
# If `select` is provided, we will call NAG.select on the input data
# and illustrate the selected/discarded pattern in the figure
if select is not None and is_nag:
# Add an ID to the points before applying NAG.select
nag_temp = input.clone()
for i in range(nag_temp.num_levels):
nag_temp._list[i] = SaveNodeIndex()(nag_temp[i])
# Apply the selection
nag_temp = nag_temp.select(*select)
# Indicate, for each node of the hierarchical graph, whether it
# has been selected
for i in range(num_levels):
selected = torch.zeros(input[i].num_nodes, dtype=torch.bool)
selected[nag_temp[i][SaveNodeIndex.DEFAULT_KEY]] = True
input[i].selected = selected
del nag_temp, selected
elif select is not None and not is_nag:
# Add an ID to the points before applying NAG.select
data_temp = SaveNodeIndex()(Data(pos=input.pos.clone()))
# Apply the selection
data_temp = data_temp.select(select)[0]
# Indicate, for each node of the hierarchical graph, whether it
# has been selected
selected = torch.zeros(input.num_nodes, dtype=torch.bool)
selected[data_temp[SaveNodeIndex.DEFAULT_KEY]] = True
input.selected = selected
del data_temp, selected
elif is_nag:
for i in range(num_levels):
input[i].selected = torch.ones(
input[i].num_nodes, dtype=torch.bool)
else:
input.selected = torch.ones(input.num_nodes, dtype=torch.bool)
# Data_0 accounts for the lowest level of hierarchy, the points
# themselves
data_0 = input[0] if is_nag else input
# Subsample to limit the drawing time
# If the level-0 cloud needs to be voxelized or sampled, a NAG
# structure will be affected too. To maintain NAG consistency, we
# only support 'GridSampling3D' with mode='last' and random sampling
# without replacement. To keep track of the sampled points and index
# the NAG accordingly, we use 'SaveNodeIndex'
idx = torch.arange(data_0.num_points)
# If a voxel size is specified, voxelize the level-0. We first
# isolate the 'pos' and the input indices of data_0 and apply
# voxelization on this. We then recover the original grid-sampled
# points indices to be used with Data.select or NAG.select
if voxel > 0:
data_temp = SaveNodeIndex()(Data(pos=data_0.pos.clone()))
data_temp = GridSampling3D(voxel, mode='last')(data_temp)
idx = data_temp[SaveNodeIndex.DEFAULT_KEY]
del data_temp
# If the cloud is too large with respect to required 'max_points',
# sample without replacement
if idx.shape[0] > max_points:
idx = idx[fast_randperm(idx.shape[0])[:max_points]]
# If a sampling is needed, apply it to the input Data or NAG,
# depending on the structure
if idx.shape[0] < data_0.num_points:
input = input.select(0, idx) if is_nag else input.select(idx)[0]
data_0 = input[0] if is_nag else input
# Round to the cm for cleaner hover info
data_0.pos = (data_0.pos * 100).round() / 100
# Class colors initialization
if class_colors is not None and not isinstance(class_colors[0], str):
class_colors = np.asarray(class_colors)
else:
class_colors = None
# Prepare figure
width = width if width and height else figsize
height = height if width and height else int(figsize / 2)
margin = int(0.02 * min(width, height))
layout = go.Layout(
width=width,
height=height,
scene=dict(aspectmode='data', ), # preserve aspect ratio
margin=dict(l=margin, r=margin, b=margin, t=margin),
uirevision=True)
fig = go.Figure(layout=layout)
# To keep track of which trace should be seen under which mode
# (i.e. button), we build trace_modes. This is a list of dictionaries
# indicating, for each trace (list element), which mode (dict key)
# it should appear in and with which attributes (values are dict of
# parameters for plotly figure updates)
trace_modes = []
i_point_trace = 0
i_unselected_point_trace = 1
# Initialize `void_classes`
void_classes = [num_classes] if num_classes else []
# Draw a trace for position-colored 3D point cloud
mini = data_0.pos.min(dim=0).values
maxi = data_0.pos.max(dim=0).values
colors = (data_0.pos - mini) / (maxi - mini + 1e-6)
colors = rgb_to_plotly_rgb(colors)
data_0.pos_colors = colors
fig.add_trace(
go.Scatter3d(
x=data_0.pos[data_0.selected, 0],
y=data_0.pos[data_0.selected, 1],
z=data_0.pos[data_0.selected, 2],
mode='markers',
marker=dict(
symbol=point_symbol,
size=point_size,
color=colors[data_0.selected]),
hoverinfo='x+y+z+text',
hovertext=None,
showlegend=False,
visible=True, ))
trace_modes.append({
'Position RGB': {
'marker.color': colors[data_0.selected], 'hovertext': None}})
fig.add_trace(
go.Scatter3d(
x=data_0.pos[~data_0.selected, 0],
y=data_0.pos[~data_0.selected, 1],
z=data_0.pos[~data_0.selected, 2],
mode='markers',
marker=dict(
symbol=point_symbol,
size=point_size,
color=colors[~data_0.selected],
opacity=alpha),
hoverinfo='x+y+z+text',
hovertext=None,
showlegend=False,
visible=True, ))
trace_modes.append({
'Position RGB': {
'marker.color': colors[~data_0.selected], 'hovertext': None}})
# Draw a trace for RGB 3D point cloud
if data_0.rgb is not None:
colors = data_0.rgb
colors = rgb_to_plotly_rgb(colors)
data_0.rgb_colors = colors
trace_modes[i_point_trace]['RGB'] = {
'marker.color': colors[data_0.selected], 'hovertext': None}
trace_modes[i_unselected_point_trace]['RGB'] = {
'marker.color': colors[~data_0.selected], 'hovertext': None}
# Color the points with ground truth semantic labels. If labels are
# expressed as histograms, keep the most frequent one
if data_0.y is not None:
y = data_0.y
y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy()
colors = class_colors[y] if class_colors is not None \
else int_to_plotly_rgb(torch.LongTensor(y))
data_0.y_colors = colors
if class_names is None:
text = np.array([f"Class {i}" for i in range(y.max() + 1)])
else:
text = np.array([str.title(c) for c in class_names])
text = text[y]
trace_modes[i_point_trace]['Semantic'] = {
'marker.color': colors[data_0.selected],
'hovertext': text[data_0.selected]}
trace_modes[i_unselected_point_trace]['Semantic'] = {
'marker.color': colors[~data_0.selected],
'hovertext': text[~data_0.selected]}
# Color the points with predicted semantic labels. If labels are
# expressed as histograms, keep the most frequent one
if data_0.semantic_pred is not None:
pred = data_0.semantic_pred
pred = pred.argmax(1).numpy() if pred.dim() == 2 else pred.numpy()
# If the ground truth labels are available, we use them to
# identify void points in the predictions
if data_0.y is not None and hide_void_pred:
# Get the target label
y_gt = data_0.y
y_gt = y_gt.argmax(1) if y_gt.dim() == 2 else y_gt
# Create a mask over the points identifying those whose
# ground truth label is void
is_void = np.zeros(y_gt.max() + 1, dtype='bool')
for i in void_classes:
if i < is_void.shape[0]:
is_void[i] = True
is_void = is_void[y_gt]
# Set the predicted label to void if the ground truth is
# void, this avoids visualizing predictions on void
# labels
pred[is_void] = y_gt[is_void]
colors = class_colors[pred] if class_colors is not None else None
data_0.pred_colors = colors
if class_names is None:
text = np.array([f"Class {i}" for i in range(pred.max() + 1)])
else:
text = np.array([str.title(c) for c in class_names])
text = text[pred]
trace_modes[i_point_trace]['Semantic Pred.'] = {
'marker.color': colors[data_0.selected],
'hovertext': text[data_0.selected]}
trace_modes[i_unselected_point_trace]['Semantic Pred.'] = {
'marker.color': colors[~data_0.selected],
'hovertext': text[~data_0.selected]}
# Color the points with ground truth instance labels. If semantic
# labels and stuff_classes/void_classes also passed, the stuff/void
# annotations will be treated accordingly
if data_0.obj is not None and (class_names is None or data_0.y is None):
obj = data_0.obj if isinstance(data_0.obj, torch.Tensor) \
else data_0.obj.major(num_classes=num_classes)[0]
colors = int_to_plotly_rgb(obj)
data_0.obj_colors = colors
text = np.array([f"Object {o}" for o in obj])
trace_modes[i_point_trace]['Panoptic'] = {
'marker.color': colors[data_0.selected],
'hovertext': text[data_0.selected]}
trace_modes[i_unselected_point_trace]['Panoptic'] = {
'marker.color': colors[~data_0.selected],
'hovertext': text[~data_0.selected]}
elif data_0.obj is not None:
# Colors and text for thing points
obj = data_0.obj.major(num_classes=num_classes)[0]
colors_thing = int_to_plotly_rgb(obj)
text_thing = np.array([f"Object {o}" for o in obj])
# For simplicity, we merge void_classes into the stuff_classes,
# the expected behavior is the same, except that we will ensure
# that the hover text distinguishes between stuff and void
stuff_classes = stuff_classes if stuff_classes is not None else []
stuff_classes = list(set(stuff_classes).union(set(void_classes)))
# Colors and text for stuff points
y = data_0.y
y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy()
colors_stuff = class_colors[y] if class_colors is not None \
else int_to_plotly_rgb(torch.LongTensor(y))
if class_names is None:
text_stuff = np.array([
f"{'Void' if i in void_classes else 'Stuff'} - Class {i}"
for i in range(y.max() + 1)])
else:
text_stuff = np.array([
f"{'Void' if i in void_classes else 'Stuff'} - {str.title(c)}"
for i, c in enumerate(class_names)])
text_stuff = text_stuff[y]
# Apply alpha-whitening on stuff points
colors_stuff = colors_stuff.astype('float')
white = np.full((colors_stuff.shape[0], 3), 255, dtype='float')
colors_stuff = colors_stuff * alpha_stuff + white * (1 - alpha_stuff)
colors_stuff = colors_stuff.astype('int64')
# Compute mask for stuff points
stuff_classes = np.asarray([i for i in stuff_classes if i <= y.max()])
is_stuff = np.zeros(y.max() + 1, dtype='bool')
for i in stuff_classes:
if i < is_stuff.shape[0]:
is_stuff[i] = True
is_stuff = is_stuff[y]
# Merge thing and stuff colors and text
colors = colors_thing
text = text_thing
colors[is_stuff] = colors_stuff[is_stuff]
text[is_stuff] = text_stuff[is_stuff]
data_0.obj_colors = colors
# Create trace modes
trace_modes[i_point_trace]['Panoptic'] = {
'marker.color': colors[data_0.selected],
'hovertext': text[data_0.selected]}
trace_modes[i_unselected_point_trace]['Panoptic'] = {
'marker.color': colors[~data_0.selected],
'hovertext': text[~data_0.selected]}
# Color the points with predicted instance labels. If semantic
# labels and stuff_classes/void_classes also passed, the
# stuff/void predictions will be treated accordingly. This
# expects `data_0.obj_pred` to be an InstanceData object
if getattr(data_0, 'obj_pred', None) is not None and class_names is None:
obj, _, y = data_0.obj_pred.major(num_classes=num_classes)
colors = int_to_plotly_rgb(obj)
data_0.obj_pred_colors = colors
text = np.array([f"Object {o}" for o in obj])
trace_modes[i_point_trace]['Panoptic Pred.'] = {
'marker.color': colors[data_0.selected],
'hovertext': text[data_0.selected]}
trace_modes[i_unselected_point_trace]['Panoptic Pred.'] = {
'marker.color': colors[~data_0.selected],
'hovertext': text[~data_0.selected]}
elif getattr(data_0, 'obj_pred', None) is not None:
# Colors and text for thing points
obj, _, y = data_0.obj_pred.major(num_classes=num_classes)
colors_thing = int_to_plotly_rgb(obj)
text_thing = np.array([f"Object {o}" for o in obj])
# For simplicity, we merge void_classes into the stuff_classes,
# the expected behavior is the same, except that we will ensure
# that the hover text distinguishes between stuff and void
stuff_classes = stuff_classes if stuff_classes is not None else []
stuff_classes = list(set(stuff_classes).union(set(void_classes)))
# If the ground truth labels are available, we use them to
# identify void points in the predictions
if data_0.y is not None and hide_void_pred:
# Get the target label
y_gt = data_0.y
y_gt = y_gt.argmax(1) if y_gt.dim() == 2 else y_gt
# Create a mask over the points identifying those whose
# ground truth label is void
is_void = np.zeros(y_gt.max() + 1, dtype='bool')
for i in void_classes:
if i < is_void.shape[0]:
is_void[i] = True
is_void = is_void[y_gt]
# Set the predicted label to void if the ground truth is
# void, this avoids visualizing predictions on void
# labels
y[is_void] = y_gt[is_void]
# Colors and text for stuff points
colors_stuff = class_colors[y] if class_colors is not None \
else int_to_plotly_rgb(torch.LongTensor(y))
if class_names is None:
text_stuff = np.array([
f"{'Void' if i in void_classes else 'Stuff'} - Class {i}"
for i in range(y.max() + 1)])
else:
text_stuff = np.array([
f"{'Void' if i in void_classes else 'Stuff'} - {str.title(c)}"
for i, c in enumerate(class_names)])
text_stuff = text_stuff[y]
# Apply alpha-whitening on stuff points
colors_stuff = colors_stuff.astype('float')
white = np.full((colors_stuff.shape[0], 3), 255, dtype='float')
colors_stuff = colors_stuff * alpha_stuff + white * (1 - alpha_stuff)
colors_stuff = colors_stuff.astype('int64')
# Compute mask for stuff points
stuff_classes = np.asarray([i for i in stuff_classes if i <= y.max()])
is_stuff = np.zeros(y.max() + 1, dtype='bool')
for i in stuff_classes:
if i < is_stuff.shape[0]:
is_stuff[i] = True
is_stuff = is_stuff[y]
# Merge thing and stuff colors and text
colors = colors_thing
text = text_thing
colors[is_stuff] = colors_stuff[is_stuff]
text[is_stuff] = text_stuff[is_stuff]
data_0.obj_pred_colors = colors
# Create trace modes
trace_modes[i_point_trace]['Panoptic Pred.'] = {
'marker.color': colors[data_0.selected],
'hovertext': text[data_0.selected]}
trace_modes[i_unselected_point_trace]['Panoptic Pred.'] = {
'marker.color': colors[~data_0.selected],
'hovertext': text[~data_0.selected]}
# Draw a trace for 3D point cloud features
if data_0.x is not None:
colors = feats_to_plotly_rgb(
data_0.x, normalize=True, colorscale=colorscale)
data_0.x_colors = colors
trace_modes[i_point_trace]['Features 3D'] = {
'marker.color': colors[data_0.selected], 'hovertext': None}
trace_modes[i_unselected_point_trace]['Features 3D'] = {
'marker.color': colors[~data_0.selected], 'hovertext': None}
# Draw a trace for each key specified in keys. Only displays
# point-wise tensor attributes that have not already been plotted
# (ie not in `_DEFAULT_KEYS`)
if keys is None:
keys = []
elif isinstance(keys, str):
keys = [keys]
keys = [k for k in keys if k not in _DEFAULT_KEYS]
for key in keys:
val = getattr(data_0, key, None)
if (val is None or not torch.is_tensor(val)
or val.shape[0] != data_0.num_points):
continue
colors = feats_to_plotly_rgb(
val, normalize=True, colorscale=colorscale)
data_0[f"{key}_colors"] = colors
trace_modes[i_point_trace][str(key).title()] = {
'marker.color': colors[data_0.selected], 'hovertext': None}
trace_modes[i_unselected_point_trace][str(key).title()] = {
'marker.color': colors[~data_0.selected], 'hovertext': None}
# Draw a trace for 3D point cloud sampling (for sampling debugging)
if 'super_sampling' in data_0.keys:
colors = data_0.super_sampling
colors = int_to_plotly_rgb(colors)
colors[data_0.super_sampling == -1] = 230
data_0.super_sampling_colors = colors
trace_modes[i_point_trace]['Super sampling'] = {
'marker.color': colors[data_0.selected], 'hovertext': None}
trace_modes[i_unselected_point_trace]['Super sampling'] = {
'marker.color': colors[~data_0.selected], 'hovertext': None}
# Draw a trace for each cluster level
for i_level, data_i in enumerate(input if is_nag else []):
# Exit in case the Data has no 'super_index'
if not data_i.is_sub:
break
# 'Data.super_index' are expressed between levels i and i+1, but
# we need to recover the 'super_index' between level 0 and i+1,
# to draw clusters on the level-0 points. To this end, we
# compute the desired 'super_index' iteratively, with a
# bottom-up approach
if i_level == 0:
super_index = data_i.super_index
else:
super_index = data_i.super_index[super_index]
# Note that we update the 'trace_modes' 0th element here, this
# assumes only it is the trace holding all level-0 points and on
# which all other colors modes are defined
colors = int_to_plotly_rgb(super_index)
data_0[f"{i_level}_level_colors"] = colors
text = np.array([f"↑: {i}" for i in super_index])
trace_modes[i_point_trace][f"Level {i_level + 1}"] = {
'marker.color': colors[data_0.selected],
'hovertext': text[data_0.selected]}
trace_modes[i_unselected_point_trace][f"Level {i_level + 1}"] = {
'marker.color': colors[~data_0.selected],
'hovertext': text[~data_0.selected]}
# Skip to the next level if we do not need to draw the cluster
# centroids
if not centroids:
continue
# To recover centroids of the i+1 level superpoints, we either
# read them from the next NAG level or compute them using the
# level i 'super_index' indices
num_levels = input.num_levels
is_last_level = i_level == num_levels - 1
if is_last_level or input[i_level + 1].pos is None:
super_pos = scatter_mean(data_0.pos, super_index, dim=0)
else:
super_pos = input[i_level + 1].pos
# Add the gap offset, if need be
if gap is not None:
super_pos += gap * (i_level + 1)
# Round to the cm for cleaner hover info
super_pos = (super_pos * 100).round() / 100
# Save the drawing position of centroids to facilitate vertical
# edges drawing later on
input[i_level + 1].draw_pos = super_pos
# Draw the level-i+1 cluster centroids
idx_sp = torch.arange(data_i.super_index.max() + 1)
colors = int_to_plotly_rgb(idx_sp)
text = np.array([f"<b>#: {i}</b>" for i in idx_sp])
ball_size = centroid_size if centroid_size else point_size * 3
fig.add_trace(
go.Scatter3d(
x=super_pos[input[i_level + 1].selected, 0],
y=super_pos[input[i_level + 1].selected, 1],
z=super_pos[input[i_level + 1].selected, 2],
mode='markers+text',
marker=dict(
symbol=centroid_symbol,
size=ball_size,
color=colors[input[i_level + 1].selected.numpy()],
line_width=min(ball_size / 2, 2),
line_color='black'),
textposition="bottom center",
textfont=dict(size=16),
hovertext=text,
hoverinfo='x+y+z+text',
showlegend=False,
visible=gap is not None, ))
fig.add_trace(
go.Scatter3d(
x=super_pos[~input[i_level + 1].selected, 0],
y=super_pos[~input[i_level + 1].selected, 1],
z=super_pos[~input[i_level + 1].selected, 2],
mode='markers+text',
marker=dict(
symbol=centroid_symbol,
size=ball_size,
color=colors[~input[i_level + 1].selected.numpy()],
line_width=min(ball_size / 2, 2),
line_color='black',
opacity=alpha_super),
textposition="bottom center",
textfont=dict(size=16),
hovertext=text,
hoverinfo='x+y+z+text',
showlegend=False,
visible=gap is not None, ))
keys = [f"Level {i_level + 1}"] if gap is None \
else trace_modes[i_point_trace].keys()
trace_modes.append(
{k: {
'marker.color': colors[input[i_level + 1].selected.numpy()],
'hovertext': text[input[i_level + 1].selected.numpy()]}
for k in keys})
trace_modes.append(
{k: {
'marker.color': colors[~input[i_level + 1].selected.numpy()],
'hovertext': text[~input[i_level + 1].selected.numpy()]}
for k in keys})
if i_level > 0 and v_edge and gap is not None and is_nag:
# Recover the source and target positions for vertical edges
# between i_level -> i_level+1
low_pos = data_i.draw_pos[data_i.selected]
high_pos = super_pos[data_i.super_index[data_i.selected]]
# Convert into a plotly-friendly format for 3D lines
edges = np.full((low_pos.shape[0] * 3, 3), None)
edges[::3] = low_pos
edges[1::3] = high_pos
# Color the vertical edges based on the parent cluster index
# Plotly is a bit hacky with colors for 3D lines. We cannot
# directly pass individual edge colors, we must instead give
# edge color as an int corresponding to a colorscale list
# holding plotly-friendly colors
colors = data_i.super_index[data_i.selected]
colors = np.repeat(colors, 3)
n_colors = colors.max().item() + 1
edge_colorscale = int_to_plotly_rgb(torch.arange(n_colors))
edge_colorscale = [
[i / (n_colors - 1), f"rgb({x[0]}, {x[1]}, {x[2]})"]
for i, x in enumerate(edge_colorscale)]
# Since plotly 3D lines do not support opacity, we draw
# these edges as super thin to limit clutter
edge_width = 0.5 if v_edge_width is None else v_edge_width
# Draw the level i -> i+1 vertical edges. NB we only draw
# edges that are selected and do not draw the unselected
# edges. This is because plotly does not handle opacity
# on lines (yet), which means the unselected edges will tend
# to clutter the figure. For this reason we choose to simply
# not show them
fig.add_trace(
go.Scatter3d(
x=edges[:, 0],
y=edges[:, 1],
z=edges[:, 2],
mode='lines',
line=dict(
width=edge_width,
color=colors,
colorscale=edge_colorscale),
hoverinfo='skip',
showlegend=False,
visible=gap is not None, ))
# NB: at this point, trace_modes contains 'Level i+1' as its
# last key, but we do not want vertical edges to be seen
# when 'Level i+1' is selected, because it means 'Level i'
# nodes are hidden
keys = list(trace_modes[i_point_trace].keys())[:-1]
trace_modes.append({k: {} for k in keys})
# Do not draw superedges if not required or if the i+1 level
# does not have any
if not h_edge or is_last_level or not input[i_level + 1].has_edges:
continue
# Recover the superedge source and target positions
se = input[i_level + 1].edge_index
se_attr = input[i_level + 1].edge_attr
# Since we can only draw one edge direction (they would overlap
# otherwise), we can trim the graph to only keep one direction
# for each undirected edge pair. However, this requires picking
# one direction for the edge attributes to we ARBITRARILY TAKE
# THE MAX EDGE FEATURE for each undirected edge
input[i_level + 1].raise_if_edge_keys()
if se_attr is not None:
se, se_attr = to_trimmed(se, edge_attr=se_attr, reduce='max')
else:
se = to_trimmed(se)
# Recover corresponding source and target coordinates using the
# previously-computed 'super_pos' cluster centroid positions
s_pos = super_pos[se[0]].numpy()
t_pos = super_pos[se[1]].numpy()
# Convert into a plotly-friendly format for 3D lines
edges = np.full((se.shape[1] * 3, 3), None)
edges[::3] = s_pos
edges[1::3] = t_pos
if h_edge_attr and se_attr is not None:
# Recover edge features and convert them to RGB colors. NB:
# edge features are assumed to be in [0, 1] or [-1, 1].
# Since we only draw edges in one direction, we choose to
# only represent the absolute value of the features. This
# implies that features are either direction-independent or
# that the edge direction only changes the sign of the
# feature
colors = feats_to_plotly_rgb(
se_attr.abs(), normalize=True, colorscale=colorscale)
colors = np.repeat(colors, 3, axis=0)
edge_width = point_size if h_edge_width is None else h_edge_width
else:
colors = feats_to_plotly_rgb(
torch.zeros(edges.shape[0]), normalize=True, colorscale=colorscale)
edge_width = point_size if h_edge_width is None else h_edge_width
selected_edge = input[i_level + 1].selected[se].all(axis=0)
selected_edge = selected_edge.repeat_interleave(3).numpy()
# Draw the level-i+1 superedges. NB we only draw edges that are
# selected and do not draw the unselected edges. This is because
# plotly does not handle opacity on lines (yet), which means the
# unselected edges will tend to clutter the figure. For this
# reason we choose to simply not show them
fig.add_trace(
go.Scatter3d(
x=edges[selected_edge, 0],
y=edges[selected_edge, 1],
z=edges[selected_edge, 2],
mode='lines',
line=dict(
width=edge_width,
color=colors[selected_edge]),
hoverinfo='skip',
showlegend=False,
visible=gap is not None, ))
keys = [f"Level {i_level + 1}"] if gap is None \
else trace_modes[i_point_trace].keys()
trace_modes.append({k: {} for k in keys})
# Add a trace for prediction errors. NB: it is important that this
# trace is created last, as the button behavior for this one is
# particular
has_error = data_0.y is not None and data_0.semantic_pred is not None
if has_error:
# Recover prediction and ground truth and deal with potential
# histograms
y = data_0.y
y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy()
pred = data_0.semantic_pred
pred = pred.argmax(1).numpy() if pred.dim() == 2 else pred.numpy()
# Identify erroneous point indices
ignore = void_classes if void_classes else []
ignore = ignore + [-1]
indices = np.where((pred != y) & (~np.in1d(y, ignore)))[0]
# Prepare the color for erroneous points
error_color = 'red' if error_color is None \
else np.asarray[error_color].squeeze()
# Draw the erroneous points
fig.add_trace(
go.Scatter3d(
x=data_0.pos[indices, 0],
y=data_0.pos[indices, 1],
z=data_0.pos[indices, 2],
mode='markers',
marker=dict(
symbol=point_symbol,
size=int(point_size * 1.5),
color=error_color, ),
showlegend=False,
visible=False, ))
# Recover the keys for all visualization modes, as an ordered set,
# with respect to their order of first appearance
modes = list(dict.fromkeys([k for m in trace_modes for k in m.keys()]))
# Traces color for interactive point cloud coloring
def trace_update(mode):
# Prepare the output args for the figure update attributes. By
# default, all traces are non visible, with no color and no
# hover text
n_traces = len(trace_modes)
out = {
'visible': [False] * (n_traces + has_error),
'marker.color': [None] * n_traces,
'hovertext': [''] * n_traces}
# For each trace in 'trace_modes' see if it contains 'mode' and
# adapt out accordingly
for i_trace, t_modes in enumerate(trace_modes):
# The trace has no action for the mode, skip it and leave
# the default args for the trace
if mode not in t_modes:
continue
# Note that a trace will only be visible for its modes
# declared in trace_modes
out['visible'][i_trace] = True
for key, val in t_modes[mode].items():
out[key][i_trace] = val
return [out, list(range(len(trace_modes)))]
# Create the buttons that will serve for toggling trace visibility
updatemenus = [
dict(
buttons=[dict(
label=mode, method='update', args=trace_update(mode))
for mode in modes if mode.lower() != 'errors'],
pad={'r': 10, 't': 10},
showactive=True,
type='dropdown',
direction='right',
xanchor='left',
x=0.02,
yanchor='top',
y=1.02, ),]
if has_error:
updatemenus.append(
dict(
buttons=[dict(
method='restyle',
label='Semantic Errors',
visible=True,
args=[
{'visible': True, 'marker.color': error_color},
[len(trace_modes)]],
args2=[
{'visible': False,},
[len(trace_modes)]],)],