In [None]:
import numpy as np
import torch
from traversability_estimation.utils import show_cloud, normalize, create_model
from traversability_estimation.datasets import TraversabilityDataset
from traversability_estimation.segmentation import filter_grid, filter_range
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

In [None]:
path = '/home/ruslan/data/bags/traversability/marv/ugv_2022-08-12-15-18-34_trav/os_cloud_node/destaggered_points/'
assert os.path.exists(path)
ds = TraversabilityDataset(path)

# visualize a sample from the data set
# for i in np.random.choice(range(len(ds)), 1):
#     _ = ds.__getitem__(i, visualize=True)

In [None]:
len(ds)

In [None]:
i = np.random.choice(range(len(ds)))
sample = ds[i]
depth, label, points = sample

In [None]:
depth.shape, label.shape, points.shape

In [None]:
label_filtered = label.squeeze().reshape((-1,))
points_filtered = points.copy()

min_dist, max_dist = 1.0, 10.0
_, mask = filter_range(points, min_dist, max_dist, return_mask=True)
points_filtered = points_filtered[mask]
label_filtered = label_filtered[mask]

grid_res = 0.5
_, mask = filter_grid(points_filtered, grid_res, return_mask=True)
points_filtered = points_filtered[mask]
label_filtered = label_filtered[mask]

# remove points which are above the robot
h_max = 0.5
mask = points_filtered[:, 2] <= h_max
points_filtered = points_filtered[mask]
label_filtered = label_filtered[mask]

In [None]:
H = W = round(2 * max_dist / grid_res)

height_map = np.full((H, W), 0.0)
trav_map = np.full((H, W), np.nan)

height_map.shape

In [None]:
ids = (points_filtered[:, :2] - np.array([[points_filtered[:, 0].min(), points_filtered[:, 1].min()]])) // grid_res
ids = np.asarray(ids, dtype=int)
ids.shape

In [None]:
min_height = points_filtered[:, 2].min()
for i, (idx, idy) in enumerate(ids):
    height_map[idx, idy] = points_filtered[i, 2] - min_height
    
    if label_filtered[i] != 255:
        trav_map[idx, idy] = label_filtered[i]

In [None]:
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(height_map)
plt.grid()

plt.subplot(1, 2, 2)
plt.imshow(trav_map)
plt.grid()

In [None]:
from mpl_toolkits.mplot3d import Axes3D

x = y = np.arange(-max_dist, max_dist, grid_res)
X, Y = np.meshgrid(x, y)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, height_map)
ax.set_zlim3d([-0.1, 10*h_max])
plt.show()