# Generate train, validation and test splits

In [1]:
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.vgg16 import preprocess_input
import glob
import numpy as np
from PIL import Image
import random
random.seed(123)
import os
import shutil

Using TensorFlow backend.


### Data dirs

In [2]:
pso_dir =   # Directory containing psoriasis images
ecz_dir =   # Directory containing eczema images

pso = glob.glob(pso_dir + '/*.jpg')
ecz = glob.glob(ecz_dir + '/*.jpg')

train_dir =          # Training directory
validation_dir =     # Validation directory
test_dir =           # Test directory

### Shuffle

In [None]:
random.shuffle(pso)
random.shuffle(ecz)

### Split

In [5]:
n_pso_train = int(0.8*np)
n_pso_val = int(0.1*np)
n_pso_test = np - n_pso_train - n_pso_val
print(n_pso_train, n_pso_val, n_pso_test, np)

n_ecz_train = int(0.8*ne)
n_ecz_val = int(0.1*ne)
n_ecz_test = ne - n_ecz_train -n_ecz_val
print(n_ecz_train, n_ecz_val, n_ecz_test, ne)

630 78 80 788
4284 535 537 5356


### copy to target dirs

In [6]:
pso_train = pso[:n_pso_train]
pso_val = pso[n_pso_train: n_pso_train + n_pso_val]
pso_test = pso[n_pso_train+n_pso_val:]

ecz_train = ecz[:n_ecz_train]
ecz_val = ecz[n_ecz_train: n_ecz_train + n_ecz_val]
ecz_test = ecz[n_ecz_train+n_ecz_val:]

print(len(pso_train), len(pso_val), len(pso_test))
print(len(ecz_train), len(ecz_val), len(ecz_test))

def copy_files(source, target_dir):
    for file in source:
        basename = os.path.basename(file)
        shutil.copyfile(file, target_dir + basename)
        

    

630 78 80
4284 535 537


In [7]:
# Do not Uncomment unless you want to generate new data splits!
# copy_files(pso_test, test_dir + '/pso/')
# copy_files(ecz_test, test_dir + '/ecz/')
# copy_files(pso_train, train_dir + '/pso/')
# copy_files(ecz_train, train_dir + '/ecz/')
# copy_files(pso_val, validation_dir + '/pso/')
# copy_files(ecz_val, validation_dir + '/ecz/')

## Augment data and balance classes

In [8]:
train_aug =        # Directory for augmented training data
validation_aug =   # Directory for augmented validation data


DIM = 224
BATCH_SIZE = 100
img_width, img_height = 224, 224
nchannels = 3

In [9]:
# this is the augmentation configuration we will use for training and validation
datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.4,
    zoom_range=0.4,
    fill_mode="nearest",
    horizontal_flip=True,
    vertical_flip=True)


In [10]:
# this is a generator that will read pictures found in
# subfolers of 'data/train', and indefinitely generate
# batches of augmented image data
train_generator = datagen.flow_from_directory(
        train_dir,  
        target_size=(img_height, img_width),  
        batch_size=1,
        class_mode='categorical')

# this is a similar generator, for validation data
validation_generator = datagen.flow_from_directory(
        validation_dir,
        target_size=(img_height, img_width),
        batch_size=1,
        class_mode='categorical')

Found 4914 images belonging to 2 classes.
Found 613 images belonging to 2 classes.


In [11]:
train_generator.class_indices

{'ecz': 0, 'pso': 1}

In [None]:
psotr = 0
psoval = 0
ecztr = 0
eczval = 0
n = 0

while psotr < 20000:
    x_tr, y_tr = train_generator.next()
    
    if y_tr[0][0] == 0 and psotr < 20000:
        im = Image.fromarray(x_tr[0].astype('uint8'), 'RGB')
        im.save(train_aug + "/pso/pso_{}.png".format(psotr))
        psotr += 1
        
    if y_tr[0][0] == 1 and ecztr < 20000:
        im = Image.fromarray(x_tr[0].astype('uint8'), 'RGB')
        im.save(train_aug + "/ecz/ecz_{}.png".format(ecztr))
        ecztr += 1
        
    
    x_val, y_val = validation_generator.next()


    if y_val[0][0] == 0 and psoval < 2000:
        im = Image.fromarray(x_val[0].astype('uint8'), 'RGB')
        im.save(validation_aug + "/pso/pso_{}.png".format(psoval))
        psoval += 1
        
    if y_val[0][0] == 1 and eczval < 2000:
        im = Image.fromarray(x_val[0].astype('uint8'), 'RGB')
        im.save(validation_aug + "/ecz/ecz_{}.png".format(eczval))
        eczval += 1
        
    n += 1
    if n % 100 == 0:
        print("psotr: {0}, ecztr: {1}, psoval: {2}, eczval: {3}".format(psotr, ecztr, psoval, eczval))
    