In [1]:
import os
import shutil
import re
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def fetch_images(image_location):
    images = []
    assert os.path.isdir(image_location), '%s is not a valid directory'

    for root, _, filenames in sorted(os.walk(image_location)):
        for filename in filenames:
            if is_image_file(filename):
                path = os.path.join(root, filename)
                images.append(path)

    return sorted(images[:len(images)])

def GenerateData(dataPath, startIdx=1):
    os.mkdir(dataPath) if not os.path.exists(dataPath) else None

    i = startIdx
    with tqdm(desc='Generating Images', total=len(list), leave=False, unit='image', position=0) as progressBar:
        for item in list:
            input_image = Image.open(item['A'])
            material_image = Image.open(item['B'])
            render_image = Image.open(item['C'])
            
            output_image = Image.new('RGBA', (768, 256))
            output_image.paste(input_image, (0, 0))
            output_image.paste(material_image, (256, 0))
            output_image.paste(render_image, (512, 0))

            output_image.save(dataPath + '/img_{}.png'.format(i))
            i = i + 1
            progressBar.update(1)
            
def ShuffleDataSet(dataPath):
    assert os.path.isdir(dataPath), '%s is not a valid directory'

    data = fetch_images(dataPath)
    train_data, val_data = train_test_split(data, test_size=0.20, shuffle=True)
    #val_data, test_data = train_test_split(val_data, test_size=0.5, shuffle=True)

    os.mkdir(dataPath + '/train') if not os.path.exists(dataPath + '/train') else None
    os.mkdir(dataPath + '/val') if not os.path.exists(dataPath + '/val') else None

    for train_file in train_data:
        filename = os.path.split(train_file)[1]
        shutil.move(train_file, dataPath + '/train/' + filename)

    for val_file in val_data:
        filename = os.path.split(val_file)[1]
        shutil.move(val_file, dataPath + '/val/' + filename)

    #for test_file in test_data:
    #    filename = os.path.split(test_file)[1]
    #    shutil.move(test_file, dataPath + '/test/' + filename)


In [3]:
rendersPath = './renders'
materialsPath = './materials'

In [4]:
materialNums = [0, 3, 7, 13, 16, 21, 23, 31, 37, 42]
angles = [0, 45, 90, 135, 180, 225, 270, 315]
images = ['armadillo', 'blenderSphere', 'dragon', 'eagle', 'fandisk', 'fishBigmouth', 'frog', 'gear_knee', 'Handle', 'maskHorror', 'monster01', 'pufferfish', 'teapot']

list = []
for image in images:
    for inputAngle in angles:
        for inputMaterial in materialNums:
            for outputMaterial in materialNums:
                list.append({'A': rendersPath + '/{0}_{1}_{2}.png'.format(image, inputMaterial, inputAngle), 
                             'B': materialsPath + '/sphere_{0}.png'.format(outputMaterial),
                             'C': rendersPath + '/{0}_{1}_{2}.png'.format(image, outputMaterial, inputAngle)})

GenerateData('./data')
ShuffleDataSet('./data')

                                                                                                                       

In [5]:
# Reserved for test set

materialNums = [1, 4, 12, 15, 20, 22, 26, 27, 32, 38]
angles = [0, 45, 90, 135, 180, 225, 270, 315]
images = ['bunny', 'squirrel', 'plane']

list = []
for image in images:
    for inputAngle in angles:
        for inputMaterial in materialNums:
            for outputMaterial in materialNums:
                list.append({'A': rendersPath + '/{0}_{1}_{2}.png'.format(image, inputMaterial, inputAngle), 
                             'B': materialsPath + '/sphere_{0}.png'.format(outputMaterial),
                             'C': rendersPath + '/{0}_{1}_{2}.png'.format(image, outputMaterial, inputAngle)})

os.mkdir('./data/test') if not os.path.exists('./data/test') else None
GenerateData('./data/test', 12000)

                                                                                                                       