# MMLAB / MMSEG style dataset parser

### The generated output includes the following annotation data:
* Classes
* Pixel masks

### Example application(s) (as demonstrated in Plum et al. 2023):
* PSPNet _(semantic segmentation)_
* UperNet + SWIN Transformers _(semantic segmentation)_

### Output structure:
* target_dir
  * **images**  _(containing a copy of the original input images)_
  * **annotations**  _(containing the segmentation maps)_
  * **imageLists**  _(containing three txt files which contain the names of the images separated into all.txt / test.txt / train.txt_
    
### Notes:

* In our examples, we use the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) repo by [OpenMMLAB](https://openmmlab.com/)
* This parser in particular has been used to train models for binary segmentation problems, grouping images into background (0) and specimen (1), but should generally be applicable to arbitrary numbers of classes (<255)
* as the colour value of the pixel masks correspond to the respective classes (0 to 255), especially binary images may appear entirely black but will contain all required information

In [None]:
import cv2
import json
import time
import threading
import queue
import sys
import os
import pathlib

import numpy as np
import matplotlib as plt

from PIL import Image
from sklearn.utils import shuffle
from sklearn.model_selection import KFold
from os import listdir
from os.path import isfile, join

### Required parameters

Specify the location of your **generated dataset** and in which **output directory** you wish to save it.

**Notes:**
* do not include trailing forward slashes in your paths (see examples below)
* Your **dataset** name should **NOT include underscores** as they are used to separate passes into their categories. Instead, use hyphens in your naming convention where applicable.

In [None]:
# define location of dataset and return all files
dataset_location = "../example_data/input-multi"
target_dir = "../example_data/MMSEG"

### Optional parameters

In [None]:
# set True to show processing results for each image (disables parallel processing)
DEBUG = False

# number of pixels added to pad masks to avoid cutting off contours
# in our examples, we use a padding of 5 pixels during training and erode the resulting masks by 5 pixels during inference
mask_padding = 0

enforce_single_class = False # overwrites multiple classes and groups all instances as one

# determine the proportion of a bounding box that needs to be filled before considering the visibility as too low
visibility_threshold = 0.1

# test split (what split is witheld)
amount_test = 0.1

The following lines will load the generated dataset from your drive and prepare it for the multi-threaded parsing process

In [None]:
all_files = [f for f in listdir(dataset_location) if isfile(join(dataset_location, f))]
all_files.sort()

# next, sort files into images, depth maps, segmentation maps, data, and colony info
# we only need the location and name of the data files, as all passes follow the same naming convention
dataset_data = []
dataset_img = []
dataset_ID = []
dataset_depth = []
dataset_norm = []
dataset_colony = None

for file in all_files:
    loc = dataset_location + "/" + file
    file_info = file.split("_")
    
    if file_info[1] == "BatchData":
        dataset_colony = loc
        
    elif len(file_info) == 2:
        # images are available in various formats, but annotation data is always written as json files
        if file_info[-1].split(".")[-1] == "json":
            dataset_data.append(loc)
        else:
            dataset_img.append(loc)
            
    elif file_info[-1].split(".")[0] == "ID":
        dataset_ID.append(loc)
    elif file_info[-1].split(".")[0]  == "depth":
        dataset_depth.append(loc)
    elif file_info[-1].split(".")[0]  == "norm":
        dataset_norm.append(loc)
        
print("Found",len(dataset_data),"samples...")

# next sort the colony info into its IDs to determine the colony size and individual scales
# Opening colony (BatchData) JSON file
colony_file = open(dataset_colony)
 
# returns JSON object as a dictionary
colony = json.load(colony_file)
colony_file.close()


""" !!! requires IDs, model names, scales !!! """

if not enforce_single_class:
    # get provided classes to create a dictionary of class IDs and class names
    all_classes = []
    for subject in colony["Subject Variations"]:
        all_classes.append(colony["Subject Variations"][subject]["Class"])
        
    subject_class_names = []
    for class_name in all_classes:
        # check if exists in unique_list or not and replace any spaces with underscores
        class_name = class_name.replace(" ", "_")
        if class_name not in subject_class_names:
            # append unique classes 
            subject_class_names.append(class_name)
        
    subject_classes = {}
    for id,sbj in enumerate(subject_class_names):
        subject_classes[str(sbj)] = id
else:
    subject_class_names = ["insect_0"] #np.array([int(0)], dtype=int)
    subject_classes = {"insect" : 0}

print("\nA total of",len(subject_class_names),"unique classes have been found.")
print("The classes and respective class IDs are:\n",subject_classes,"\n")


print("Loaded colony file with seed", colony['Seed']) #,"and",len(colony['ID']),"individuals.")
    
if len(colony['Subject Variations']) > 1:
    multi_animal = True
    print("Generating MULTI-animal dataset! Containing",len(colony['Subject Variations']),"individuals")
else:
    multi_animal = False
    print("Generating SINGLE-animal dataset!")

Create required output folders used in the mmlab segmentation conventions
- **/images/** (containing a copy of the original input images)
- **/annotations/** (containing the segmentation maps)
- **/imageLists/**(containing three txt files which contain the names of the images separated into all.txt / test.txt / train.txt

In [None]:
# create output folders
output_folders = ["images", "annotations", "imageLists"]
for i, f in enumerate(output_folders):
    if not os.path.exists(target_dir + "/" + f):
        os.mkdir(target_dir + "/" + f)
        
"""
    if i < 2:
        if not os.path.exists(target_dir + "/" + f + "/train"):
            os.mkdir(target_dir + "/" + f + "/train")
        if not os.path.exists(target_dir + "/" + f + "/test"):
            os.mkdir(target_dir + "/" + f + "/test")
"""

Create a test and a train image list, according to the specified split.

When writing out the images, automatically place them into their respective folders.

In [None]:
imageLists_all = target_dir + "/imageLists/all.txt"
imageLists_train = target_dir + "/imageLists/train.txt"
imageLists_test = target_dir + "/imageLists/test.txt"

imageLists_all_orig = [img.split('/')[-1][:-4] + "_synth" for img in dataset_img]

imageLists_all_orig_shuffle = shuffle(imageLists_all_orig, random_state=0)
num_train_examples = int(np.floor(len(imageLists_all_orig_shuffle) * (1 - amount_test)))

print("Using", num_train_examples, "training images and",
      int(np.floor(len(imageLists_all_orig_shuffle) - (len(imageLists_all_orig_shuffle) * (1 - amount_test)))), "test images. (" + str(amount_test * 100),
      "%)")

files_train = imageLists_all_orig_shuffle[0:num_train_examples]
files_test = imageLists_all_orig_shuffle[num_train_examples:]

with open(imageLists_all, 'w') as f:
    for line in imageLists_all_orig_shuffle:
        f.write(line)
        f.write('\n')
        
with open(imageLists_train, 'w') as f:
    for line in files_train:
        f.write(line)
        f.write('\n')
        
with open(imageLists_test, 'w') as f:
    for line in files_test:
        f.write(line)
        f.write('\n')

With all dataset related parameters configured, we have provided a multi-threaded parsing solution below to minimise the processing time it takes to bring the entire dataset into the required output format. Currently, we instanciate one processing thread per (virtual) CPU core but you can adjust this value if you wish by changing:

```
threadList_export = createThreadList(#NumDesiredThreads)
```

**Note:** To see the process of mask generation from ID passes in action, set the **DEBUG** mode to **"True"**. This will however slow down the processing speed considerably and only run in single-threaded mode!

In [None]:
# transform between sRGB and linear colour space (optional)

def to_linear(srgb):
    linear = np.float32(srgb) / 255.0
    less = linear <= 0.04045
    linear[less] = linear[less] / 12.92
    linear[~less] = np.power((linear[~less] + 0.055) / 1.055, 2.4)
    return linear * 255.0

    
def from_linear(linear):
    srgb = linear.copy()
    less = linear <= 0.0031308
    srgb[less] = linear[less] * 12.92
    srgb[~less] = 1.055 * np.power(linear[~less], 1.0 / 2.4) - 0.055
    return srgb * 255.0

def getThreads():
    """ Returns the number of available threads on a posix/win based system """
    if sys.platform == 'win32':
        return int(os.environ['NUMBER_OF_PROCESSORS'])
    else:
        return int(os.popen('grep -c cores /proc/cpuinfo').read())

class exportThread(threading.Thread):
    def __init__(self, threadID, name, q):
        threading.Thread.__init__(self)
        self.threadID = threadID
        self.name = name
        self.q = q

    def run(self):
        print("Starting " + self.name)
        process_detections(self.name, self.q)
        print("Exiting " + self.name)
        
def createThreadList(num_threads):
    threadNames = []
    for t in range(num_threads):
        threadNames.append("Thread_" + str(t))

    return threadNames

def process_detections(threadName, q):
    while not exitFlag_export:
        queueLock.acquire()
        if not workQueue_export.empty():
            
            data_input = q.get()
            i, data_loc, img, ID = data_input
            queueLock.release()
            
            display_img = cv2.imread(img)
            display_img_orig = display_img.copy()
            
            # compute visibility for each individual
            seg_img = cv2.imread(ID)
            seg_img_display = seg_img.copy()
            
            data_file = open(data_loc)
            # returns JSON object as a dictionary
            data = json.load(data_file)
            data_file.close()
            
            img_shape = display_img.shape
            
            output_mask = np.zeros(img_shape[:2], dtype=np.uint8)
            
            # only add images that contain visibile individuals
            is_empty = True
            img_info = []
            
            # check if the size of the image and segmentation pass match
            if display_img.shape != seg_img.shape:
                print("Size mismatch of image and segmentation pass for sample",data_input[1].split("/")[-1],"!")
            else:
                for im, individual in enumerate(data["iterationData"]["subject Data"]):
                    ind_key = list(individual.keys())[0]
                    ind_ID = int(ind_key)
                    # WARNING ID numbering begins at 1

                    fontColor = (int(ID_colours[ind_ID,0]),
                                 int(ID_colours[ind_ID,1]),
                                 int(ID_colours[ind_ID,2]))
                    
                    bbox_orig = [individual[ind_key]["2DBounds"]["xmin"],
                                 individual[ind_key]["2DBounds"]["ymin"],
                                 individual[ind_key]["2DBounds"]["xmax"],
                                 individual[ind_key]["2DBounds"]["ymax"]]
                    
                    bbox = fix_bounding_boxes(bbox_orig, max_val=display_img.shape)
                    
                    # only process an individual if its bounding box width and height are not zero
                    if bbox[2] - bbox[0] == 0 or bbox[3] - bbox[1] == 0:
                        continue


                    contours_lowpoly = []

                    #try:
                    ID_mask = cv2.inRange(seg_img, np.array([0, 0, ind_ID - 0]), np.array([0, 0, ind_ID + 0]))
                    indivual_occupancy = cv2.countNonZero(ID_mask)

                    # the kernel size for both dilation and median blur are to be determined by the bbounding boxes relative size
                    rel_size = ((bbox[2] - bbox[0]) / display_img.shape[0] + (bbox[3] - bbox[1]) / display_img.shape[0]) / 2
                    # values range from 0 (tiny) to 1 (huge)
                    # required smoothing 5 to 95
                    rel_size_root = int(round((5 * rel_size)/2.)*2 + 1) # round to next odd integer
                    #print("img:", i, "individual:", im, "rel_size", rel_size, rel_size_root)

                    # to simplify the generated masks and counter compression artifacts the original mask is dilated
                    # https://docs.opencv.org/3.4/db/df6/tutorial_erosion_dilatation.html
                    kernel = np.ones((rel_size_root, rel_size_root), 'uint8')
                    ID_mask_dilated = cv2.dilate(ID_mask, kernel, iterations=1)
                    # use median blur to further smooth the edges of the binary mask
                    ID_mask_dilated = cv2.medianBlur(ID_mask_dilated,rel_size_root)

                    # pad segmentation subwindow to prevent contours from being cut off

                    if mask_padding != 0:
                        ID_mask_dilated_padded = np.zeros([ID_mask_dilated.shape[0] + mask_padding * 2 , ID_mask_dilated.shape[1] + mask_padding * 2], 'uint8')
                        ID_mask_dilated_padded[mask_padding:-mask_padding,mask_padding:-mask_padding] = ID_mask_dilated
                        ID_mask_dilated = ID_mask_dilated_padded

                    


                    if DEBUG:
                        print("\nindividual",im,ID_mask_dilated.dtype)
                        print(hierarchy)
                        # draw the contours on the empty image
                        seg_img_display = seg_img.copy()
                        cv2.imshow("mask: ", ID_mask_dilated)
                        cv2.drawContours(seg_img_display[bbox[1]:bbox[3],bbox[0]:bbox[2]], contours, -1, (255,0,0), 3)
                        cv2.imshow("segmentation: ", seg_img_display[bbox[1]:bbox[3],bbox[0]:bbox[2]])
                        cv2.waitKey(0)


                    #indivual_occupancy = np.count_nonzero((seg_img == [0, 0, int((individual[0]/len(colony['ID']))*255)]).all(axis = 2)) + np.count_nonzero((seg_img == [0, 0, int((individual[0]/len(colony['ID']))*255 - 1)]).all(axis = 2)) + np.count_nonzero((seg_img == [0, 0, int((individual[0]/len(colony['ID']))*255 + 1)]).all(axis = 2))
                    bbox_area = abs((bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) + 1
                    bbox_occupancy = indivual_occupancy / bbox_area
                    #print("Individual", individual[0], "with bounding box occupancy ",bbox_occupancy)

                    if not enforce_single_class:
                        class_ID = subject_classes[colony['Subject Variations'][ind_key]["Class"].replace(" ","_")] + 1
                    else:
                        # here we use a single class, otherwise this can be replaced by size / scale values
                        class_ID = 1 # as the background = 0
                        
                    #cv2.putText(display_img, "ID: " + str(int(individual[0])), (bbox[0] + 10,bbox[3] - 10), font, fontScale, fontColor, lineType)
                    if bbox_occupancy > visibility_threshold:
                        #output_mask += np.array((ID_mask_dilated/255) * class_ID, np.uint8)
                        """
                        segmentation code goes here
                        - get mask for every individual
                        - set all values in out_mask = class_ID which correspond to the individual
                        """
                        
                        output_mask = np.where(ID_mask_dilated > 0, 
                                               np.array((ID_mask_dilated/255) * (class_ID), np.uint8),
                                               output_mask)
                        """
                        output_mask = np.where(ID_mask_dilated > 0, 
                                               ID_mask_dilated,
                                               output_mask)
                        """
                    else:
                        pass

            # uncomment to show resulting bounding boxes and masks
            if DEBUG:
                cv2.imshow("segmentation: " ,cv2.resize(seg_img_display, (int(seg_img.shape[1] / 2), 
                                                                  int(seg_img.shape[0] / 2))))
                cv2.imshow("labeled image", cv2.resize(display_img, (int(display_img.shape[1] / 2), 
                                                                     int(display_img.shape[0] / 2))))
                cv2.waitKey(1)
            
            img_name = img.split('/')[-1][:-4] + "_synth"
            
            """
            if img_name in files_train:
                img_out_path = target_dir + "/images/train/" + img_name + ".jpg"
                mask_out_path = target_dir + "/annotations/train/" + img_name + "_labelTrainIds.png"
            else:
                img_out_path = target_dir + "/images/test/" + img_name + ".jpg"
                mask_out_path = target_dir + "/annotations/test/" + img_name + "_labelTrainIds.png"
            """
            img_out_path = target_dir + "/images/" + img_name + ".jpg"
            mask_out_path = target_dir + "/annotations/" + img_name + "_labelTrainIds.png"
                
            cv2.imwrite(img_out_path, display_img)
            Image.fromarray(output_mask).save(mask_out_path, 'PNG')
            print("Saved", img_name)
            
        else:
            queueLock.release()
            
# setup as many threads as there are (virtual) CPUs
exitFlag_export = 0
if DEBUG:
    threadList_export = createThreadList(1)
else:
    threadList_export = createThreadList(getThreads())
print("Using", len(threadList_export), "threads for export...")
queueLock = threading.Lock()

# define paths to all images and set the maximum number of items in the queue equivalent to the number of images
workQueue_export = queue.Queue(len(dataset_img))
threads = []
threadID = 1

np.random.seed(seed=1)
ID_colours = np.random.randint(255, size=(255, 3))

font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 0.5
lineType = 2

# we can additionally plot the points in the data files to check joint locations
plot_joints = True

# remember to refine an export folder when saving out your dataset
generate_dataset = True

def fix_bounding_boxes(coords,max_val = [1024,1024]):
    # fix bounding box coordinates so they do not reach beyond the image
    fixed_coords = []
    for c, coord in enumerate(coords):
        if c == 0 or c == 2:
            max_val_temp = max_val[0]
        else:
            max_val_temp = max_val[1]
            
        if coord >= max_val_temp:
            coord = max_val_temp
        elif coord <= 0:
            coord = 0
        
        fixed_coords.append(int(coord))
        
    return fixed_coords

timer = time.time()

# Create new threads
for tName in threadList_export:
    thread = exportThread(threadID, tName, workQueue_export)
    thread.start()
    threads.append(thread)
    threadID += 1

# Fill the queue with samples
queueLock.acquire()
for i, (data, img, ID) in enumerate(zip(dataset_data , dataset_img, dataset_ID)):
    workQueue_export.put([i, data, img, ID])
queueLock.release()

# Wait for queue to empty
while not workQueue_export.empty():
    pass

# Notify threads it's time to exit
exitFlag_export = 1

# Wait for all threads to complete
for t in threads:
    t.join()
print("Exiting Main export Thread")

# close all windows if they were opened
cv2.destroyAllWindows()

print("Total time elapsed:",time.time()-timer,"seconds")