In [1]:
# imports
import argparse
import csv
import os

import numpy as np
from PIL import Image
from tqdm import tqdm

In [2]:
# File Paths
DATA_DIR = "/home/jupyter/runwai/data"
IMAGES_DIR = f"{DATA_DIR}/images"
LABELS_DIR = f"{DATA_DIR}/labels"
SPLIT_DIR = f"{DATA_DIR}/split"

In [3]:
# delete output files to avoid duplicates and overriding
#train.csv

if os.path.exists(f"{SPLIT_DIR}/train.csv"):
  os.remove(f"{SPLIT_DIR}/train.csv")
  print("Train CSV is deleted")
else:
  print("The train.csv file does not exist")

#val.csv

if os.path.exists(f"{SPLIT_DIR}/val.csv"):
  os.remove(f"{SPLIT_DIR}/val.csv")
  print("Val CSV is deleted")
else:
  print("The val.csv file does not exist")

The train.csv file does not exist
The val.csv file does not exist


In [4]:
def save_csv(data, path, fieldnames=['image_path', "upper_fabric", "lower_fabric", "outer_fabric", 
                                              "upper_color", "lower_color" , "outer_color", 
                                             "sleeve_len", "lower_clothing_len", "neckline"]):
    with open(path, 'w', newline='') as csv_file:
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
        writer.writeheader()
        for row in data:
            writer.writerow(dict(zip(fieldnames, row)))

In [5]:
def split_data():
    input_folder = IMAGES_DIR
    output_folder = SPLIT_DIR
    annotation = f"{LABELS_DIR}/small_labels.csv"

    # open annotation file
    all_data = []
    with open(annotation) as csv_file:
        # parse it as CSV
        reader = csv.DictReader(csv_file)
        # tqdm shows pretty progress bar
        # each row in the CSV file corresponds to the image
        for row in tqdm(reader, total=reader.line_num):
            # we need image ID to build the path to the image file
            img_name = row['image']
            # we're going to use only 9 attributes
            u_fabric = row['upper_fabric']
            l_fabric = row['lower_fabric']
            o_fabric = row['outer_fabric']
            
            u_color = row['upper_color']
            l_color = row['lower_color']
            o_color = row['outer_color']
            
            sleeve_len = row['sleeve_len']
            pants_len = row['lower_clothing_len']
            neckline = row['neckline']
            
            img_path = os.path.join(input_folder, img_name)
            # check if file is in place
            if os.path.exists(img_path):
                # check if the image has 80*60 pixels with 3 channels
                img = Image.open(img_path)
                if img.mode == "RGB":
                    all_data.append([img_path, u_fabric, l_fabric, o_fabric, 
                                    u_color, l_color, o_color, sleeve_len, 
                                    pants_len, neckline])

    # set the seed of the random numbers generator, so we can reproduce the results later
    np.random.seed(42)
    # construct a Numpy array from the list
    all_data = np.asarray(all_data)
    # Take 42544 samples in random order
    inds = np.random.choice(42544, 42544, replace=False)
    # split the data into train/val and save them as csv files
    save_csv(all_data[inds][:34035], os.path.join(output_folder, 'train.csv'))
    save_csv(all_data[inds][34035:42544], os.path.join(output_folder, 'val.csv'))

In [6]:
split_data()

42544it [00:10, 4129.50it/s]
