In [None]:
from monoforce.vis import show_cloud
from monoforce.utils import normalize, create_model
from monoforce.datasets import RobinGasDataset
from monoforce.segmentation import position
from monoforce.config import Config
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from matplotlib import cm
import os
import cv2
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from tqdm import tqdm

In [None]:
class Dataset(RobinGasDataset):
    def __init__(self, path, cfg=Config()):
        super(Dataset, self).__init__(path, cfg)
        self.img_size = (512, 512)
        
    def __getitem__(self, i):
        # point cloud
        cloud = self.get_cloud(i)
        points = position(cloud)

        # height map: estimated from point cloud
        heightmap = self.estimate_heightmap(points, self.cfg)
        height_est = heightmap['z']
        x_grid, y_grid = heightmap['x'], heightmap['y']

        # height map: optimized from robot-terrain interaction model
        terrain = self.get_optimized_terrain(i)
        height_opt = terrain['height']
        
        # images
        img_raw = self.get_image(i)
        img_raw = img_raw[..., (2, 1, 0)]  # BGR -> RGB
        
        # resize image
        H_raw, W_raw = img_raw.shape[:2]
        h, w = self.img_size
        img = cv2.resize(img_raw, (int(h/H_raw * W_raw), h))
        # crop image
        H, W = img.shape[:2]
        img = img[H-h:H, W // 2 - w // 2: W // 2 + w // 2]
        
        return img, img_raw, height_opt, height_est

In [None]:
path = '/home/ruslan/data/robingas/data/22-09-27-unhost/husky/husky_2022-09-27-15-01-44_trav/'
# path = '/home/ruslan/data/robingas/data/22-08-12-cimicky_haj/marv/ugv_2022-08-12-15-18-34_trav/'

cfg = Config()
cfg.from_yaml(os.path.join(path, 'terrain', 'train_log', 'cfg.yaml'))

assert os.path.exists(path)
ds = Dataset(path, cfg=cfg)

len(ds)

In [None]:
i = 11
# i = np.random.choice(range(len(ds)))

# trajectory poses
poses = ds.get_traj(i)['poses']
img, img_raw, height_opt, height_est = ds[i]

h_hm, w_hm = height_est.shape
xy_grid = poses[:, :2, 3] / cfg.grid_res + np.array([h_hm / 2, w_hm / 2])

if img is not None:
    plt.figure(figsize=(20, 10))
    plt.subplot(121)
    plt.imshow(img_raw)
    plt.subplot(122)
    plt.imshow(img)

plt.figure(figsize=(20, 10))
plt.subplot(131)
plt.imshow(height_est)
plt.plot(xy_grid[:, 0], xy_grid[:, 1], 'rx', markersize=4)
plt.subplot(132)
plt.imshow(height_opt)
plt.plot(xy_grid[:, 0], xy_grid[:, 1], 'rx', markersize=4)
plt.show()

## Monolayout Training

In [None]:
import sys
sys.path.append('/home/ruslan/workspaces/traversability_ws/src/thridparty/bev-net/monolayout/')
import monolayout

In [None]:
models = {}
H, W = img.shape[:2]
models["encoder"] = monolayout.Encoder(num_layers=18, img_ht=H, img_wt=W, pretrained=True)
models["decoder"] = monolayout.Decoder(models["encoder"].resnet_encoder.num_ch_enc)
# models['discriminator'] = monolayout.Discriminator()

In [None]:
device = torch.device('cuda')

for key in models.keys():
    models[key].to(device)

In [None]:
def pred_to_vis(tv):
    tv_np = tv.squeeze().cpu().numpy()
    true_top_view = np.zeros((tv_np.shape[1], tv_np.shape[2]))
    true_top_view[tv_np[1] > tv_np[0]] = 255
    return true_top_view

In [None]:
if False:
    # calculate mean and std from the entire dataset
    means, stds = [], []
    for i in tqdm(range(len(ds))):
        img, _, _ = ds[i]
        img_01 = img / 255.

        mean = img_01.reshape([-1, 3]).mean(axis=0)
        std = img_01.reshape([-1, 3]).std(axis=0)

        means.append(mean)
        stds.append(std)

    mean = np.asarray(means).mean(axis=0)
    std = np.asarray(stds).mean(axis=0)

    print(f'Estimated mean: {mean} \n and std: {std}')
    
else:
    print('Using precalculated mean and std')
    
    mean = np.array([0.4750956,  0.47310572, 0.42155158] )
    std = np.array([0.2212268,  0.23130926, 0.29598755])

In [None]:
def normalize_img(img, mean, std):
    H, W, C = img.shape
    img -= img.min()
    img = img / img.max()
    img_01 = img
    img_01_CHW = img_01.transpose((2, 0, 1))
    img_CHW_norm = (img_01_CHW - mean.reshape((C, 1, 1))) / std.reshape((C, 1, 1))
    return img_CHW_norm

# plt.figure()
# plt.imshow(img_3HW_norm.transpose((1, 2, 0)))
# plt.show()

In [None]:
with torch.no_grad():
    img_CHW_norm = normalize_img(img, mean, std)
    inp = torch.as_tensor(img_CHW_norm[None], device=device, dtype=torch.float32)
    features = models['encoder'](inp)
    tv = models['decoder'](features, is_training=False)
    
    tv_np = tv.squeeze().cpu().numpy()
    pred = pred_to_vis(tv)
    print(tv.shape, height_est.shape)
        
plt.figure()
plt.imshow(pred)
plt.show()