In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
path = '../data/kkt/data/'
os.listdir(path)

## Sparse to Dense

In [None]:
import sys
sys.path.append('../../pose-consistency-KKT-loss/scripts/')
from dataset_s2d import Dataset

ds = Dataset(path=os.path.join(path, 's2d_trn/'))
len(ds)

In [None]:
data = ds[int(np.random.choice(range(len(ds))))]

input = data['input']
label = data['label_d']
input_mask = data['mask']
weights = data['weights']

In [None]:
def normalize(x, eps=1e-8):
    x = x - x.min()
    x = x / (x.max() + eps)
    return x

In [None]:
plt.figure(figsize=(20, 10))
plt.subplot(1, 4, 1)
plt.title('Input')
plt.imshow(normalize(input[0]))

plt.subplot(1, 4, 2)
plt.title('Label')
plt.imshow(normalize(label[0]))

plt.subplot(1, 4, 3)
plt.title('Input mask')
plt.imshow(input_mask[0])

plt.subplot(1, 4, 4)
plt.title('Weights')
plt.imshow(weights[0])

## Dense to RPZ

In [None]:
from dataset_d2rpz import Dataset

ds = Dataset(path=os.path.join(path, 'd2rpz_labels/'))
len(ds)

In [None]:
data = ds[int(np.random.choice(range(len(ds))))]

input = data['input']
label = data['label']

print(input.shape, label.shape)

In [None]:
plt.figure(figsize=(20, 10))
plt.subplot(1, 4, 1)
plt.title('Input')
plt.imshow(normalize(input[0]))

plt.subplot(1, 4, 2)
plt.title('Label 1: Roll')
plt.imshow(label[0])

plt.subplot(1, 4, 3)
plt.title('Label 2: Pitch')
plt.imshow(label[1])

plt.subplot(1, 4, 4)
plt.title('Label 3: Z')
plt.imshow(label[2])

## Soft to Dense

In [None]:
from dataset_sf2d import Dataset

ds = Dataset(path=os.path.join(path, 'sf2d_trn/'))
len(ds)

In [None]:
data = ds[int(np.random.choice(range(len(ds))))]

input = data['input']
label = data['label']
input_mask = data['mask']
weights = data['weights']
imgs = data['images']

T_baselink_zpr = data['T_baselink_zpr']
features = data['features']

In [None]:
for i in range(len(imgs)):
    plt.imshow(imgs[i][..., (2,1,0)])
    plt.show()

In [None]:
plt.figure(figsize=(20, 10))
plt.subplot(1, 4, 1)
plt.title('Input')
plt.imshow(normalize(input[0]))

plt.subplot(1, 4, 2)
plt.title('Label')
plt.imshow(normalize(label[0]))

plt.subplot(1, 4, 3)
plt.title('Input mask')
plt.imshow(input_mask[0])

plt.subplot(1, 4, 4)
plt.title('Weights')
plt.imshow(weights[0])

## KKT

In [None]:
i = np.random.choice(os.listdir(os.path.join(path, 'tomas_pose_all')))
i_path = os.path.join(path, 'tomas_pose_all', i)

data = np.load(i_path)
print(data.files)

In [None]:
input = data['input']
label_rpz = data['label_rpz']
dem_interp = data['dem_interp']
dem_s2d2rpz = data['dem_s2d2rpz']
dem_s2d2kkt = data['dem_s2d2kkt']

In [None]:
input.shape, label_rpz.shape, dem_interp.shape, dem_s2d2rpz.shape, dem_s2d2kkt.shape

In [None]:
label_rpz_vis = label_rpz.copy()
label_rpz_vis[np.isnan(label_rpz_vis)] = -1

plt.figure(figsize=(10, 15))

plt.subplot(3, 2, 1)
plt.title('Input')
plt.imshow(input)

plt.subplot(3, 2, 2)
plt.title('Label 1: Roll')
plt.imshow(label_rpz_vis[0])

plt.subplot(3, 2, 3)
plt.title('Label 2: Pitch')
plt.imshow(label_rpz_vis[1])

plt.subplot(3, 2, 4)
plt.title('DEM Interp')
plt.imshow(normalize(dem_interp))

plt.subplot(3, 2, 5)
plt.title('DEM S2D2RPZ')
plt.imshow(normalize(dem_s2d2rpz))

plt.subplot(3, 2, 6)
plt.title('DEM S2D2KKT')
plt.imshow(normalize(dem_s2d2kkt))

## Real RPZ

In [None]:
from dataset_real_rpz import Dataset

ds = Dataset(path=os.path.join(path, 's2d_trn/'))
len(ds)

In [None]:
data = ds[int(np.random.choice(range(len(ds))))]

input = data['input']
label_rpz = data['label_rpz']
mask = data['mask']
label_dem = data['label_dem_d']
label_dem_p = data['label_dem']
weights = data['weights']
yaw = data['yaw']

In [None]:
input.shape, label_rpz.shape, mask.shape, label_dem.shape, label_dem_p.shape, weights.shape, yaw.shape

In [None]:
label_rpz_vis = label_rpz.copy()
label_rpz_vis[np.isnan(label_rpz_vis)] = -1

yaw_vis = yaw.copy()
yaw_vis[np.isnan(yaw_vis)] = -np.pi/2

plt.figure(figsize=(10, 15))

plt.subplot(4, 2, 1)
plt.title('Input')
plt.imshow(input[0])

plt.subplot(4, 2, 2)
plt.title('Label 1: Roll')
plt.imshow(label_rpz_vis[0])

plt.subplot(4, 2, 3)
plt.title('Label 2: Pitch')
plt.imshow(label_rpz_vis[1])

plt.subplot(4, 2, 4)
plt.title('Label 2: Z')
plt.imshow(label_rpz_vis[2])

plt.subplot(4, 2, 5)
plt.title('DEM Dense')
plt.imshow(normalize(label_dem[0]))

plt.subplot(4, 2, 6)
plt.title('DEM Sparse')
plt.imshow(normalize(label_dem_p[0]))

plt.subplot(4, 2, 7)
plt.title('Weights')
plt.imshow(weights[0])

plt.subplot(4, 2, 8)
plt.title('Yaw')
plt.imshow(yaw_vis)