In [136]:
# Use a pipeline as a high-level helper
from transformers import DetrImageProcessor, DetrForObjectDetection
import os
import cv2
from PIL import Image
import torch
import numpy as np
from jupyter_bbox_widget import BBoxWidget
import json
import ipywidgets as widgets
from tqdm import tqdm


In [162]:
base_dir = "ball2_soft"
output_base_dir = os.path.join(base_dir, "output")
image_dir = os.path.join(base_dir, "IMAGE")
depth_dir = os.path.join(base_dir, "DEPTH")
visualize = False

if not os.path.exists(output_base_dir):
    os.makedirs(output_base_dir)
if not os.path.exists(image_dir):
    os.makedirs(image_dir)
if not os.path.exists(depth_dir):
    os.makedirs(depth_dir)

In [163]:
# Organize directory
filenames = sorted(os.listdir(base_dir))

for f in filenames:
    if "depth" in f:
        os.rename(os.path.join(base_dir, f), os.path.join(depth_dir, f))
    elif "left" in f:
        os.rename(os.path.join(base_dir, f), os.path.join(image_dir, f))

In [164]:
annotations = {}
find_manually = []

In [165]:
images = os.listdir(image_dir)
for idx, image_path in tqdm(enumerate(images)):
    full_image_path = os.path.join(image_dir, image_path)
    image2 = cv2.imread(full_image_path)
    image = Image.open(full_image_path).convert("RGB")

    # you can specify the revision tag if you don't want the timm dependency
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)

    # convert outputs (bounding boxes and class logits) to COCO API
    # let's only keep detections with score > 0.9
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.1)[0]

    temp_box = None
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        if model.config.id2label[label.item()] == "orange" or model.config.id2label[label.item()] == "sports ball":
            temp_box = box
            annotations[image_path] = box
            # image2 = cv2.rectangle(image2, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 0, 0), 2)
            # plt.imshow(image2,cmap='gray')
            # plt.show()
            
    if temp_box == None:
        find_manually.append(image_path)


46it [01:03,  1.37s/it]


In [166]:
manual_annotations = {}

In [167]:
# a progress bar to show how far we got
w_progress = widgets.IntProgress(value=0, max=len(find_manually), description='Progress')
# the bbox widget
w_bbox = BBoxWidget(
    image = os.path.join(image_dir, find_manually[0])
)

# combine widgets into a container
w_container = widgets.VBox([
    w_progress,
    w_bbox,
])

IndexError: list index out of range

In [168]:
# when Skip button is pressed we move on to the next file
@w_bbox.on_skip
def skip():
    w_progress.value += 1
    # open new image in the widget
    image_file = find_manually[w_progress.value]
    w_bbox.image = os.path.join(image_dir, image_file)
    # here we assign an empty list to bboxes but 
    # we could also run a detection model on the file
    # and use its output for creating inital bboxes
    w_bbox.bboxes = [] 
    if image_file not in manual_annotations:
        manual_annotations[image_file] = w_bbox.bboxes

# when Submit button is pressed we save current annotations
# and then move on to the next file
@w_bbox.on_submit
def submit():
    image_file = find_manually[w_progress.value]
    # save annotations for current image
    manual_annotations[image_file] = w_bbox.bboxes
    # with open("full_annotations.json", 'w') as f:
    #     json.dump(manual_annotations, f, indent=4)
    # move on to the next file
    skip()

In [169]:
w_container

VBox(children=(IntProgress(value=0, description='Progress', max=1), BBoxWidget(colors=['#1f77b4', '#ff7f0e', '…

In [170]:
edited_manual_annotations = {}
for key in manual_annotations:
    data = manual_annotations[key]
    if len(data) == 0:
        val = []
    else:
        data = data[0]
        val = [data['x'], data['y'], data['x'] + data['width'], data['y'] + data['height']]
    edited_manual_annotations[key] = val

In [171]:
full_annotations = edited_manual_annotations | annotations

In [172]:
with open(os.path.join(output_base_dir, f"{base_dir}_full_annotations.json"), "w") as f:
    json.dump(full_annotations, f, indent=4)

In [173]:
import json
import cv2
import os

def draw_center_on_image(image_path, json_data, output_dir):
    # Load the image from the given path
    image = cv2.imread(image_path)
    
    # Extract the filename from the image path (assumes the image path includes the filename)
    filename = os.path.basename(image_path)
    
    # Check if the filename is in the JSON data
    if filename in json_data and len(json_data[filename]) > 0:
        
        # Get the bounding box coordinates from the JSON data
        x1, y1, x2, y2 = json_data[filename]
        
        # Calculate the center of the bounding box
        center_x = int((x1 + x2) // 2)
        center_y = int((y1 + y2) // 2)
        
        # Draw a red dot (center) on the image
        cv2.circle(image, (center_x, center_y), 5, (0, 0, 255), -1)  # Red circle with radius 5
        
        
        # Save the modified image to the output directory
        output_image_path = os.path.join(output_dir, filename)
        cv2.imwrite(output_image_path, image)
    else:
        print(f"Bounding box not found for {filename} in the JSON data.")

# Example usage:
image_dir = os.path.join(base_dir, "IMAGE")  # Path to the input image
json_file = os.path.join(output_base_dir, f"{base_dir}_full_annotations.json")  # Path to the JSON file with bounding boxes
if visualize:
    output_dir = os.path.join(base_dir, f"{base_dir}_VIS_ANN")  # Directory to save the modified images
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # Load the JSON data
    with open(json_file, 'r') as f:
        bounding_boxes = json.load(f)

    # Call the function to process the image
    for image_path in tqdm(os.listdir(image_dir)):
        image_path = os.path.join(image_dir, image_path)
        draw_center_on_image(image_path, bounding_boxes, output_dir)


100%|██████████| 46/46 [00:05<00:00,  7.86it/s]


In [174]:
import json
import re
def cvt(key):
    match = re.match(r"left(\d+).png", key)
    if match:
        # Convert the matched string to an integer and return
        return int(match.group(1))
    return -1

with open(json_file) as f:
    data = json.load(f)
    
new_data = {}

for key in sorted(data.keys()):
    new_key = int(cvt(key))
    new_data[new_key] = data[key]
    
with open(json_file, "w") as f:
    json.dump(new_data, f, indent=4)