In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

In [3]:
import hydra
from hydra import compose, initialize
import matplotlib.pyplot as plt
import collections

import json
import einops
import numpy as np
import torch
import torchvision
import torchvision.transforms.functional as TF


from run import _init_model_loss

In [None]:
with initialize(config_path="../configs/"):
    cfg = compose(config_name="run_vqbet", overrides=["task=tissue_pick_up", "run_offline=True", "dataset.test.shuffle_mode=SEQUENTIAL", "device=cuda"])

In [None]:
model = _init_model_loss(cfg)
model.eval();

In [None]:
test_dataset = hydra.utils.instantiate(cfg.dataset.test)
test_dataset.set_include_trajectory_end(True)
buffer_size = cfg["image_buffer_size"]
device = cfg["device"]

In [None]:
i = 0
(input_images, terminate), *_, gt_actions = test_dataset[i]
TF.to_pil_image(torchvision.utils.make_grid(input_images))

In [8]:
def run_model(model, test_dataset, num_steps, repeat_num=1):
    action_preds = []
    ground_truth = []
    images = []
    image_buffer = collections.deque(maxlen=buffer_size)
    for i in range(num_steps):
        # get step from dataset
        (input_images, terminate), *_, gt_actions = test_dataset[i]

        # prepare input for forward pass
        input_images = input_images.float() / 255.0
        image_buffer.append(input_images[-1])
        img = input_images[-1]
        images.append(einops.rearrange(img, "c h w -> h w c").cpu().detach().numpy())
        ground_truth.append(gt_actions[-1])

        model_input = (
            torch.stack(tuple(image_buffer), dim=0).unsqueeze(0).repeat(repeat_num, 1, 1, 1, 1).to(device),
            torch.tensor(gt_actions).unsqueeze(0).repeat(repeat_num, 1, 1).to(device),
        )

        # forward pass
        out, _ = model.step(model_input)
        action_preds.append(out.squeeze().cpu().detach().numpy())

    return np.array(action_preds), np.array(ground_truth), np.array(images)

In [9]:
action_preds, ground_truth, images = run_model(model, test_dataset, 100, repeat_num=1)

In [None]:
plt.plot(action_preds - ground_truth)

# Uncertainty quantification

- Set model to eval() mode but enable dropout
- Do forward pass on the model 64 times.
- Look at the standard deviation of the outputs.

In [11]:
model.eval()
for module in model.modules():
    # if model is Dropout, set it to train mode
    if isinstance(module, torch.nn.Dropout):
        module.train(True)

In [12]:
action_preds, ground_truth, images = run_model(model, test_dataset, 100, repeat_num=64)

In [13]:
action_pred_xyz_norm = np.linalg.norm(action_preds[:, :, :3], axis=-1)

In [None]:
action_pred_xyz_norm.shape

In [None]:
plt.plot(action_pred_xyz_norm.std(axis=1))

In [None]:
action_pred_xyz_norm_std= action_pred_xyz_norm.std(axis=1)
normalized_action_pred_xyz_std = (action_pred_xyz_norm_std - action_pred_xyz_norm_std.min()) / (action_pred_xyz_norm_std.max() - action_pred_xyz_norm_std.min())
plt.plot(normalized_action_pred_xyz_std)

In [19]:
# put a circle with radius of 0.1 in images[0]
import cv2

circled_img = cv2.circle(images[0].copy(), center=(128, 200), radius=50, color=(255, 0, 0), thickness=3)

In [28]:
video_frames = []
for i, img in enumerate(images):
    frame = (img.copy() * 255).astype(np.uint8)
    video_frames.append(cv2.circle(frame, center=(128, 200), radius=int(50 * normalized_action_pred_xyz_std[i]), color=(0, 0, 255), thickness=3))

In [33]:
# write video frames to video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('uncertainty_quantification.mp4', fourcc, 1.0, (256, 256))

for frame in video_frames:
    out.write(frame[..., ::-1])

out.release()
