In [1]:
import os
import re
from PIL import Image
from sklearn.model_selection import train_test_split

In [2]:
data_path = '/Users/lauren/Desktop/CodingInterview'
frame_path = data_path + '/imgs_small'
mask_path = data_path + '/masks_small'


In [None]:
# create folders to hold images
folders = ['train_frames', 'train_masks','val_frames', 'val_masks', 'test_frames', 'test_masks']

for folder in folders:
    os.makedirs(data_path+'/'+folder)

In [None]:
#get all frames and masks, sort them, shuffle them to generate data sets

all_frames = os.listdir(frame_path)

all_masks = os.listdir(mask_path)

all_frames.sort(key=lambda var: [int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

all_masks.sort(key=lambda var: [int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

#generate train, val, and test sets for frames

img_train, img_test, mask_train, mask_test = train_test_split(all_frames, all_masks, test_size=0.1, random_state=230)

img_train, img_val, mask_train, mask_val= train_test_split(img_train, mask_train, test_size=0.2, random_state=230)

In [None]:
#add train, val and test data to correct folders

In [None]:
def add_frames(dir_name, image):
    img = Image.open(frame_path+'/'+image)
    new_name =data_path+'/'+ dir_name + '/'+ image
    img.save(new_name)

def add_masks(dir_name, image):
    img = Image.open(mask_path+'/'+image)
    new_name =data_path+'/'+ dir_name + '/'+ image
    img.save(new_name)

In [None]:
frame_folders= [(img_train, 'train_frames'),(img_val, 'val_frames'), (img_test, 'test_frames')]

mask_folders= [(mask_train, 'train_masks'),(mask_val, 'val_masks'), (mask_test, 'test_masks')]

In [None]:
for folder in frame_folders:
    array = folder[0]
    name = [folder[1]] *len(array)
    list(map(add_frames, name, array))

In [None]:
for folder in mask_folders:
    array = folder[0]
    name = [folder[1]] *len(array)
    list(map(add_masks, name, array))

In [None]:
#sorted results to make sure the images and masks were in the correct folders

# img_train.sort(key=lambda var: [int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

# img_test.sort(key=lambda var: [int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

# img_val.sort(key=lambda var: [int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

# Data Augmentation

In [None]:
from keras.preprocessing.image import ImageDataGenerator

In [None]:
train_datagen =ImageDataGenerator(rescale= 1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
val_datagen = ImageDataGenerator(rescale=1./255)


train_image_generator = train_datagen.flow_from_directory(f'{data_path}/train_frames/',batch_size=4,color_mode="grayscale",class_mode=None)
train_mask_generator = train_datagen.flow_from_directory(f'{data_path}/train_masks/',color_mode="grayscale",class_mode=None)


val_image_generator = val_datagen.flow_from_directory(f'{data_path}/val_frames/',batch_size=4,color_mode="grayscale",class_mode=None )
val_mask_generator = val_datagen.flow_from_directory(f'{data_path}/val_masks/',batch_size=4, color_mode="grayscale",class_mode=None)

train_generator = zip(train_image_generator, train_mask_generator)
val_generator = zip(val_image_generator, val_mask_generator)



In [None]:
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.optimizers import Adam

In [None]:
import model

In [None]:
no_of_training_imgs= len(os.listdir(data_path+'/train_frames/train/'))
no_of_val_imgs= len(os.listdir(data_path+'/val_frames/val/'))

no_epochs = 5

batch_size = 32

weights_path = data_path +'/weights/weights.h5'

m= model.unet()

check_point = ModelCheckpoint(weights_path, monitor='val_accuracy', verbose= 1, save_best_only=True, mode ='max')

csv_logger = CSVLogger('./log.out', append =True, separator = ';')

earlystopping = EarlyStopping(verbose=1,monitor='val_acc', min_delta=.01, patience = 3, mode = 'max')

callbacks_list = [check_point, csv_logger, earlystopping]

In [None]:
results = m.fit_generator(train_generator, epochs = 2, steps_per_epoch=(360//32), validation_data= val_generator,validation_steps= (90//32), callbacks= callbacks_list )

In [None]:
m.summary()