In [None]:
from matplotlib import pyplot as plt
import numpy as np
from argparse import Namespace
import json
import torch

In [None]:
from scripts.train import DataModule, _to_image, KeypointModule

In [None]:
with open('../config/cups.json', 'rt') as f:
    keypoint_config = json.load(f)
module = DataModule(Namespace(train='/home/ken/data/cups_train/', val='/home/ken/data/cups_test/', batch_size=1, workers=1, pool=32), keypoint_config=keypoint_config)
module.setup('fit')

dataloader = module.val_dataloader()
print(dataloader)
train_iterator = iter(dataloader)

In [None]:
import cv2
def resize(target, width=320, height=180):
    return cv2.resize(target, (width, height))
    

In [None]:
module = KeypointModule.load_from_checkpoint('../lightning_logs/version_0/checkpoints/epoch=15-step=33567.ckpt', keypoint_config=keypoint_config)

In [None]:
model = module.model.eval()

In [None]:
frame, target, depth, centers = next(train_iterator)

plt.figure(figsize=(14, 8))
image = _to_image(frame[0].numpy())
for i in range(2):
    for j in range(2):
        axis = plt.subplot2grid((2, 2), loc=(i, j))
        axis.imshow(resize(image))
        axis.imshow(resize(target[0, i * 2 + j].numpy()), alpha=0.7)
        plt.axis('off')
plt.tight_layout()
pass

In [None]:
with torch.no_grad():
    heatmap_p, depth_p, centers_p = model(frame)
    heatmap_p = torch.sigmoid(heatmap_p)

plt.figure(figsize=(14, 8))
image = _to_image(frame[0].numpy())
for i in range(2):
    for j in range(2):
        axis = plt.subplot2grid((2, 2), loc=(i, j))
        axis.imshow(resize(image))
        axis.imshow(resize(heatmap_p[0, i * 2 + j].detach().numpy()), alpha=0.7, vmin=0.0, vmax=1.0)
        plt.axis('off')
plt.tight_layout()
pass

In [None]:
with torch.no_grad():
    heatmap_p, depth_p, centers_p = model(frame)

plt.figure(figsize=(14, 8))
image = _to_image(frame[0].numpy())
plt.imshow(resize(depth_p[0, 3].detach().numpy()), alpha=0.7, vmin=0.0, vmax=2.0)
plt.axis('off')
plt.tight_layout()
pass

In [None]:
pixel_indices = np.zeros((2, 180, 320), dtype=np.float32)
for i in range(180):
    for j in range(320):
        pixel_indices[:, i, j] = [j + 0.5, i + 0.5]

In [None]:
vectors = centers.numpy()[0]
norms = np.linalg.norm(vectors, axis=0)
where_non_zero = np.abs(norms) > 1e-1
gt_centers = np.zeros((2, pixel_indices.shape[1], pixel_indices.shape[2]), dtype=np.float32)
gt_centers[:, where_non_zero] = pixel_indices[:, where_non_zero] + vectors[:, where_non_zero]

where_heatmap_non_zero = heatmap_p[0].numpy().sum(axis=0) > 0.1
p_centers = np.zeros((2, pixel_indices.shape[1], pixel_indices.shape[2]), dtype=np.float32)
p_centers[:, where_heatmap_non_zero] = pixel_indices[:, where_heatmap_non_zero] + centers_p[0].detach().numpy()[:, where_heatmap_non_zero]

figure = plt.figure(figsize=(10, 5))
dotted_image = cv2.resize(image.copy(), (320, 180))
for point in gt_centers[:, where_non_zero].transpose():    
    cv2.circle(dotted_image, (point[0], point[1]), 2, (255, 0, 0), -1)
    
dotted_image_pred = cv2.resize(image.copy(), (320, 180))
for point in p_centers[:, where_non_zero].transpose():    
    cv2.circle(dotted_image_pred, (point[0], point[1]), 2, (255, 0, 0), -1)

axis = plt.subplot2grid((1, 2), loc=(0, 0))
axis.imshow(dotted_image)
plt.axis('off')

axis = plt.subplot2grid((1, 2), loc=(0, 1))
axis.imshow(dotted_image_pred)
plt.axis('off')
pass