In [None]:
import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured, unstructured_to_structured
import torch
from torch.utils.data import DataLoader, Dataset
from matplotlib import cm
import os
from differentiable_physics.vis import show_cloud
from differentiable_physics.utils import normalize, create_model
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from tqdm import tqdm

In [None]:
# data augmentation
def horizontal_shift(img, shift):
    if shift > 0:
        img_shifted = np.zeros_like(img)
        img_shifted[..., :shift] = img[..., -shift:]
        img_shifted[..., shift:] = img[..., :-shift]
    else:
        img_shifted = img
    return img_shifted
    
class TraversabilityData(object):

    def __init__(self, path):
        self.path = path
        self.ids = [f[:-4] for f in os.listdir(os.path.join(path, 'points'))]
        self.proj_fov_up = 45
        self.proj_fov_down = -45
        self.proj_H = 128
        self.proj_W = 1024
        self.ignore_label = 255

    def range_projection(self, points, labels):
        """ Project a point cloud into a sphere.
        """
        # laser parameters
        fov_up = self.proj_fov_up / 180.0 * np.pi  # field of view up in rad
        fov_down = self.proj_fov_down / 180.0 * np.pi  # field of view down in rad
        fov = abs(fov_down) + abs(fov_up)  # get field of view total in rad

        # get depth of all points
        depth = np.linalg.norm(points, 2, axis=1)

        # get scan components
        scan_x = points[:, 0]
        scan_y = points[:, 1]
        scan_z = points[:, 2]

        # get angles of all points
        yaw = -np.arctan2(scan_y, scan_x)
        pitch = np.arcsin(scan_z / (depth + 1e-8))

        # get projections in image coords
        proj_x = 0.5 * (yaw / np.pi + 1.0)  # in [0.0, 1.0]
        proj_y = 1.0 - (pitch + abs(fov_down)) / fov  # in [0.0, 1.0]

        # scale to image size using angular resolution
        proj_x *= self.proj_W  # in [0.0, W]
        proj_y *= self.proj_H  # in [0.0, H]

        # round and clamp for use as index
        proj_x = np.floor(proj_x)
        proj_x = np.minimum(self.proj_W - 1, proj_x)
        proj_x = np.maximum(0, proj_x).astype(np.int32)  # in [0,W-1]

        proj_y = np.floor(proj_y)
        proj_y = np.minimum(self.proj_H - 1, proj_y)
        proj_y = np.maximum(0, proj_y).astype(np.int32)  # in [0,H-1]

        # order in decreasing depth
        indices = np.arange(depth.shape[0])
        order = np.argsort(depth)[::-1]
        depth = depth[order]
        proj_y = proj_y[order]
        proj_x = proj_x[order]
        indices = indices[order]

        # assing to image
        proj_range = np.full((self.proj_H, self.proj_W), -1, dtype=np.float32)
        proj_range[proj_y, proj_x] = depth

        # projected index (for each pixel, what I am in the pointcloud)
        # [H,W] index (-1 is no data)
        proj_idx = np.full((self.proj_H, self.proj_W), -1, dtype=np.int32)
        proj_idx[proj_y, proj_x] = indices
        # only map colors to labels that exist
        mask = proj_idx >= 0

        # projection color with semantic labels
        proj_sem_label = np.full((self.proj_H, self.proj_W), self.ignore_label, dtype=np.float32)  # [H,W]  label
        proj_sem_label[mask] = labels[proj_idx[mask]]

        return proj_range, proj_sem_label
        
    def __getitem__(self, i, visualize=False):
        ind = self.ids[i]
        cloud = np.load(os.path.join(self.path, 'points', '%s.npz' % ind))['cloud']
        
        if cloud.ndim == 2:
            cloud = cloud.reshape((-1,))
            
        points = structured_to_unstructured(cloud[['x', 'y', 'z']])
        trav = np.asarray(cloud['traversability'], dtype=points.dtype)
        if visualize:
            show_cloud(points, trav, min=0, max=1, colormap=cm.jet)

        depth_range, label_range = self.range_projection(points, trav)
        
        # data augmentation
        # add rotation around vertical axis (Z)
        H, W = depth_range.shape
        shift = np.random.choice(range(W))
        depth_range = horizontal_shift(depth_range, shift=shift)
        label_range = horizontal_shift(label_range, shift=shift)

        return depth_range[None], label_range[None], points
    
    def __len__(self):
        return len(self.ids)

In [None]:
class Trainer(object):

    def __init__(self, dataset, batch_size=4, lr=1e-3, epochs=1):
        self.ds = dataset
        self.dataloader = DataLoader(self.ds, batch_size=batch_size, shuffle=True)
        self.epochs = epochs
        self.device = torch.device('cuda:0')
        self.model = create_model('deeplabv3_resnet50', n_inputs=1, n_outputs=1)
        self.model = self.model.train()
        self.model = self.model.to(self.device)
        self.optimizer = torch.optim.Adam(lr=lr, params=self.model.parameters())
        # self.loss_fn = torch.nn.CrossEntropyLoss()
        # self.loss_fn = smp.losses.LovaszLoss(mode='multilabel', from_logits=False, ignore_index=self.ds.ignore_label)
        self.loss_fn = smp.losses.LovaszLoss(mode='multilabel', from_logits=False)

    def train(self, vis=False):
        losses = []
        
        for e in range(self.epochs):
            print('Training epoch %i' % e)
            
            for i, sample in tqdm(enumerate(self.dataloader)):

                depth, label, points = sample
                depth = depth.to(self.device)
                label = label.to(self.device)

                pred = self.model(depth)['out']

                self.optimizer.zero_grad()
                loss = self.loss_fn(pred, label)
                loss.backward()
                self.optimizer.step()

                losses.append(loss.item())

                if vis and i % (len(self.ds) // 10) == 0:
                    # visualize(pred, label, depth, points)
                    visualize(pred, label, depth)
                
                # print('Training loss: %f' % loss.item())
            
            self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr'] / 10.
            print('Decrease decoder learning rate to: %f' % self.optimizer.param_groups[0]['lr'])
        
            # plot losses
            plt.figure()
            plt.grid()
            plt.plot(losses)
            plt.show()

In [None]:
def visualize(pred, label, depth_range, points=None, ignore_label=255):
    plt.figure(figsize=(20, 10))
    plt.subplot(3, 1, 1)
    plt.title('Prediction')
    pred_vis = torch.clone(pred)
    pred_vis = normalize(pred_vis.detach().cpu())
#     pred_vis[label == ignore_label] = 0
    pred_vis = pred_vis[0].squeeze()
    plt.imshow(pred_vis)

    plt.subplot(3, 1, 2)
    plt.title('Label')
    label_vis = torch.clone(label)
    label_vis = normalize(label_vis.detach().cpu())
    label_vis[label == ignore_label] = 0
    label_vis = label_vis[0].squeeze()
    plt.imshow(label_vis)

    plt.subplot(3, 1, 3)
    plt.title('Range image')
    depth_vis = normalize(torch.clone(depth_range)[0].squeeze().detach().cpu().numpy())
    plt.imshow(depth_vis)

    plt.show()

    if points is not None:
        show_cloud(points[0], pred[0].squeeze().detach().cpu().numpy().reshape((-1,)), min=0, max=1)
        show_cloud(points[0], label[0].squeeze().detach().cpu().numpy().reshape((-1,)), min=0, max=1)

In [None]:
path = '/home/ruslan/data/robingas/data/22-09-27-unhost/husky/husky_2022-09-27-15-01-44/'
assert os.path.exists(path)
ds = TraversabilityData(path)

# visualize a sample from the data set
for i in np.random.choice(range(len(ds)), 5):
    _ = ds.__getitem__(i, visualize=True)

In [None]:
trainer = Trainer(ds, batch_size=8, lr=1e-3, epochs=2)
trainer.train(vis=True)

In [None]:
len(ds)

In [None]:
# test the trained model

for _ in range(5):
    sample = next(iter(trainer.dataloader))
    device = torch.device('cpu')

    depth, label, points = sample
    depth = depth.to(device)
    label = label.to(device)

    model = trainer.model.to(device)
    pred = model(depth)['out']

#     visualize(pred, label, depth, points)
    visualize(pred, label, depth)