In [1]:
from pathlib import Path
import json
from ipywidgets import widgets
from matplotlib import pyplot as plt
import yaml
import torch
from matplotlib import colormaps
from matplotlib.colors import Normalize

from price_net.configs import TrainingConfig, EvaluationConfig
from price_net.datamodule import PriceAssociationDataModule
from price_net.models import PriceAssociatorLightningModule
from price_net.schema import PriceAssociationScene
from price_net.utils import plot_price_scene
from price_net.enums import PredictionStrategy

In [None]:
while True:
    config_path = Path(input("Enter the path to the eval config of the model whose predictions will be viewed: "))
    try:
        with open(config_path, "r") as f:
            eval_config = EvaluationConfig(**yaml.safe_load(f))
        break
    except Exception:
        print("Invalid path. Try again.")

In [3]:
ckpt_path = Path("..") / eval_config.ckpt_path
model = PriceAssociatorLightningModule.load_from_checkpoint(ckpt_path).eval()
with open(Path("..") / eval_config.trn_config_path, "r") as f:
    training_config = TrainingConfig(**yaml.safe_load(f))

datamodule = PriceAssociationDataModule(
    data_dir=Path("..") / training_config.dataset_dir,
    input_reduction=training_config.model.input_reduction,
    prediction_strategy=training_config.model.prediction_strategy,
    featurization_method=training_config.model.featurization_method,
    use_depth=training_config.model.use_depth,
)
datamodule.setup("test")

dataset = datamodule.test
with open(dataset.root_dir / dataset.RAW_PRICE_SCENES_FNAME, "r") as f:
    raw_scenes = [PriceAssociationScene(**scene) for scene in json.load(f)]
raw_scenes = {scene.scene_id: scene for scene in raw_scenes}

In [4]:
@torch.inference_mode()
def plot_scene(idx: int, threshold: float):
    plt.close()
    X, y, scene_id = dataset[idx]
    group_ids = dataset.instances[dataset.scene_id_to_indices[scene_id]]["group_id"].to_list()
    clean_scene_id = str(scene_id).split("__")[0]
    scene = raw_scenes[clean_scene_id]
    id_to_product_group = {group.group_id: group for group in scene.product_groups}

    images_dir = datamodule.test.root_dir / datamodule.test.IMAGES_DIR
    price_centroids = X[:, 5:7]

    if training_config.model.prediction_strategy == PredictionStrategy.JOINT:
        X = X.unsqueeze(0)
    
    probs = model.forward(X).sigmoid().flatten()

    fig, axs = plt.subplots(1, 2, width_ratios=[1.5, 2])
    axs: list[plt.Axes]
    
    plot_price_scene(scene, ax=axs[0])

    norm = Normalize(vmin=0.0, vmax=1.0)
    good_cmap = colormaps["Greens"]
    bad_cmap = colormaps["Reds"]
    for price_centroid, group_id, pred_prob, label in zip(price_centroids, group_ids, probs, y):
        pred_association = pred_prob > threshold
        if pred_association:
            color = good_cmap(norm(pred_prob)) if label.item() == 1 else bad_cmap(norm(pred_prob))
            for prod_bbox in [scene.product_bboxes[id_] for id_ in id_to_product_group[group_id].product_bbox_ids]:
                axs[0].plot(
                    [prod_bbox.cx, price_centroid[0]],
                    [prod_bbox.cy, price_centroid[1]],
                    color=color,
                    alpha=0.3,
                )

    axs[0].set_xlim(0.0, 1.0)
    axs[0].set_ylim(1.0, 0.0)
    axs[0].set_aspect("equal")

    axs[1].imshow(plt.imread(images_dir / f"{clean_scene_id}.jpg"))
    axs[1].axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
display_func = lambda idx, threshold: plot_scene(idx=idx, threshold=threshold)
idx_slider = widgets.IntSlider(value=0, min=0, max=len(dataset)-1, step=1, description="Scene Index")
threshold_slider = widgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.01, description="Threshold")
display(widgets.interact(display_func, idx=idx_slider, threshold=threshold_slider))