In [None]:
from pathlib import Path
import json
from ipywidgets import widgets
from matplotlib import pyplot as plt

from price_net.dataset import PriceAttributionDataset
from price_net.schema import PriceAttributionScene
from price_net.utils import plot_price_attribution_scene, plot_bboxes

In [2]:
while True:
    dataset_dir = Path(input("Input the dataset directory: "))
    if dataset_dir.exists() and dataset_dir.is_dir():
        break
    print("Invalid dataset directory. Please try again.")

In [None]:
def plot_scene(idx: int, scenes: list[PriceAttributionScene]):
    plt.close()
    fig, axs = plt.subplots(1, 3, figsize=(12, 4), width_ratios=[1.5, 2, 2])
    axs = axs.ravel()
    axs: list[plt.Axes]
    graph_axis, img_axis, depth_axis = axs
    scene = scenes[idx]

    _, color_key = plot_price_attribution_scene(scene, ax=graph_axis)
    graph_axis.set_title("Associations")
    fig.suptitle(scene.scene_id, fontsize=10)

    image_path = dataset_dir / PriceAttributionDataset.IMAGES_DIR / f"{scene.scene_id}.jpg"
    image = plt.imread(image_path)
    height, width = image.shape[:2]
    img_axis.imshow(image)
    img_axis.set_title("Shelf Image")
    img_axis.axis("off")

    depth_path = dataset_dir / PriceAttributionDataset.DEPTH_MAPS_DIR / f"{scene.scene_id}.jpg"
    depth_axis.imshow(plt.imread(depth_path))
    depth_axis.set_title("Predicted Depth")
    depth_axis.axis("off")

    for group in scene.product_groups:
        color = color_key.get(group.group_id, "white")
        group_bboxes = [scene.product_bboxes[id_] for id_ in group.product_bbox_ids]
        for ax in img_axis, depth_axis:
            plot_bboxes(group_bboxes, ax, color=color, width=width, height=height)
    
    for ax in img_axis, depth_axis:
        plot_bboxes(scene.price_bboxes.values(), ax=ax, linestyle="dashed", color="black", width=width, height=height)

    fig.tight_layout()
    fig.set_dpi(100)
    plt.show()


with open(dataset_dir / PriceAttributionDataset.RAW_PRICE_SCENES_FNAME) as f:
    scenes = [PriceAttributionScene(**x) for x in json.load(f)]

In [None]:
display_func = lambda idx: plot_scene(idx=idx, scenes=scenes)
idx_slider = widgets.IntSlider(value=0, min=0, max=len(scenes)-1, step=1, description="Scene Index")
display(widgets.interact(display_func, idx=idx_slider))