In [None]:
import os
import random
import numpy as np
import torch
import torchvision.transforms.functional as VF
import hydra
from omegaconf import OmegaConf
import cv2
import matplotlib.pyplot as plt
import h5py
from hydra import initialize, compose

from gazebot.test import make_gaze_agent
from gazebot.utils import array2image

## Load config

In [None]:
# hydra initialize
with initialize(version_base=None, config_path="../config"):
    args = compose(config_name="config.yaml")
    args.hydra_base_dir = os.getcwd()
    args.device = args.cuda_device if torch.cuda.is_available() else "cpu"

# Set seeds
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# Set device
device = torch.device(args.device)
if device.type == "cuda" and torch.cuda.is_available() and args.cuda_deterministic:
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
torch.cuda.set_device(device)  # change default device (e.g. .cuda() or torch.cuda.IntTensor(x))

print(OmegaConf.to_yaml(args))

## Load trained gaze model

In [None]:
# Load trained model
gaze_agent = make_gaze_agent(args)

## Load demonstration data

In [None]:
data_dir = args.expert.test_path[0]
data_files = sorted([os.path.join(os.path.expanduser(data_dir), f) for f in os.listdir(data_dir) if "h5" in f])
print(data_files)

## Visualize

In [None]:
save_gif = True

data_scores = []
for i in range(len(data_files)):
    print(f"Evaluation Demonstration ({i}): {data_files[i]}")
    images = []
    scores = []
    with h5py.File(data_files[i], "r") as e:
        eps_steps = len(e["left_img"])

        # change_steps: steps in which gaze transition is occurred
        change_steps = e["change_steps"][1:]
        print("change_steps:", change_steps)

        for step in range(eps_steps):
            image = np.transpose(np.stack([e["left_img"][step], e["right_img"][step]]), (0, 3, 1, 2)) / 255.0  # (2, C, H, W)
            if args.expert.bgr:
                image = image[:, [2, 1, 0]]  # BGR2RGB

            _, _, H, W = image.shape

            # Gaze
            human_gaze = np.array(e["gaze"][step]).reshape(2, 2)  # teleoperator's gaze position in pixel coord
            human_gaze = np.clip(human_gaze, [0, 0], [W - 1, H - 1])  # (2, 2)
            human_gaze = np.round(human_gaze).astype(np.int64)  # (2, 2)

            # SegIdx
            seg_idx = np.sum(np.array(change_steps) <= step)

            # To tensor
            image = torch.as_tensor(image, dtype=torch.float, device=device)  # (2, C, H, W)
            image = VF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            image = image.unsqueeze(0)  # (1, 2, C, H, W)

            seg_idx = torch.tensor([seg_idx], dtype=torch.long)  # (1,)
            seg_idx = seg_idx.unsqueeze(0)  # (1, 1)

            # Predict gaze
            with torch.no_grad():
                predict_gaze, _ = gaze_agent(image, seg_idx)  # (1, 2, 2)
            predict_gaze = gaze_agent.model.denormalize(predict_gaze, H, W).detach().cpu().numpy()[0]  # (2, 2)
            predict_gaze = np.round(predict_gaze).astype(np.int64)  # (2, 2)

            score = np.linalg.norm(human_gaze - predict_gaze) / max(H, W)  # Normalize by longer edge length
            scores.append(score)

            # Visualize
            if save_gif:
                im = np.array(e["left_img"][step])
                if not args.expert.bgr:
                    im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)  # RGB2BGR
                im = cv2.circle(im, (human_gaze[0, 0], human_gaze[0, 1]), 20, (0, 0, 255), 5)
                im = cv2.circle(im, (predict_gaze[0, 0], predict_gaze[0, 1]), 20, (255, 255, 255), 5)
                images.append(array2image(im[:, :, [2, 1, 0]]).resize((320, 180)))

    if save_gif:
        os.makedirs(f"outputs/{args.expert.task_type}", exist_ok=True)
        images[0].save(
            f"outputs/{args.expert.task_type}/gaze_{data_files[i]}.gif",
            save_all=True,
            append_images=images[1:],
            optimize=False,
            duration=40,
            loop=1,
        )
        print(f"Save outputs/{args.expert.task_type}/gaze_{data_files[i]}.gif")

    data_scores.append(scores)
    print("episode score:", np.mean(scores))

    plt.plot(scores, label="score")
    plt.show()

print("Total score:", np.mean(np.concatenate(data_scores)), np.median(np.concatenate(data_scores)))