In [5]:
import os
import shutil
import numpy as np
import json
import cv2
import send2trash

source = '../Dataset/Color'
train_dst = 'models/model/train'
test_dst = 'models/model/test'
annotation_dst = '../Dataset/annotation.json'
files = os.listdir(source)

In [6]:
# toy example: sample 800 images as training set and 200 images as test set
size = 1000
sample = np.random.choice(files, size=size)
train_size = int(size * 0.8)
test_size = int(size * 0.2)
train = sample[:train_size]
test = sample[train_size:]

In [7]:
annotation_data = {}
with open(annotation_dst) as f:
    annotation_data = json.load(f)
annotation = annotation_data.items()

In [8]:
def generate_dataset(src, training_dst, test_dst, training, test):
    print("generating dataset...")
    if os.path.exists(training_dst):
        shutil.rmtree(training_dst)
    os.makedirs(training_dst)
    
    if os.path.exists(test_dst):
        shutil.rmtree(test_dst)
    os.makedirs(test_dst)
    
    for f in training:
        # copy file in images
        shutil.copy(src + '/' + f, training_dst + '/' + f)
        # copy annotation data of corresponding files
        
        
    print("training set completed!")
    
    for f in test:
        shutil.copy(src + '/' + f, test_dst + '/' + f)
    print("test set completed!")

In [9]:
generate_dataset(source, train_dst, test_dst, train, test)

generating dataset...
training set completed!
test set completed!


In [10]:
# convert annotation data to csv format
import csv
import decimal

In [11]:
def bounds(coordinates):
    xmin = ymin = decimal.Decimal('infinity')
    xmax = ymax = 0
    for i in range(len(coordinates)):
        pair = coordinates[i]
        x = int(pair[0])
        y = int(pair[1])
        xmax = x if (x > xmax) else xmax
        xmin = x if (x < xmin) else xmin
        ymax = y if (y > ymax) else ymax
        ymin = y if (y < ymin) else ymin
    return xmin, ymin, xmax, ymax

In [12]:
def generate_label_file(img_dir, label_dst, annotation_data):
    """
    Given an image directory, and an annotation dictionary, create a csv file with specified input format
    """
    img_names = []
    csv_holder = []
    header = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    csv_holder.append(header)
    for roots, dirs, filenames in os.walk(img_dir):
        for f in filenames:
            if f.split(".")[1] == "jpg":
                img_names.append(f.split(".")[0])

    for i, name in enumerate(img_names):
        name_L = name + "_L"
        name_R = name + "_R"
        # get width and height of current image
        filename = name + ".jpg"
        img = cv2.imread(img_dir + "/" + filename)
        
        width = img.shape[1]
        height = img.shape[0]
        # get bounding box(es)
        # Either left or right hand annotation is in annotation_data
        if name_L in annotation_data.keys():
            # find xmin, ymin, xmax, ymax
            xmin, ymin, xmax, ymax = bounds(annotation_data[name_L])
            if xmin > 0 and ymin > 0 and xmax > 0 and ymax > 0:
                label_row = [filename, width, height, "hand", xmin, ymin, xmax, ymax]
                csv_holder.append(label_row)
                # draw bounding boxes
                # cv2.rectangle(img, (xmin, ymin), 
                #               (xmax, ymax), (0, 255, 0), 1)
        if name_R in annotation_data.keys():
            # same as name_L
            xmin, ymin, xmax, ymax = bounds(annotation_data[name_R])
            if xmin > 0 and ymin > 0 and xmax > 0 and ymax > 0:
                label_row = [filename, width, height, "hand", xmin, ymin, xmax, ymax]
                csv_holder.append(label_row)
                # draw bounding boxes
                # cv2.rectangle(img, (xmin, ymin), 
                #               (xmax, ymax), (0, 255, 0), 1)
        img = cv2.resize(img, (1080, 720))
        # cv2.imshow('Verifying annotation', img)
        # cv2.waitKey(200)
    # cv2.destroyAllWindows()
    # save as csv
    csv_path =  "data/" + label_dst + ".csv"
    print(csv_path)
    if os.path.exists(csv_path):
        send2trash.send2trash(csv_path)

    if not os.path.exists(csv_path):
        with open(csv_path, 'w') as csv_file:
            wr = csv.writer(csv_file)
            print("Writing data to csv file...")
            for i, row in enumerate(csv_holder):
                wr.writerow(row)
            print("Completed!")

In [13]:
generate_label_file(train_dst, 'train', annotation_data)
generate_label_file(test_dst, 'test', annotation_data)

data/train.csv
Writing data to csv file...
Completed!
data/train.csv
Writing data to csv file...
Completed!
