Skip to content

Commit

Permalink
Merge pull request #668 from mitchellwest/viz-mp-fix
Browse files Browse the repository at this point in the history
Segmentation_MP visualizer fix
  • Loading branch information
CCInc committed Oct 4, 2021
2 parents b78ab86 + 8bfe497 commit c0e750f
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torch_points3d/models/segmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(self, option, model_type, dataset, modules):

self.loss_names = ["loss_seg"]

self.visual_names = ["data_visual"]

def set_input(self, data, device):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
Expand All @@ -55,6 +57,9 @@ def forward(self, *args, **kwargs) -> Any:
if self.labels is not None:
self.loss_seg = F.nll_loss(self.output, self.labels, ignore_index=IGNORE_LABEL) + self.get_internal_loss()

self.data_visual = self.input
self.data_visual.y = self.labels
self.data_visual.pred = torch.max(self.output, -1)[1]
return self.output

def backward(self):
Expand Down

0 comments on commit c0e750f

Please sign in to comment.