In [None]:
import sys
sys.path.append('../../pose-consistency-KKT-loss/scripts/')
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import network_s2d
import network_d2rpz
from dataset_s2d import Dataset

In [None]:
path = '../data/kkt/data/'
os.listdir(path)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"
model_s2d = network_s2d.Net()
model_s2d.load_state_dict(torch.load("../config/weights/kkt/network_weights_s2d", map_location=device))
model_s2d.to(device)

dataset_val = Dataset(os.path.join(path, "s2d_tst/"))
valloader = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=True, num_workers=0)

In [None]:
from tqdm import tqdm

with torch.no_grad():

    for i, data in tqdm(enumerate(valloader)):
        input = data['input']
        input_mask = data['mask']
        input, input_mask = input.to(device), input_mask.to(device)

        input_w_mask = torch.cat([input, input_mask], 1)

        output_DEM = model_s2d(input_w_mask)

        dense = output_DEM[:, 0:1]
        
        if i % 100 == 0:
            inpt = input.squeeze().detach().cpu().numpy()
            pred = dense.squeeze().detach().cpu().numpy()
            label = data['label'].squeeze().detach().cpu().numpy()

            plt.figure(figsize=(12, 36))
            plt.subplot(1, 3, 1)
            plt.title('Input')
            plt.imshow(inpt)

            plt.subplot(1, 3 ,2)
            plt.title('Prediction')
            plt.imshow(pred)

            plt.subplot(1, 3, 3)
            plt.title('Label')
            plt.imshow(label)
            
            plt.show()
        break

In [None]:
def set_axes_equal(ax):
    '''Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc..  This is one possible solution to Matplotlib's
    ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

    Input
      ax: a matplotlib axis, e.g., as output from plt.gca().
    '''

    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5*max([x_range, y_range, z_range])

    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])


In [None]:
from mpl_toolkits.mplot3d import Axes3D

grid_res = 0.1
h, w = pred.shape
x_grid, y_grid = np.mgrid[-h//2:h//2, -w//2:w//2] * grid_res
    
# Visualization of the data
fig = plt.figure(figsize=(24, 12))
ax = fig.add_subplot(121, projection='3d')
plt.title('Prediction')
ax.plot_surface(x_grid, y_grid, pred, alpha=0.7)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
set_axes_equal(ax)

ax = fig.add_subplot(122, projection='3d')
plt.title('Label')
ax.plot_surface(x_grid, y_grid, label, alpha=0.7)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
set_axes_equal(ax)

## RobinGas data 

In [None]:
from differentiable_physics.datasets import RobinGasDataset
from differentiable_physics.segmentation import position
from differentiable_physics.config import Config

cfg = Config()
cfg.grid_res = 0.1
cfg.d_max = 12.8
cfg.d_min = 1.
cfg.hm_interp_method = None

path = '/home/ruslan/data/robingas/data/22-08-12-cimicky_haj/marv/ugv_2022-08-12-15-18-34_trav/'
ds = RobinGasDataset(path, cfg=cfg)

i = 0
cloud, traj, heightmap = ds[i]
height = heightmap['z']

In [None]:
h, w = height.shape

with torch.no_grad():
    input = torch.as_tensor(height, dtype=torch.float32).view((1, 1, h, w))
    # input_mask = torch.ones_like(input)
    input_mask = torch.as_tensor(heightmap['mask'], dtype=torch.float32).view((1, 1, h, w))
    input, input_mask = input.to(device), input_mask.to(device)

    input_w_mask = torch.cat([input, input_mask], 1)

    output_DEM = model_s2d(input_w_mask)

In [None]:
pred = output_DEM[0, 0].squeeze().cpu().numpy()
mask = input_mask[0].squeeze().cpu().numpy()

plt.figure(figsize=(10, 5))
plt.subplot(131)
plt.imshow(height)
plt.colorbar()

plt.subplot(132)
plt.imshow(mask, cmap='gray')
plt.colorbar()

plt.subplot(133)
plt.imshow(pred)
plt.colorbar()

In [None]:
h, w = pred.shape
x_grid, y_grid = np.mgrid[-h//2:h//2, -w//2:w//2] * cfg.grid_res
    
# Visualization of the data
fig = plt.figure(figsize=(24, 12))

ax = fig.add_subplot(121, projection='3d')
plt.title('Input')
ax.plot_surface(x_grid, y_grid, height, alpha=0.7)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
set_axes_equal(ax)

ax = fig.add_subplot(122, projection='3d')
plt.title('Prediction')
ax.plot_surface(x_grid, y_grid, pred, alpha=0.7)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
set_axes_equal(ax)