In [None]:
import sys
sys.path.append('../src/')
import torch
from torch.utils.data import DataLoader
from monoforce.models.terrain_encoder.voxelnet import VoxelNet
from monoforce.utils import read_yaml, position
from monoforce.datasets.rough import ROUGH, rough_seq_paths
from monoforce.models.traj_predictor.dphys_config import DPhysConfig

In [None]:
class Data(ROUGH):
    def __init__(self, path, lss_cfg, dphys_cfg=DPhysConfig(), is_train=True):
        super(Data, self).__init__(path, lss_cfg, dphys_cfg=dphys_cfg, is_train=is_train)

    def get_sample(self, i):
        points = torch.as_tensor(position(self.get_cloud(i))).T  # (3, N)
        terrain = self.get_terrain_height_map(i)  # (2, H, W), stacked (height, labeled mask)
        return points, terrain

## Dataset: ROUGH

Data description [../docs/DATA.md](https://github.com/ctu-vras/monoforce/blob/master/monoforce/docs/DATA.md)
The sequence used in this example can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1vcZSr4BIv7rBXTcu7YkcbVsKCi5wU6Ci?usp=drive_link).

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lss_config = read_yaml('../config/lss_cfg.yaml')
ds = Data(path=rough_seq_paths[1], lss_cfg=lss_config)
print(f'Dataset length: {len(ds)}')

In [None]:
points, terrain = ds[0]

print(f'Points shape: {points.shape}')  # (3, N), N is the number of points
print(f'Terrain shape: {terrain.shape}')  # (2, H, W), stacked (height, labeled mask)

# visualize the point cloud and terrain height map
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
points_downsampled = points[:, ::10]
ax[0].scatter(points_downsampled[0], points_downsampled[1], s=1)
ax[0].set_title('Point cloud')
ax[0].set_xlabel('x [m]')
ax[0].set_ylabel('y [m]')
ax[0].set_xlim(-6.4, 6.4)
ax[0].set_ylim(-6.4, 6.4)

ax[1].imshow(terrain[0].T, cmap='jet', origin='lower')
ax[1].set_title('Terrain height map')

ax[2].imshow(terrain[1].T, cmap='gray', origin='lower')
ax[2].set_title('Terrain labeled mask')

plt.show()

## Model: VoxelNet

The model is using a voxel grid representation of the point cloud as input.
This example contains only the point cloud part of the model.

Reference: https://hanlab.mit.edu/projects/bevfusion

In [None]:
model = VoxelNet(grid_conf=lss_config['grid_conf'], n_features=16, outC=1)
model.to(device)
print(f'Number of model parameters: {sum(p.numel() for p in model.parameters())}')

In [None]:
loader = DataLoader(ds, batch_size=2, shuffle=True)

points, hm_terrain = next(iter(loader))
print(f'Points shape: {points.shape}')  # (B, 3, N), N is the number of points

In [None]:
with torch.inference_mode():
    with torch.no_grad():
        out = model(points.to(device))
        terrain = out['terrain']
        print(f'Output shape: {terrain.shape}')  # (B, outC, H, W)

## Training

In [None]:
from tqdm.auto import tqdm

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()

for epoch in range(10):
    epoch_loss = 0
    for i, batch in tqdm(enumerate(loader), total=len(loader)):
        batch = [b.to(device) for b in batch]
        points, hm_terrain = batch

        optimizer.zero_grad()

        out = model(points)
        terrain_pred = out['terrain']
        terrain, weights = hm_terrain[:, 0:1], hm_terrain[:, 1:2]
        loss = criterion(terrain_pred * weights, terrain * weights)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(loader)
    print(f'Epoch: {epoch}, Loss: {epoch_loss}')

## Inference

Visualizing an output of the model.

In [None]:
with torch.inference_mode():
    with torch.no_grad():
        batch = next(iter(loader))
        batch = [b.to(device) for b in batch]
        points, hm_terrain = batch
        terrain, weights = hm_terrain[:, 0:1], hm_terrain[:, 1:2]
        out = model(points)
        terrain_pred = out['terrain']
        print(f'Predicted terrain shape: {terrain_pred.shape}')  # (B, outC, H, W)
        
# visualize the output
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(terrain_pred[0, 0].cpu().numpy().T, cmap='jet', origin='lower', vmin=-1, vmax=1)
ax[0].set_title('Prediction')
ax[1].imshow(terrain[0, 0].cpu().numpy().T, cmap='jet', origin='lower', vmin=-1, vmax=1)
ax[1].set_title('Ground truth')
ax[2].imshow(weights[0, 0].cpu().numpy().T, cmap='gray', origin='lower')
ax[2].set_title('Weights')
plt.show()