In [3]:
from PIL import Image
import numpy as np
import os
from sklearn.model_selection import train_test_split

In [4]:
def load_image(name):
    img = Image.open(name)
    return np.array(img)

def restructure_dataset(dataset_path, color_encoding):
    """
    Restructures a dataset to make it applicable for the model training pipeline. 
    Requires that the name of the label image is the name of the corresponding original image + '_L'.
    The dataset should be in the same folder as this notebook and the argument a relative path to it.
    The argument color encoding should represent an array of color encodings(rgb) for each of the classes of the dataset.
    """
    os.chdir(dataset_path)
    image_names = os.listdir()
    images = []
    labels = []
    for name in image_names:
        if name.endswith('_L.png'):
            labels.append(name)
        else:
            images.append(name)
            
    # Sorting to preserve ordering
    labels.sort()
    images.sort()
    
    # Splitting to create folder structure later
    train_images, rest_of_images, train_labels, rest_of_labels = train_test_split(images, labels,
                                                                                 test_size=.4, random_state=42)
    
    valid_images, test_images, valid_labels, test_labels = train_test_split(rest_of_images, rest_of_labels,
                                                                           test_size=.6, random_state=42)
    
    os.makedirs('train/images')
    os.makedirs('train/labels')
    os.makedirs('test/images')
    os.makedirs('test/labels')
    os.makedirs('valid/images')
    os.makedirs('valid/labels')
    
    for name in train_images:
        os.rename(name, "./train/images/" + name)
    for name in train_labels:
        os.rename(name, "./train/labels/" + name)
    for name in valid_images:
        os.rename(name, "./valid/images/" + name)
    for name in valid_labels:
        os.rename(name, "./valid/labels/" + name)
    for name in test_images:
        os.rename(name, "./test/images/" + name)
    for name in test_labels:
        os.rename(name, "./test/labels/" + name)
        
    rgbs = color_encoding
    colors = list(range(0,len(rgbs)))
    for color in colors:
        colors[color] = (color,color,color)
        
    # Encoding the label images
    folders = ["train/labels","test/labels","valid/labels"]
    for folder in folders:
        path = os.path.join(dataset,folder)
        image_names = os.listdir(os.path.join("ConvertedDatasetDPNS",folder))
        for name in image_names:
            image_full_name = os.path.join(path,name)
            image = Image.open(image_full_name)
            img = load_image(image_full_name)
            for x in range(img.shape[1]):
                for y in range(img.shape[0]):
                    image.putpixel((x,y),colors[rgbs.index(list(image.getpixel((x,y))))])
            image.save(image_full_name)