# YOLO style dataset parser

### The generated output includes the following annotation data:
* Classes
* Bounding Boxes

### Example application(s) (as demonstrated in Plum et al. 2023):
* [YOLO v4](https://github.com/AlexeyAB/darknet) _(multi-animal detection and tracking)_

### Output structure:
* target_dir
    * data
      * obj
        * images
        * annotation files
      * ###-train.txt _(contains list of training samples)_
      * ###-test.txt _(contains list of validation samples)_
      * obj.data _(relative dataset path info)_
      * obj.names _(all class names)
    
### Notes:

* While we used [YOLO v4](https://github.com/AlexeyAB/darknet) in all our examples, the dataset structure is generally applicable to most YOLO detectors (excluding those that additionally infer pose e.g. YOLO v7)
* We also demonstrate how trained models can be used for real-time multi animal tracking applications in our [Blender Multi-Animal Tracking and Pose Estimation Addon - **OmniTrax**](https://github.com/FabianPlum/OmniTrax)

In [None]:
import cv2
import numpy as np
import matplotlib as plt
import pathlib

from os import listdir
from os.path import isfile, join

import json
import threading
import queue
import sys
import os
import time

### Required parameters

Specify the location of your **generated dataset** and in which **output directory** you wish to save it.

**Notes:**
* do not include trailing forward slashes in your paths (see examples below)
* Your **dataset** name should **NOT include underscores** as they are used to separate passes into their categories. Instead, use hyphens in your naming convention where applicable.

In [None]:
# define location of dataset and return all files
dataset_location = "../example_data/input-multi"
target_dir = "../example_data/YOLO"

# Specify which labels to ignore. By default, all keypoints are used to recompute bounding boxes.
# In this example we omit all keypoints relating to wings. Refer to the base_rig documentation for naming conventions
omit_labels = ['w_1_l', 'w_1_l_end', 'w_2_l', 'w_2_l_end', 'w_1_r', 'w_1_r_end', 'w_2_r', 'w_2_r_end']

# The following options are only relevant for performing cross-validation with specific naming conventions.
# if the dataset_name is left unspecified (= None), the name will be derived from the original BatchData file
dataset_name = "eq"
dataset_img_name = "Atta-hist-eq"

### Optional parameters

In [None]:
# set True to show processing results for each image (disables parallel processing)
DEBUG = False

enforce_single_class = False # overwrites multiple classes and groups all instances as one

cross_validation_split = [5,0]

# determine the proportion of a bounding box that needs to be filled before considering the visibility as too low
# WARNING: At the moment the ID shown in segmentation maps does not always correspond to the ID in the data file (off by 1)
# only use detections where at least 5% of the bounding box is occupied by the respective subject
visibility_threshold = 0.05

# we can additionally plot the points in the data files to check joint locations
plot_joints = False

# remember to refine an export folder when saving out your dataset
generate_dataset = True

# we can enforce the bounding box to centre on the individual instead of being influenced by its orientation
# As the groundtruth in real recordings is annotated in the same way this should boost the average accuracy
# When enforcing a centred bounding box, list between (on which) which keypoints(s) the centre is supposed to be placed.
# In the default insect rig, the centre would lie between the thorax b_t and abdomen b_a_1 keypoint
enforce_centred_bboxes = False
enforced_centre = ["b_t", "b_a_1"]

# alternatively, we can draw tighter bounding boxes without enforced centres, based on 2D keypoints listed in "labels"
enforce_tight_bboxes = False # centred OR tight. This option will overwrite "enforce_centred" if True

The following lines will load the generated dataset from your drive and prepare it for the multi-threaded parsing process

In [None]:
all_files = [f for f in listdir(dataset_location) if isfile(join(dataset_location, f))]

# next, sort files into images, depth maps, segmentation maps, data, and colony info
# we only need the location and name of the data files, as all passes follow the same naming convention
dataset_data = []
dataset_img = []
dataset_ID = []
dataset_depth = []
dataset_norm = []
dataset_colony = None

for file in all_files:
    loc = dataset_location + "/" + file
    file_info = file.split("_")
    
    if file_info[1] == "BatchData":
        dataset_colony = loc
        
    elif len(file_info) == 2:
        # images are available in various formats, but annotation data is always written as json files
        if file_info[-1].split(".")[-1] == "json":
            dataset_data.append(loc)
        else:
            dataset_img.append(loc)
            
    elif file_info[2].split(".")[0] == "ID":
        dataset_ID.append(loc)
    elif file_info[2].split(".")[0]  == "depth":
        dataset_depth.append(loc)
    elif file_info[2].split(".")[0]  == "norm":
        dataset_norm.append(loc)
        
print("Found",len(dataset_data),"samples...")

# next sort the colony info into its IDs to determine the colony size and individual scales
# Opening colony (BatchData) JSON file
colony_file = open(dataset_colony)
 
# returns JSON object as a dictionary
colony = json.load(colony_file)
colony_file.close()

if not enforce_single_class:
    # get provided classes to create a dictionary of class IDs and class names
    all_classes = []
    for subject in colony["Subject Variations"]:
        all_classes.append(colony["Subject Variations"][subject]["Class"])
        
    subject_class_names = []
    for class_name in all_classes:
        # check if exists in unique_list or not and replace any spaces with underscores
        class_name = class_name.replace(" ", "_")
        if class_name not in subject_class_names:
            # append unique classes 
            subject_class_names.append(class_name)
            
    # sort classes for repeatability between datasets
    subject_class_names.sort()
        
    subject_classes = {}
    for id,sbj in enumerate(subject_class_names):
        subject_classes[str(sbj)] = id
else:
    subject_class_names = np.array([0])
    subject_classes = {"insect" : 0}
    subject_classes = dict(sorted(subject_classes.items()))

print("\nA total of",len(subject_class_names),"unique classes have been found.")
print("The classes and respective class IDs are:\n\n",subject_classes)

print("\nLoaded colony file with seed", colony['Seed']) #,"and",len(colony['ID']),"individuals.")

Now that we have the cleaned colony info, we can start loading the data associated with each frame.
For simplicity we will simply make this a list of lists as the number of individuals.

We will therefore access "data" as [frame] [individual] [attribute], where attributes will include [ID,bbox_x_0,bbox_y_0,...]

In this instance, we use the bounding box of each individual to first determine its visibility through its occpuancy in the ID pass and second produce a cropped sample for every individual, including its keypoints in the cropped ROI.

As there may be animals for which we don't use all bones we can return a list of all labels and exclude the respective locations from the pose data. As all animals use the same convention, we can simply read in one example and remove the corresponding indices from all animals.
For simplicity we'll assume that at this stage all subjects use the same armature and therefore report the same keypoints.
We therefore load the first sample from the list and find the subjects keypoint hierarchy

In [None]:
sample_file = open(dataset_data[0])

# returns JSON object as a dictionary
sample = json.load(sample_file)
sample_file.close()

first_entry_key = list(sample["iterationData"]["subject Data"][0].keys())[0]
labels = list(sample["iterationData"]["subject Data"][0][first_entry_key]["keypoints"].keys())

# show all used labels:
print("\nAll labels:  ",labels)

print("\nOmitting labels:  ", omit_labels)

# removing all occurences of omitted labels from the labels list to be used as keys below
labels = [x for x in labels if x not in omit_labels]

print("\nUsing labels:  ", labels)

With all dataset related parameters configured, we have provided a multi-threaded parsing solution below to minimise the processing time it takes to bring the entire dataset into the required output format. Currently, we instanciate one processing thread per (virtual) CPU core but you can adjust this value if you wish by changing:

```
threadList_export = createThreadList(#NumDesiredThreads)
```

**Note:** To see the process of mask generation from ID passes in action, set the **DEBUG** mode to **"True"**. This will however slow down the processing speed considerably and only run in single-threaded mode!

In [None]:
def fix_bounding_boxes(coords,max_val=[1024,1024],ind_key=None,labels=None):
    # fix bounding box coordinates so they do not reach beyond the image
    # you can either pass only bounding box coordinates or the entire skeleton coordinates
    # The latter will recalculate a tighter bounding box, based on all keypoints
    # When recalculating the bounding box based on all keypoints, you can chose to ignore wings.
    fixed_coords = []
    
    if len(coords) == 4:
        coords_bbox = coords[:4]
    
    else:
        coords_bbox = [0,0,max_val[0],max_val[1]]
        # get all X and Y coordinates to find min and max values for the bounding box
        key_x = []
        key_y = []
        
        for key in labels:
            key_x.append(coords[ind_key]["keypoints"][key]["2DPos"]["x"])
            key_y.append(coords[ind_key]["keypoints"][key]["2DPos"]["y"])
        
        coords_bbox[0] = max([0,min(key_x)])
        coords_bbox[1] = max([0,min(key_y)])
        coords_bbox[2] = min([max_val[0],max(key_x)])
        coords_bbox[3] = min([max_val[1],max(key_y)])
    
    for c, coord in enumerate(coords_bbox):
        if c == 0 or c == 2:
            max_val_temp = max_val[1]
        else:
            max_val_temp = max_val[0]
            
        if coord >= max_val_temp:
            coord = max_val_temp
        elif coord <= 0:
            coord = 0
        
        fixed_coords.append(int(coord))
        
    return fixed_coords

def getThreads():
    """ Returns the number of available threads on a posix/win based system """
    if sys.platform == 'win32':
        return int(os.environ['NUMBER_OF_PROCESSORS'])
    else:
        return int(os.popen('grep -c cores /proc/cpuinfo').read())

class customThread(threading.Thread):
    def __init__(self, threadID, name, q):
        threading.Thread.__init__(self)
        self.threadID = threadID
        self.name = name
        self.q = q

    def run(self):
        print("Starting " + self.name)
        process_detections(self.name, self.q)
        print("Exiting " + self.name)
        
def createThreadList(num_threads):
    threadNames = []
    for t in range(num_threads):
        threadNames.append("Thread_" + str(t))

    return threadNames

def process_detections(threadName, q):
    while not exitFlag:
        queueLock.acquire()
        if not workQueue.empty():
            
            data_input = q.get()
            i, data_loc, img, ID = data_input
            queueLock.release()
            
            display_img = cv2.imread(img)
            display_img_out = display_img.copy()
            
            # compute visibility for each individual from ID pass
            seg_img = cv2.imread(ID)
            
            data_file = open(data_loc)
            # returns JSON object as a dictionary
            data = json.load(data_file)
            data_file.close()

            if generate_dataset:
                img_info = []
            
            # check if the size of the image and segmentation pass match
            if display_img.shape != seg_img.shape:
                print("Size mismatch of image and segmentation pass for sample",data_input[1].split("/")[-1],"!")
            else:
                individual_visible = False
                
                for individual in data["iterationData"]["subject Data"]:
                    ind_key = list(individual.keys())[0]
                    ind_ID = int(ind_key)

                    fontColor = (int(ID_colours[ind_ID,0]),
                                 int(ID_colours[ind_ID,1]),
                                 int(ID_colours[ind_ID,2]))
                    
                    bbox_orig = [individual[ind_key]["2DBounds"]["xmin"],
                                 individual[ind_key]["2DBounds"]["ymin"],
                                 individual[ind_key]["2DBounds"]["xmax"],
                                 individual[ind_key]["2DBounds"]["ymax"]]
                    
                    if enforce_tight_bboxes:
                        bbox = fix_bounding_boxes(individual, max_val=display_img.shape, ind_key=ind_key, labels=labels)
                    else:
                        bbox = fix_bounding_boxes(bbox_orig, max_val=display_img.shape)
                        
                    # only process an individual if its bounding box width and height are not zero
                    if bbox[2] - bbox[0] == 0 or bbox[3] - bbox[1] == 0:
                        continue

                    try:
                        ID_mask = cv2.inRange(seg_img[bbox[1]:bbox[3],bbox[0]:bbox[2]], np.array([0, 0, ind_ID]), np.array([0, 0, ind_ID]))
                        indivual_occupancy = cv2.countNonZero(ID_mask)
                    except:
                        if len(threadList) == 1: 
                            print("Individual fully occluded:",ind_ID,"in",dataset_seg[i])
                        indivual_occupancy = 1
                    
                    #indivual_occupancy = np.count_nonzero((seg_img == [0, 0, int((individual[0]/len(colony['ID']))*255)]).all(axis = 2)) + np.count_nonzero((seg_img == [0, 0, int((individual[0]/len(colony['ID']))*255 - 1)]).all(axis = 2)) + np.count_nonzero((seg_img == [0, 0, int((individual[0]/len(colony['ID']))*255 + 1)]).all(axis = 2))
                    bbox_area = abs((bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) + 1
                    bbox_occupancy = indivual_occupancy / bbox_area
                    #print("Individual", individual[0], "with bounding box occupancy ",bbox_occupancy)
                    
                    cv2.putText(display_img, "ID: " + str(ind_ID), (bbox[0] + 10,bbox[3] - 10), font, fontScale, fontColor, lineType)
                    if not enforce_single_class:
                        class_ID = subject_classes[colony['Subject Variations'][ind_key]["Class"].replace(" ","_")]
                    else:
                        # here we use a single class, otherwise this can be replaced by size / scale values
                        class_ID = 0
                    
                    if bbox_occupancy > visibility_threshold:
                        
                        individual_visible = True
                        
                        if generate_dataset:
                            # now we need to convert the bounding box info into the desired format.
                            img_dim = display_img.shape
                            
                            # [class_ID, centre_x, centre_y, bounding_box_width, bounding_box_height]
                            valid_new_x = False
                            valid_new_y = False

                            if enforce_centred_bboxes:
                                # coords of head
                                centre_points = []
                                for key in enforced_centre:
                                    centre_points.append([individual[ind_key]["keypoints"][key]["2DPos"]["x"],
                                                         individual[ind_key]["keypoints"][key]["2DPos"]["y"]])
                                
                                centre_points_arr = np.array(centre_points)
                                # compute new centre point
                                new_centre_x = np.mean(centre_points_arr[:,0])
                                new_centre_y = np.mean(centre_points_arr[:,1])

                                if new_centre_x < img_dim[1] and new_centre_x > 0:
                                    centre_x = new_centre_x / img_dim[1]
                                    valid_new_x = True

                                if new_centre_y < img_dim[0] and new_centre_y > 0:
                                    centre_y = new_centre_y / img_dim[0]
                                    valid_new_y = True

                                cv2.circle(display_img, (int(new_centre_x),int(new_centre_y)), 
                                           radius=3, color=fontColor, thickness=-1)    

                            for label in range(int((len(individual)-5)/5)):
                                cv2.circle(display_img, (int(individual[label*5+5]),
                                                         int(individual[label*5+6])), 
                                           radius=3, color=fontColor, thickness=-1)    


                            bounding_box_width = abs(bbox[2] - bbox[0]) / img_dim[1]
                            bounding_box_height = abs(bbox[3] - bbox[1]) / img_dim[0]

                            if not valid_new_x or not valid_new_y:
                                centre_x = bbox[0] / img_dim[1] + bounding_box_width / 2
                                centre_y = bbox[1] / img_dim[0] + bounding_box_height / 2

                            img_info.append([class_ID,centre_x,centre_y,bounding_box_width,bounding_box_height])

                            cv2.rectangle(display_img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), fontColor, 2)

                    else:
                        pass
                        if len(threadList) == 1: 
                            print("Ah shit, can't see",ind_ID,class_ID)

                    if generate_dataset and individual_visible:

                        img_name = target_dir + "/data/obj/" + img.split('/')[-1][:-4] + "_" + dataset_img_name + ".JPG"
                        cv2.imwrite(img_name, display_img_out)

                        with open(target_dir + "/data/obj/" + img.split('/')[-1][:-4] + "_" + dataset_img_name + ".txt", "w") as f: 
                            output_txt = []
                            if img_info:
                                for line in img_info:
                                    line_str = ' '.join([str(i) for i in line])
                                    output_txt.append(line_str+"\n")
                                f.writelines(output_txt)
                            else:
                                f.write("")

                if len(threadList) == 1:
                    cv2.imshow("labeled image", cv2.resize(display_img, (int(display_img.shape[1] / 2), 
                                                                         int(display_img.shape[0] / 2))))
                    cv2.waitKey(0)

        else:
            queueLock.release()
            
# setup as many threads as there are (virtual) CPUs
exitFlag = 0
if DEBUG:
    threadList = createThreadList(1)
else:
    threadList = createThreadList(getThreads())
print("Using", len(threadList), "threads to parse data...")
queueLock = threading.Lock()

# define paths to all images and set the maximum number of items in the queue equivalent to the number of images
workQueue = queue.Queue(len(dataset_data))
threads = []
threadID = 1


np.random.seed(seed=1)
# once colony size can be read from the BatchData file, set the size of ID_colours equal to the colony size
ID_colours = np.random.randint(255, size=(255, 3))

font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 0.5
lineType = 2

if generate_dataset:
    from helper.Generate_YOLO_training import createCustomFiles
    createCustomFiles(output_folder=target_dir+"/",obIDs=subject_class_names, k_fold=cross_validation_split)

timer = time.time()

# Create new threads
for tName in threadList:
    thread = customThread(threadID, tName, workQueue)
    thread.start()
    threads.append(thread)
    threadID += 1

# Fill the queue with samples
queueLock.acquire()
for i, (data, img, ID) in enumerate(zip(dataset_data , dataset_img, dataset_ID)):
    workQueue.put([i, data, img, ID])
queueLock.release()

# Wait for queue to empty
while not workQueue.empty():
    pass

# Notify threads it's time to exit
exitFlag = 1

# Wait for all threads to complete
for t in threads:
    t.join()
print("Exiting Main Thread")

# close all windows if they were opened
cv2.destroyAllWindows()

print("Total time elapsed:",time.time()-timer,"seconds")

Generating cross validation splits:

In [None]:
if dataset_name is None:
    dataset_name = colony["name"]

from helper.Generate_YOLO_training import createCustomFiles
for i in range(cross_validation_split[0]):
    print("Generating split", i + 1, "...")
    createCustomFiles(output_folder=target_dir+"/",
                      obIDs=subject_class_names,
                      k_fold=[cross_validation_split[0],i],
                      custom_name=dataset_name+str(i+1))

And finally, displaying example detections:

In [None]:
# show example detection
ID_colours = np.random.randint(255, size=(255, 3))
example_sample_path = "../example_data/YOLO/data/obj/input-multi_01_example_multi.JPG"

test_img = cv2.imread(example_sample_path)
test_img_height,test_img_width = test_img.shape[:-1]

with open(example_sample_path[:-3] + "txt") as f:
    lines = f.readlines()
    for l, line in enumerate(lines):
        line = line.split(" ")
        # the colours do not correspond to their original IDs as they are not present in this dataset style
        # different colours are only used for visualisation purposes to better distinquish between adjacent bounding boxes
        ind_color = (int(ID_colours[l,0]),
                     int(ID_colours[l,1]),
                     int(ID_colours[l,2]))
        x,y,w,h = int(float(line[1])*test_img_width),int(float(line[2])*test_img_height),int(float(line[3])*test_img_width),int(float(line[4])*test_img_height)
        cv2.rectangle(test_img, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), ind_color, 2)
    
cv2.imshow("test_img",cv2.resize(test_img,(1024,1024)))
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
if generate_dataset:
    from helper.Generate_YOLO_training import createCustomFiles
    createCustomFiles(output_folder=target_dir+"/",obIDs=subject_class_names, k_fold=cross_validation_split)