In [None]:
from pathlib import Path
import json

import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np

from grconvnet.utils import visualization as vis
from grconvnet.datatypes import ImageGrasp, RealGrasp

In [None]:
result_paths = [
    Path.cwd().parent / "grconvnet" / "results" / "ycb_1_b10",
    Path.cwd().parent / "grconvnet" / "results" / "ycb_2_b10",
    # Path.cwd().parent / "grconvnet" / "results" / "ycb_3_b10",
]


In [None]:
result_names = sorted([p.name for p in result_paths[0].iterdir() if p.is_dir()])

for result_name in result_names:
    fig, axes = plt.subplots(2, len(result_paths), figsize=(10, 5))
    fig.suptitle(result_name)
    fig.tight_layout()

    for i, result_path in enumerate(result_paths):
        with open(result_path / result_name / "data.json") as f:
            grasps_data = json.load(f)

        grasps_image = [
            ImageGrasp(np.array(g["center"]), g["quality"], g["angle"], g["width"])
            for g in grasps_data["grasps_img"]
        ]
        grasps_world = [
            RealGrasp(np.array(g["center"]), g["quality"], g["angle"], g["width"])
            for g in grasps_data["grasps_world"]
        ]

        rgb_cropped = mpl.image.imread(result_path / result_name / "rgb_cropped.png")
        rgb_original = mpl.image.imread(result_path / result_name / "original_rgb.png")

        # axes[i].imshow(rgb_original)
        vis.image_grasps_ax(axes[0][i], rgb_cropped, grasps_image)
        axes[0][i].set_title(f"{result_path.name} image")
        vis.world_grasps_ax(
            axes[1][i],
            rgb_original,
            grasps_world,
            np.array(grasps_data["cam_intrinsics"]),
            np.array(grasps_data["cam_rot"]),
            np.array(grasps_data["cam_pos"]),
        )
        axes[1][i].set_title(f"{result_path.name} world")

    plt.show()
