Skip to content

Commit

Permalink
Merge pull request #670 from umrlastig/tensorboard-visualizer
Browse files Browse the repository at this point in the history
Add support for Tensorbord mesh visualization + Format for PLY
  • Loading branch information
nicolas-chaulet committed Oct 11, 2021
2 parents 8ea1e06 + 8ab8780 commit 84251ec
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 32 deletions.
6 changes: 5 additions & 1 deletion conf/visualization/default.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# @package _group_
activate: False
format: "pointcloud" # image will come later
format: ["pointcloud", "tensorboard"] # image will come later
num_samples_per_epoch: 10
deterministic: True # False -> Randomly sample elements from epoch to epoch
saved_keys:
pos: [['x', 'float'], ['y', 'float'], ['z', 'float']]
y: [['l', 'float']]
pred: [['p', 'float']]
ply_format: 'binary_big_endian'
tensorboard_mesh:
label: 'y'
prediction: 'pred'
7 changes: 5 additions & 2 deletions conf/visualization/eval.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# @package _group_
activate: True
format: "pointcloud" # image will come later
format: ["pointcloud", "tensorboard"] # image will come later
num_samples_per_epoch: -1
deterministic: True # False -> Randomly sample elements from epoch to epoch
saved_keys:
pos: [['x', 'float'], ['y', 'float'], ['z', 'float']]
y: [['l', 'float']]
pred: [['p', 'float']]

ply_format: 'binary_big_endian'
tensorboard_mesh:
label: 'y'
prediction: 'pred'
2 changes: 1 addition & 1 deletion test/test_config/viz/viz_config_deterministic.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
visualization:
activate: True
format: "pointcloud" # image will come later
format: ["pointcloud"] # image will come later
num_samples_per_epoch: 2
deterministic: True # False -> Randomly sample elements from epoch to epoch
saved_keys:
Expand Down
2 changes: 1 addition & 1 deletion test/test_config/viz/viz_config_indices.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
visualization:
activate: True
format: "pointcloud" # image will come later
format: ["pointcloud"] # image will come later
num_samples_per_epoch: 2
deterministic: True # False -> Randomly sample elements from epoch to epoch
saved_keys:
Expand Down
2 changes: 1 addition & 1 deletion test/test_config/viz/viz_config_save_all.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
visualization:
activate: True
format: "pointcloud" # image will come later
format: ["pointcloud"] # image will come later
num_samples_per_epoch: -1
deterministic: True # False -> Randomly sample elements from epoch to epoch
saved_keys:
Expand Down
10 changes: 5 additions & 5 deletions test/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_empty(self):

mock_num_batches = {"train": 9, "test": 3, "val": 0}
config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_indices.yaml"))
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path, None)

for epoch in range(epochs):
run(9, visualizer, epoch, "train", data)
Expand All @@ -60,7 +60,7 @@ def test_indices(self):

mock_num_batches = {"train": 9, "test": 3, "val": 0}
config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_indices.yaml"))
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path, None)

for epoch in range(epochs):
run(9, visualizer, epoch, "train", data)
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_save_all(self):
mock_num_batches = {"train": num_samples}

config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_save_all.yaml"))
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path, None)

for epoch in range(epochs):
run(num_samples // batch_size, visualizer, epoch, "train", data)
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_pyg_data(self):
mock_num_batches = {"train": num_batches}

config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_non_deterministic.yaml"))
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path, None)

for epoch in range(epochs):
run(num_batches, visualizer, epoch, "train", data)
Expand All @@ -148,7 +148,7 @@ def test_dense_data(self):

mock_num_batches = {"train": 9, "test": 3, "val": 0}
config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_deterministic.yaml"))
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)
visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path, None)

for epoch in range(epochs):
run(9, visualizer, epoch, "train", data)
Expand Down
2 changes: 1 addition & 1 deletion torch_points3d/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _initialize_trainer(self):
self._model = self._model.to(self._device)
if self.has_visualization:
self._visualizer = Visualizer(
self._cfg.visualization, self._dataset.num_batches, self._dataset.batch_size, os.getcwd()
self._cfg.visualization, self._dataset.num_batches, self._dataset.batch_size, os.getcwd(), self._tracker
)

def train(self):
Expand Down
95 changes: 75 additions & 20 deletions torch_points3d/visualization/visualizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import torch
import numpy as np
from matplotlib.cm import get_cmap
from math import log10, ceil
from plyfile import PlyData, PlyElement
import logging

Expand All @@ -15,11 +17,11 @@ class Visualizer(object):
batch_size (int) -- Current batch size usef
save_dir (str) -- The path used by hydra to store the experiment
This class is responsible to save visual into .ply format
This class is responsible to save visual into .ply format and tensorboard mesh
The configuration looks like that:
visualization:
activate: False # Wheter to activate the visualizer
format: "pointcloud" # image will come later
format: ["pointcloud", "tensorboard"] # image will come later
num_samples_per_epoch: 2 # If negative, it will save all elements
deterministic: True # False -> Randomly sample elements from epoch to epoch
saved_keys: # Mapping from Data Object to structured numpy
Expand All @@ -28,9 +30,14 @@ class Visualizer(object):
pred: [['p', 'float']]
indices: # List of indices to be saved (support "train", "test", "val")
train: [0, 3]
ply_format: binary_big_endian # PLY format (support "binary_big_endian", "binary_little_endian", "ascii")
tensorboard_mesh: # Mapping from mesh name and propety use to color
label: 'y'
prediction: 'pred'
"""

def __init__(self, viz_conf, num_batches, batch_size, save_dir):
def __init__(self, viz_conf, num_batches, batch_size, save_dir, tracker):
# From configuration and dataset
for stage_name, stage_num_sample in num_batches.items():
setattr(self, "{}_num_batches".format(stage_name), stage_num_sample)
Expand All @@ -40,17 +47,28 @@ def __init__(self, viz_conf, num_batches, batch_size, save_dir):
self._num_samples_per_epoch = int(viz_conf.num_samples_per_epoch)
self._deterministic = viz_conf.deterministic

self._saved_keys = viz_conf.saved_keys
self._saved_keys = {}
self._tensorboard_mesh = {}

# Internal state
self._stage = None
self._current_epoch = None

# Current experiment path
self._save_dir = save_dir
self._viz_path = os.path.join(self._save_dir, "viz")
if not os.path.exists(self._viz_path):
os.makedirs(self._viz_path)
if "pointcloud" in self._format:
self._saved_keys = viz_conf.saved_keys
self._ply_format = viz_conf.ply_format if viz_conf.ply_format is not None else "binary_big_endian"

# Current experiment path
self._viz_path = os.path.join(save_dir, "viz")
if not os.path.exists(self._viz_path):
os.makedirs(self._viz_path)

if "tensorboard" in self._format:
if tracker._use_tensorboard:
self._tensorboard_mesh = viz_conf.tensorboard_mesh

# SummaryWriter for tensorboard loging
self._writer = tracker._writer

self._indices = {}
self._contains_indices = False
Expand Down Expand Up @@ -109,7 +127,7 @@ def _extract_from_PYG(self, item, pos_idx):
batch_mask = item.batch == pos_idx
out_data = {}
for k in item.keys:
if torch.is_tensor(item[k]) and k in self._saved_keys.keys():
if torch.is_tensor(item[k]) and (k in self._saved_keys.keys() or k in self._tensorboard_mesh.values()):
if item[k].shape[0] == num_samples:
out_data[k] = item[k][batch_mask]
return out_data
Expand All @@ -121,7 +139,7 @@ def _extract_from_dense(self, item, pos_idx):
num_samples = item.y.shape[0]
out_data = {}
for k in item.keys:
if torch.is_tensor(item[k]) and k in self._saved_keys.keys():
if torch.is_tensor(item[k]) and (k in self._saved_keys.keys() or k in self._tensorboard_mesh.values()):
if item[k].shape[0] == num_samples:
out_data[k] = item[k][pos_idx]
return out_data
Expand Down Expand Up @@ -149,6 +167,7 @@ def save_visuals(self, visuals):
Make sure the saved_keys within the config maps to the Data attributes.
"""
if self._stage in self._indices:
stage_num_batches = getattr(self, "{}_num_batches".format(self._stage))
batch_indices = self._indices[self._stage] // self._batch_size
pos_indices = self._indices[self._stage] % self._batch_size
for idx in np.argwhere(self._seen_batch == batch_indices).flatten():
Expand All @@ -158,14 +177,50 @@ def save_visuals(self, visuals):
out_item = self._extract_from_PYG(item, pos_idx)
else:
out_item = self._extract_from_dense(item, pos_idx)
out_item = self._dict_to_structured_npy(out_item)

dir_path = os.path.join(self._viz_path, str(self._current_epoch), self._stage)
if not os.path.exists(dir_path):
os.makedirs(dir_path)

filename = "{}_{}.ply".format(self._seen_batch, pos_idx)
path_out = os.path.join(dir_path, filename)
el = PlyElement.describe(out_item, visual_name)
PlyData([el], byte_order=">").write(path_out)
if hasattr(self, "_viz_path"):
dir_path = os.path.join(self._viz_path, str(self._current_epoch), self._stage)
if not os.path.exists(dir_path):
os.makedirs(dir_path)

filename = "{}_{}.ply".format(self._seen_batch, pos_idx)
path_out = os.path.join(dir_path, filename)

npy_array = self._dict_to_structured_npy(out_item)
el = PlyElement.describe(npy_array, visual_name)
if self._ply_format == "ascii":
PlyData([el], text=True).write(path_out)
elif self._ply_format == "binary_little_endian":
PlyData([el], byte_order="<").write(path_out)
elif self._ply_format == "binary_big_endian":
PlyData([el], byte_order=">").write(path_out)
else:
PlyData([el]).write(path_out)

if hasattr(self, "_writer"):
pos = out_item['pos'].detach().cpu().unsqueeze(0)
colors = get_cmap('tab10')
config_dict = {
"material": {
"size": 0.3
}
}

for label, k in self._tensorboard_mesh.items():
value = out_item[k].detach().cpu()

if len(value.shape) == 2 and value.shape[1] == 3:
if value.min() >= 0 and value.max() <= 1:
value = (255*value).type(torch.uint8).unsqueeze(0)
else:
value = value.type(torch.uint8).unsqueeze(0)
elif len(value.shape) == 1 and value.shape[0] == 1:
value = np.tile((255*colors(value.numpy() % 10))[:,0:3].astype(np.uint8), (pos.shape[0],1)).reshape((1,-1,3))
elif len(value.shape) == 1 or value.shape[1] == 1:
value = (255*colors(value.numpy() % 10))[:,0:3].astype(np.uint8).reshape((1,-1,3))
else:
continue

self._writer.add_mesh(self._stage + "/" + visual_name + "/" + label, pos, colors=value, config_dict=config_dict, global_step=(self._current_epoch-1)*(10**ceil(log10(stage_num_batches+1)))+self._seen_batch)

self._seen_batch += 1

0 comments on commit 84251ec

Please sign in to comment.