In [None]:
# Takes in YOLO dataset and produces a cropped images dataset with species-level classifications
!pip install opencv-python

# yolo_images_root = Path("/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-yolo-dataset/images")
# yolo_labels_root = Path("/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-yolo-dataset/labels")

import cv2
import os
from tqdm import tqdm
import glob
from pathlib import Path

In [None]:
def parse_annotation(annotation_file):
    with open(annotation_file, 'r') as file:
        lines = file.readlines()
    annotations = []
    for line in lines:
        parts = line.strip().split()
        class_id = int(parts[0])
        x_center = float(parts[1])
        y_center = float(parts[2])
        width = float(parts[3])
        height = float(parts[4])
        annotations.append((class_id, x_center, y_center, width, height))
    return annotations

def yolo2cv_bbox(yolo_bbox, width, height):
    class_id, x_center, y_center, box_width, box_height = yolo_bbox
    x_min = int((x_center - box_width / 2) * width)
    x_max = int((x_center + box_width / 2) * width)
    y_min = int((y_center - box_height / 2) * height)
    y_max = int((y_center + box_height / 2) * height)
    return (x_min, x_max, y_min, y_max)

def annotate_image(image_dir, image_rel_path, annotation, output_dir):
    image = cv2.imread(os.path.join(image_dir, image_rel_path))
    val_image = image.copy()
    height, width, _ = image.shape
    
    for i, annot in enumerate(annotation):
        x_min, x_max, y_min, y_max = yolo2cv_bbox(annot, width, height)
        val_image = cv2.rectangle(val_image, (x_min, y_min), (x_max, y_max), (0,0,255), 2)

    output_path = os.path.join(output_dir, image_rel_path)
    os.makedirs(Path(output_path).parent, exist_ok=True)
    cv2.imwrite(output_path, val_image)

def crop_image(image_dir, image_rel_path, annotations, output_dir, crop_square=True):
    image = cv2.imread(os.path.join(image_dir, image_rel_path))
    height, width, _ = image.shape
    
    for i, (class_id, x_center, y_center, box_width, box_height) in enumerate(annotations):
        x_min = int((x_center - box_width / 2) * width)
        x_max = int((x_center + box_width / 2) * width)
        y_min = int((y_center - box_height / 2) * height)
        y_max = int((y_center + box_height / 2) * height)

        cropped_image = image[y_min:y_max, x_min:x_max, :]
        output_path = os.path.join(output_dir, f"{os.path.splitext(image_rel_path)[0]}_crop_{i}_class{class_id}.png")
        os.makedirs(Path(output_path).parent, exist_ok=True)
        cv2.imwrite(output_path, cropped_image)

def process_yolo_dataset(image_dir, annotation_dir, output_dir, img_type=".png", validation_dir=None):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    img_paths = glob.glob(os.path.join(image_dir,f"**/*{img_type}"), recursive=True)

    for image_file in tqdm(img_paths):
        image_rel_path = os.path.relpath(image_file, image_dir)
        annotation_file = os.path.join(annotation_dir, os.path.splitext(image_rel_path)[0] + '.txt')

        if os.path.exists(annotation_file):
            annotations = parse_annotation(annotation_file)
            crop_image(image_dir, image_rel_path, annotations, output_dir)
            annotate_image(image_dir, image_rel_path, annotations, validation_dir)


In [None]:
# Perform the crop extraction        
# Example usage
# image_directory = '/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-yolo-dataset/images'
# annotation_directory = '/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-yolo-dataset/labels'
# output_directory = '/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-classification-dataset/'
# validation_output_directory = '/srv/warplab/shared/datasets/WHOI_RS_Fish_Detector/whoi-rsi-fish-detection-species-classification-dataset-validation'

image_directory = '/media/data/warp_data/wrsi-datasets/whoi-rsi-fish-detection-species-yolo-dataset/images'
annotation_directory = '/media/data/warp_data/wrsi-datasets/whoi-rsi-fish-detection-species-yolo-dataset/labels'
output_directory = '/media/data/warp_data/wrsi-datasets/whoi-rsi-fish-detection-species-crops-dataset'
validation_output_directory = '/media/data/warp_data/wrsi-datasets/whoi-rsi-fish-detection-species-validations'

process_yolo_dataset(image_directory, annotation_directory, output_directory, validation_dir=validation_output_directory)
print("done")

In [None]:
import labelbox as lb
import labelbox.types as lb_types
import uuid
import base64
import requests

# Setup Labelbox naming and ID schemes

# Setup client
with open("labelbox_api_key.txt","r") as f:
    API_KEY = f.read().strip()
client = lb.Client(api_key=API_KEY)

# Get ontology
print("===ONTOLOGY DETAILS===")
ontology = client.get_ontology("clqo6bd8v0jc407ybc1r9ehlb")
print("Name: ", ontology.name)
tools = ontology.tools()

# for tool in tools:
#   print(tool)

# Get project
print("\n===PROJECT DETAILS===")
PROJECT_ID = 'clqoh3ylw1o8s070hd6ch5z7o' # WHOI RSI USVI Fish
# PROJECT_ID = 'clqo7auln0mpo07wphorp0t2e' # Test WHOI RSI USVI Fish
project = client.get_project(PROJECT_ID)
print("Name: ", project.name)

# Get dataset
DATASET_ID = "clqh7v7qi001r07886j6aws7i"
dataset = client.get_dataset(DATASET_ID)
print("\n===DATASET DETAILS===")
print("Name: ", dataset.name)

In [None]:
# Enumerate species labels for YOLO class formatting
classes_option = {}
classes_enum = {}
ordered_class_names = []

classes_option["fish"] = {"label": "Fish", "value": "fish"}
classes_enum["fish"] = 0
ordered_class_names.append("fish")

for option_num, option in enumerate(tools[0].classifications[0].options):
    classes_option[option.value] = option
    classes_enum[option.value] = len(ordered_class_names)
    ordered_class_names.append(option.value)
print(ordered_class_names)

In [None]:
# Print class-based statistics and their corresponding names
# (What size, how many of each)

from pathlib import Path

cls_counts = [0]*len(ordered_class_names)

cls_images = dict(zip(ordered_class_names,[[] for _ in range(len(ordered_class_names))]))

def crop_stats(image_dir, img_type=".png"):
    img_paths = glob.glob(os.path.join(image_dir,f"**/*{img_type}"), recursive=True)

    for img_path in tqdm(img_paths):
        pth = Path(img_path)
        img_stem = pth.stem

        # Assumes filename format: frame_<num>_crop_<instance num>_class<id>.png
        cls_id = int(img_stem.split("class")[-1])
        cls_counts[cls_id] += 1

        cls_name = ordered_class_names[cls_id]
        
        cls_images[cls_name].append(img_path)

crops_directory = '/media/data/warp_data/wrsi-datasets/whoi-rsi-fish-detection-species-crops-dataset'

crop_stats(crops_directory)

cls_counts_list = list(zip(ordered_class_names, cls_counts))
cls_counts_list = sorted(cls_counts_list, key=lambda x: x[1], reverse=True)
print(cls_counts_list)


In [None]:
# Plot all the crops of a particular class
!pip install torchvision

from torchvision.utils import make_grid
from torchvision.io import read_image
from torchvision import transforms
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F

plt.rcParams["savefig.bbox"] = 'tight'

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

def grid_from_crop_paths(crop_paths, output_path = None, img_transforms = None, output_img_size = (144,144), display_image = False):
    """
    Loads and merges list of image paths into image grid and save them, where images may not be the same size. Provided transform must output the same size image though.
    """
    if len(crop_paths) == 0:
        return
    
    # Compose transforms
    if img_transforms is None:
        img_transforms = transforms.Compose([transforms.Resize(output_img_size)])

    # Load and standardize images
    imgs = []
    for img_path in img_paths:
        try:
            img = read_image(img_path)
        except:
            print("Error reading: ", img_path)
            continue
        img = img_transforms(img)
        imgs.append(img)

    # Make grid
    grid = make_grid(imgs)

    # Show grid
    if display_image:
        show(grid)

    # Save grid
    if output_path:
        grid = grid.permute(1, 2, 0).numpy() 
        #grid = (grid * 255).astype(np.uint8)
        plt.imsave(output_path, grid)
    
    return grid

# Make all the grid images for each class
class_grids_directory = "/media/data/warp_data/wrsi-datasets/whoi-rsi-fish-detection-species-crops-validation-grids"
os.makedirs(class_grids_directory, exist_ok=True)

for cls_name, cls_count in tqdm(cls_counts_list):
    if cls_count > 0:
        # Get image paths
        img_paths = cls_images[cls_name]
            
        # Save the gridded images
        gridded_img_output_directory = class_grids_directory
        gridded_img_name = f"grid_{len(img_paths)}_" + cls_name + f".png"
        gridded_img_path = os.path.join(gridded_img_output_directory, gridded_img_name)

        grid_from_crop_paths(img_paths, output_path = gridded_img_path)

print("done")