In [None]:
import open3d as o3d
import sys
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

### Depth correction

In [None]:
class Model(nn.Module):

    def __init__(self, batch_size=1):
        super(Model, self).__init__()
        self.w = nn.Parameter(torch.tensor([1.1, 0.2]).view(batch_size, 1, 2))
        self.b = nn.Parameter(torch.tensor([0.3]).view(batch_size, 1))

    def forward(self, x):
        assert x.dim() == 4
        assert x.shape[3] == 2  # x.shape == (B, H, W, 2)
        y = (x * self.w).sum(dim=3) + self.b
        return y

In [None]:
class Dataset(Dataset):
    def __init__(self, path):
        if not os.path.exists(path):
            os.makedirs(path)
        self.root_dir = path
        self.rgbs = [os.path.join(path, 'rgb', f) for f in os.listdir(os.path.join(path, 'rgb')) if '.npy' in f]
        self.depths = [os.path.join(path, 'depth', f) for f in os.listdir(os.path.join(path, 'depth')) if '.npy' in f]
        self.points = [os.path.join(path, 'point_clouds', f) for f in os.listdir(os.path.join(path, 'point_clouds')) if '.npy' in f]
        self.normals = [os.path.join(path, 'normals', f) for f in os.listdir(os.path.join(path, 'normals')) if '.npy' in f]
        self.length = len(self.depths)

    def __getitem__(self, i):
        sample = {'rgb': np.asarray(np.load(self.rgbs[i]), dtype=np.uint8),
                  'depth': np.load(self.depths[i]),
                  'points': np.load(self.points[i]),
                  'normals': np.load(self.normals[i])}
        return sample

    def __len__(self):
        return self.length

In [None]:
path = "../ros_ws/src/gradslam_ros/data/explorer_x1_rgbd_traj/living_room_traj1_frei_png/"

data = Dataset(path)
loader = DataLoader(data)

In [None]:
model = Model()

optim = torch.optim.Adam(model.parameters(), lr=1e-4)
optim.zero_grad()

# for i, sample in tqdm(enumerate(iter(loader))):
sample = next(iter(loader))
for i in range(1500):
    B, H, W = sample['depth'].shape
    normals = torch.as_tensor(sample['normals'], dtype=torch.float32)
    depth = torch.as_tensor(sample['depth'], dtype=torch.float32)
    assert normals.shape == (B, W*H, 3)

    # convert normals to angles
    n = normals / torch.linalg.norm(normals)

    angles = torch.empty(B, H, W)
    for b in range(B):
        n1 = n[b, ...]
        # normal vector to image frame
        v = torch.tensor([0., 0., 1.], dtype=torch.float32)
        angles1 = torch.arccos(torch.clip(n1 @ v, -1.0, 1.0))
        angles[b, ...] = angles1.reshape([H, W])
    assert angles.shape == (B, H, W)

    # input to the model
    x = torch.cat([depth, angles], dim=0).permute(1, 2, 0).view([B, H, W, 2])

    # inference
    depth_pred = model(x)
    assert depth_pred.shape == (B, H, W)
    assert depth.shape == (B, H, W)

    # just compare the pixels values
    loss = torch.mean(torch.abs(depth - depth_pred))
    loss.backward()
    optim.step()

    if i % 300 == 0:
        print("Loss:", loss.item())
        for p in model.parameters():
            print(p)
        plt.figure(figsize=(18, 6))
        plt.subplot(1, 3, 1)
        plt.title('Depth')
        plt.imshow(x[0, :, :, 0], cmap='gray')
        plt.subplot(1, 3, 2)
        plt.title('Angles')
        plt.imshow(x[0, :, :, 1])
        plt.subplot(1, 3, 3)
        plt.imshow(depth_pred[0].detach(), cmap='gray')
        plt.show()
