# Check if the dataset lables are correct (GT)

In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from IPython.display import clear_output, display
import pickle
import json
import ipywidgets as widgets
import threading
import time


In [None]:
def save_checkpoint(subject_id, camera, batch_index):
    checkpoint = {'subject_id': subject_id, 'camera': camera, 'batch_index': batch_index}
    with open('checkpoint_gt.pkl', 'wb') as f:
        pickle.dump(checkpoint, f)

In [None]:
def load_checkpoint():
    if os.path.exists('checkpoint_gt.pkl'):
        with open('checkpoint_gt.pkl', 'rb') as f:
            return pickle.load(f)
    return {}

In [None]:
def save_marks_json(marks, fname='marks2302.json'):
    with open(fname, 'w') as f:
        json.dump(marks, f, indent=2)

def load_marks_json(fname='marks2302.json'):
    if os.path.exists(fname):
        with open(fname, 'r') as f:
            return json.load(f)
    return {}

In [None]:
def display_images(control_img1, control_img2, subject_img, frame_images):
    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec

    rows = (len(frame_images) // 10) + (1 if len(frame_images) % 10 != 0 else 0)
    fig = plt.figure(figsize=(20, (rows + 1) * 2.5))
    gs = gridspec.GridSpec(rows + 1, 15, figure=fig)  # 15 columns: 6 for big images, 9 for frames

    # Show control image 1 in (0,0:2)
    ax = fig.add_subplot(gs[0, 0:3])
    ax.imshow(control_img1)
    ax.axis('off')
    ax.set_title("Control Image (set1)", fontsize=14, weight='bold')

    # Show control image 2 in (0,3:6)
    ax = fig.add_subplot(gs[0, 3:6])
    ax.imshow(control_img2)
    ax.axis('off')
    ax.set_title("Control Image (set2)", fontsize=14, weight='bold')

    # Show subject image in (0,6:9)
    ax = fig.add_subplot(gs[0, 6:9])
    ax.imshow(subject_img)
    ax.axis('off')
    ax.set_title("First Frame", fontsize=14, weight='bold')

    # Fill empty slots in row 0 (9:15)
    for j in range(9, 15):
        ax = fig.add_subplot(gs[0, j])
        ax.axis('off')

    # Show frame images (rows 1+), in cols 0:10
    for i in range(1, rows + 1):
        for j in range(10):
            idx = (i - 1) * 10 + j
            ax = fig.add_subplot(gs[i, j])
            if idx < len(frame_images):
                ax.imshow(frame_images[idx])
                ax.axis('off')
                ax.text(0.95, 0.05, str(idx + 1), color='white', fontsize=10,
                        transform=ax.transAxes, ha='right', va='bottom', weight='bold')
            else:
                ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def process_frames(subject_id, camera_id, control_img1, control_img2, subject_img, camera_frame_dir, checkpoint, marks):
    import os
    import matplotlib.image as mpimg
    from IPython.display import clear_output

    frame_files = [f for f in os.listdir(camera_frame_dir) if f.endswith('.png')]
    frame_files.sort()
    frame_files = [frame_files[i] for i in range(len(frame_files)) if (i + 1) % 3 == 0]
    batch_size = 500
    total_batches = (len(frame_files) // batch_size) + (1 if len(frame_files) % batch_size != 0 else 0)
    mark_key = f"{subject_id}_{camera_id}"
    batch_idx = 0
    mark_value = None

    while batch_idx < total_batches:
        batch = frame_files[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        frame_images = [mpimg.imread(os.path.join(camera_frame_dir, frame)) for frame in batch]
        # clear_output(wait=True)
        display_images(control_img1, control_img2, subject_img, frame_images)
        
        prompt_text = (
            f"Mark for subject {subject_id}, camera {camera_id}, batch {batch_idx+1}/{total_batches}: "
            "[Enter]=correct, 'm'=incorrect, '2'=two subjects in one, or any key for next batch: "
            if (batch_idx < total_batches - 1) else
            f"Mark for subject {subject_id}, camera {camera_id}, batch {batch_idx+1}/{total_batches}: "
            "[Enter]=correct, 'm'=incorrect, '2'=two subjects in one: "
        )
        user_input = input(prompt_text)

        if user_input == "m":
            mark_value = "incorrect"
            print(f"Marked {subject_id}, {camera_id} as INCORRECT.")
            marks[mark_key] = mark_value
            save_marks_json(marks)
            return  # Immediately jump to next subject/camera

        elif user_input == "2":
            mark_value = "two_subjects"
            print(f"Marked {subject_id}, {camera_id} as TWO SUBJECTS.")
            marks[mark_key] = mark_value
            save_marks_json(marks)
            return  # Immediately jump to next subject/camera

        elif user_input.strip() == "":
            if batch_idx == total_batches - 1:
                mark_value = mark_value or "correct"
                print(f"Marked {subject_id}, {camera_id} as {mark_value.upper()}.")
                marks[mark_key] = mark_value
                save_marks_json(marks)
                return
            batch_idx += 1

        else:
            # For other input, just move to the next batch
            batch_idx += 1

In [None]:
def process_subjects_by_camera(first_root, second_root, checkpoint, marks):
    import matplotlib.image as mpimg
    camera_ids = ['G2302']
    subject_ids = [d for d in os.listdir(second_root) if os.path.isdir(os.path.join(second_root, d))]
    print('Found subject IDs: ', subject_ids)
    start_subject = checkpoint.get('subject_id') if checkpoint else None
    resume = False
    for subject_id in subject_ids:
        if start_subject and subject_id != start_subject and not resume:
            continue
        resume = True
        # Try to find control images in first_root (set1 and set2)
        control_img_path1 = os.path.join(first_root, f"{subject_id}_set1_wb0_1_0.rs-image-5.png")
        control_img_path2 = os.path.join(first_root, f"{subject_id}_set2_wb0_1_0.rs-image-5.png")
        if not (os.path.exists(control_img_path1) and os.path.exists(control_img_path2)):
            print(f"Control images for {subject_id} not found in {first_root}. Skipping...")
            continue
        control_img1 = mpimg.imread(control_img_path1)
        control_img2 = mpimg.imread(control_img_path2)
        subject_frame_dir = os.path.join(second_root, subject_id)
        for camera_id in camera_ids:
            camera_frame_dir = os.path.join(subject_frame_dir, camera_id, 'frames')
            if not os.path.exists(camera_frame_dir):
                print(f"Frames for subject {subject_id} camera {camera_id} not found. Skipping...")
                continue
            frame_files = [f for f in os.listdir(camera_frame_dir) if f.endswith('.png')]
            if not frame_files:
                print(f"No PNG frames for subject {subject_id} camera {camera_id}. Skipping...")
                continue
            # Only annotate if not already marked
            mark_key = f"{subject_id}_{camera_id}"
            if mark_key in marks:
                print(f"Already marked {mark_key}: {marks[mark_key]}. Skipping...")
                continue
            first_frame_path = os.path.join(camera_frame_dir, frame_files[0])
            subject_img = mpimg.imread(first_frame_path)
            # Pass both control images
            process_frames(subject_id, camera_id, control_img1, control_img2, subject_img, camera_frame_dir, checkpoint, marks)

In [None]:
## first_root = '/home/caio.dasilva/datasets/brc2_rotate/'  # Directory with the control images
second_root = '/home/caio.dasilva/datasets/extracted_brc2/'  # Directory with the frames organized by subject/camera/frames

checkpoint = load_checkpoint()
marks = load_marks_json()

process_subjects_by_camera(first_root, second_root, checkpoint, marks)
print('Done')