In [None]:
import json
import os
import sys
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
from pathlib import Path

parent_dir = os.path.abspath('../')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from utils.image_handling import crop_image
from utils.keypoints import crop_and_resize_keypoints

In [None]:
BASE_PATH = Path("../../")

In [None]:
def load_sample(sample_path):
    with open(sample_path, 'r') as f:
        samples = json.load(f)
    return samples

In [None]:
def prepare_cropped_image(sample_entry, target_size=(128, 128)):
    rgb_path = Path(sample_entry['rgb_path'].replace('\\', '/'))
    img = cv2.imread(BASE_PATH / str(rgb_path))
    if img is None:
        raise FileNotFoundError(f"Image not found: {rgb_path}")
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Crop using bbox_obj
    x, y, w, h = sample_entry['bbox_obj']
    cropped_img = crop_image(img_rgb, bbox=(x, y, w, h))

    # Resize cropped image
    resized_img = cv2.resize(cropped_img, target_size, interpolation=cv2.INTER_LINEAR)

    # Adjust keypoints relative to crop
    keypoints_2D = np.array(sample_entry['keypoints_2D'])
    keypoints_2D_resized = crop_and_resize_keypoints(keypoints_2D, crop_box=(x, y, w, h), target_size=target_size)

    return resized_img, keypoints_2D_resized, sample_entry['image_id']

In [None]:
def visualize_multiple_samples(samples):
    fig, axes = plt.subplots(1, len(samples), figsize=(5 * len(samples), 5))

    if len(samples) == 1:
        axes = [axes]

    for ax, sample_entry in zip(axes, samples):
        resized_img, keypoints_2D_resized, image_id = prepare_cropped_image(sample_entry)
        ax.imshow(resized_img)
        ax.axis('off')

        text_offset = 2
        for idx, (x_kp, y_kp) in enumerate(keypoints_2D_resized):
            ax.scatter(x_kp, y_kp, c='red', s=15)
            ax.text(x_kp + text_offset, y_kp - text_offset, str(idx), color='blue', fontsize=8)

        ax.set_title(f"Image ID: {image_id}")

    plt.tight_layout()
    plt.show()

In [None]:
obj_id = 1  # Change as needed
sample_path = BASE_PATH / f"data/annotations/{obj_id:06d}.json"
samples = load_sample(sample_path)

# Randomly select 3 samples
random.seed(36)
selected_samples = random.sample(samples, 3)

visualize_multiple_samples(selected_samples)
