In [34]:
# Takes in YOLO dataset and produces a cropped images dataset with species-level classifications
!pip install opencv-python

# yolo_images_root = Path("/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-yolo-dataset/images")
# yolo_labels_root = Path("/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-yolo-dataset/labels")

import cv2
import os
from tqdm import tqdm
import glob
from pathlib import Path

def parse_annotation(annotation_file):
    with open(annotation_file, 'r') as file:
        lines = file.readlines()
    annotations = []
    for line in lines:
        parts = line.strip().split()
        class_id = int(parts[0])
        x_center = float(parts[1])
        y_center = float(parts[2])
        width = float(parts[3])
        height = float(parts[4])
        annotations.append((class_id, x_center, y_center, width, height))
    return annotations

def yolo2cv_bbox(yolo_bbox, width, height):
    class_id, x_center, y_center, box_width, box_height = yolo_bbox
    x_min = int((x_center - box_width / 2) * width)
    x_max = int((x_center + box_width / 2) * width)
    y_min = int((y_center - box_height / 2) * height)
    y_max = int((y_center + box_height / 2) * height)
    return (x_min, x_max, y_min, y_max)

def annotate_image(image_dir, image_rel_path, annotation, output_dir):
    image = cv2.imread(os.path.join(image_dir, image_rel_path))
    val_image = image.copy()
    height, width, _ = image.shape
    
    for i, annot in enumerate(annotation):
        x_min, x_max, y_min, y_max = yolo2cv_bbox(annot, width, height)
        val_image = cv2.rectangle(val_image, (x_min, y_min), (x_max, y_max), (0,0,255), 2)

    output_path = os.path.join(output_dir, image_rel_path)
    os.makedirs(Path(output_path).parent, exist_ok=True)
    cv2.imwrite(output_path, val_image)

def crop_image(image_dir, image_rel_path, annotations, output_dir, crop_square=True):
    image = cv2.imread(os.path.join(image_dir, image_rel_path))
    height, width, _ = image.shape
    
    for i, (class_id, x_center, y_center, box_width, box_height) in enumerate(annotations):
        x_min = int((x_center - box_width / 2) * width)
        x_max = int((x_center + box_width / 2) * width)
        y_min = int((y_center - box_height / 2) * height)
        y_max = int((y_center + box_height / 2) * height)

        cropped_image = image[y_min:y_max, x_min:x_max, :]
        output_path = os.path.join(output_dir, f"{os.path.splitext(image_rel_path)[0]}_crop_{i}_class{class_id}.png")
        os.makedirs(Path(output_path).parent, exist_ok=True)
        cv2.imwrite(output_path, cropped_image)

def process_yolo_dataset(image_dir, annotation_dir, output_dir, img_type=".png", validation_dir=None):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    img_paths = glob.glob(os.path.join(image_dir,f"**/*{img_type}"), recursive=True)

    for image_file in tqdm(img_paths):
        image_rel_path = os.path.relpath(image_file, image_dir)
        annotation_file = os.path.join(annotation_dir, os.path.splitext(image_rel_path)[0] + '.txt')

        if os.path.exists(annotation_file):
            annotations = parse_annotation(annotation_file)
            crop_image(image_dir, image_rel_path, annotations, output_dir)
            annotate_image(image_dir, image_rel_path, annotations, validation_dir)
        
# Example usage
image_directory = '/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-yolo-dataset/images'
annotation_directory = '/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-yolo-dataset/labels'
output_directory = '/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-classification-dataset/'
validation_output_directory = '/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-classification-dataset-validation'

process_yolo_dataset(image_directory, annotation_directory, output_directory, validation_dir=validation_output_directory)
print("done")



  1%|▉                                                                                                  | 100/10620 [00:18<32:28,  5.40it/s]

done



