# Clean dataset lables/images (GT)

In [36]:
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.gridspec as gridspec
from IPython.display import clear_output, display
import pickle
import json
import ipywidgets as widgets
import threading
import time


In [33]:
def save_checkpoint(subject_id, camera, batch_index):
    # Load existing checkpoint data
    if os.path.exists('checkpoint_clean_gt.json'):
        with open('checkpoint_clean_gt.json', 'r') as f:
            checkpoint = json.load(f)
    else:
        checkpoint = {}

    # Update the checkpoint with the new progress
    checkpoint[f"{subject_id}_{camera}"] = {'subject_id': subject_id, 'camera': camera, 'batch_index': batch_index}

    # Save the updated checkpoint
    with open('checkpoint_clean_gt.json', 'w') as f:
        json.dump(checkpoint, f, indent=2)

In [20]:
def load_checkpoint():
    if os.path.exists('checkpoint_clean_gt.json'):
        with open('checkpoint_clean_gt.json', 'r') as f:
            return json.load(f)
    return {}

In [21]:
def save_marks_json(marks, fname='marks2302_2.json'):
    with open(fname, 'w') as f:
        json.dump(marks, f, indent=2)

def load_marks_json(fname='marks2302_2.json'):
    if os.path.exists(fname):
        with open(fname, 'r') as f:
            return json.load(f)
    return {}

In [37]:
def display_images(control_img1, control_img2, subject_img, frame_images):
    rows = (len(frame_images) // 10) + (1 if len(frame_images) % 10 != 0 else 0)
    fig = plt.figure(figsize=(20, (rows + 1) * 2.5))
    gs = gridspec.GridSpec(rows + 1, 15, figure=fig)  # 15 columns: 6 for big images, 9 for frames

    # Show control image 1 in (0,0:2)
    ax = fig.add_subplot(gs[0, 0:3])
    ax.imshow(control_img1)
    ax.axis('off')
    ax.set_title("Control Image (set1)", fontsize=14, weight='bold')

    # Show control image 2 in (0,3:6)
    ax = fig.add_subplot(gs[0, 3:6])
    ax.imshow(control_img2)
    ax.axis('off')
    ax.set_title("Control Image (set2)", fontsize=14, weight='bold')

    # Show subject image in (0,6:9)
    ax = fig.add_subplot(gs[0, 6:9])
    ax.imshow(subject_img)
    ax.axis('off')
    ax.set_title("First Frame", fontsize=14, weight='bold')

    # Fill empty slots in row 0 (9:15)
    for j in range(9, 15):
        ax = fig.add_subplot(gs[0, j])
        ax.axis('off')

    # Show frame images (rows 1+), in cols 0:10
    for i in range(1, rows + 1):
        for j in range(10):
            idx = (i - 1) * 10 + j
            ax = fig.add_subplot(gs[i, j])
            if idx < len(frame_images):
                ax.imshow(frame_images[idx])
                ax.axis('off')
                ax.text(0.95, 0.05, str(idx), color='white', fontsize=10,
                        transform=ax.transAxes, ha='right', va='bottom', weight='bold')
            else:
                ax.axis('off')
    plt.tight_layout()
    plt.show()

In [38]:
def process_frames(subject_id, camera_id, control_img1, control_img2, subject_img, camera_frame_dir, checkpoint, marks):

    def get_sorted_frames():
        files = [f for f in os.listdir(camera_frame_dir) if f.endswith('.png')]
        files.sort()
        return files

    frame_files = get_sorted_frames()
    batch_size = 500
    total_batches = (len(frame_files) // batch_size) + (1 if len(frame_files) % batch_size != 0 else 0)
    mark_key = f"{subject_id}_{camera_id}"
    batch_idx = 0

    while batch_idx < total_batches:
        frame_files = get_sorted_frames()  # Refresh the list each batch
        batch = frame_files[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        frame_images = [mpimg.imread(os.path.join(camera_frame_dir, frame)) for frame in batch]
        display_images(control_img1, control_img2, subject_img, frame_images)
        
        prompt_text = (
            f"Mark for subject {subject_id}, camera {camera_id}, batch {batch_idx+1}/{total_batches}: "
            "'c' for correct (keep all), 'k' to keep a range, or 'r' to remove a range, followed by two numbers (start and end of range): "
        )
        user_input = input(prompt_text).strip().split()

        if len(user_input) == 1 and user_input[0] == 'c':
            print(f"Marked subject {subject_id}, camera {camera_id} as correct for this batch. Showing next batch.")
            batch_idx += 1
            continue

        elif len(user_input) == 3 and user_input[0] in ['k', 'r']:
            action, start, end = user_input[0], int(user_input[1]), int(user_input[2])
            if start < 0 or end >= len(batch) or start > end:
                print("Invalid range. Please try again.")
                continue

            if action == 'r':
                # Remove frames in the specified range (from current batch)
                for idx in range(start, end + 1):
                    try:
                        os.remove(os.path.join(camera_frame_dir, batch[idx]))
                    except Exception as e:
                        print(f"Error removing file {batch[idx]}: {e}")
                print(f"Removed frames {start} to {end} for subject {subject_id}, camera {camera_id}.")
            elif action == 'k':
                # Keep frames in the specified range and remove the rest (from current batch)
                for idx, frame in enumerate(batch):
                    if idx < start or idx > end:
                        try:
                            os.remove(os.path.join(camera_frame_dir, frame))
                        except Exception as e:
                            print(f"Error removing file {frame}: {e}")
                print(f"Kept frames {start} to {end} and removed the rest for subject {subject_id}, camera {camera_id}.")
                return  # Exit after keeping frames

            # Refresh frame_files after removal
            frame_files = get_sorted_frames()
            total_batches = (len(frame_files) // batch_size) + (1 if len(frame_files) % batch_size != 0 else 0)

        else:
            print("Invalid input. Please enter 'c', or 'k'/'r' followed by two numbers (start and end of range).")
            continue

        save_checkpoint(subject_id, camera_id, batch_idx)
        batch_idx += 1

In [39]:
def process_subjects_by_camera(first_root, second_root, checkpoint, marks):
    import matplotlib.image as mpimg
    camera_ids = ['G2302']
    subject_ids = [d for d in os.listdir(second_root) if os.path.isdir(os.path.join(second_root, d))]
    print('Found subject IDs: ', subject_ids)
    start_subject = checkpoint.get('subject_id') if checkpoint else None
    start_camera = checkpoint.get('camera') if checkpoint else None
    resume = False

    for subject_id in subject_ids:
        if start_subject and subject_id != start_subject and not resume:
            continue
        resume = True
        # Try to find control images in first_root (set1 and set2)
        control_img_path1 = os.path.join(first_root, f"{subject_id}_set1_wb0_1_0.rs-image-5.png")
        control_img_path2 = os.path.join(first_root, f"{subject_id}_set2_wb0_1_0.rs-image-5.png")
        if not (os.path.exists(control_img_path1) and os.path.exists(control_img_path2)):
            print(f"Control images for {subject_id} not found in {first_root}. Skipping...")
            continue
        control_img1 = mpimg.imread(control_img_path1)
        control_img2 = mpimg.imread(control_img_path2)
        subject_frame_dir = os.path.join(second_root, subject_id)

        for camera_id in camera_ids:
            mark_key = f"{subject_id}_{camera_id}"
            # Check if this subject/camera is in the checkpoint
            if checkpoint and mark_key in checkpoint:
                print(f"Checkpoint found for {mark_key}, skipping...")
                continue
            camera_frame_dir = os.path.join(subject_frame_dir, camera_id, 'frames')
            if not os.path.exists(camera_frame_dir):
                print(f"Frames for subject {subject_id} camera {camera_id} not found. Skipping...")
                continue
            frame_files = [f for f in os.listdir(camera_frame_dir) if f.endswith('.png')]
            if not frame_files:
                print(f"No PNG frames for subject {subject_id} camera {camera_id}. Skipping...")
                continue
            # Only annotate if not already marked
            if mark_key in marks:
                print(f"Already marked {mark_key}: {marks[mark_key]}. Skipping...")
                continue
            first_frame_path = os.path.join(camera_frame_dir, frame_files[0])
            subject_img = mpimg.imread(first_frame_path)
            process_frames(subject_id, camera_id, control_img1, control_img2, subject_img, camera_frame_dir, checkpoint, marks)

In [None]:
first_root = '/home/caio.dasilva/datasets/brc2_rotate/'  # Directory with the control images
second_root = '/home/caio.dasilva/datasets/extracted_brc2/'  # Directory with the frames organized by subject/camera/frames

checkpoint = load_checkpoint()
marks = load_marks_json()

process_subjects_by_camera(first_root, second_root, checkpoint, marks)
print('Done')