In [1]:
from keras.models import Sequential, load_model
from keras.layers import SeparableConv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.applications.xception import preprocess_input
from keras.preprocessing.image import ImageDataGenerator

import numpy as np
import os

from sklearn.metrics import classification_report

In [16]:
wd = os.path.sep.join([os.getcwd(), "data"])
train_path = os.path.sep.join([wd, "train"])
test_path = os.path.sep.join([wd, "test"])
symblink_path = os.path.sep.join([wd, "symblink"])

input from multiple directories with keras https://stackoverflow.com/questions/60787620/combine-two-data-generator-to-train-a-cnn

## Declaration of the image generators used to train models without having to load images in cache

In [17]:
img_datagen = ImageDataGenerator(preprocessing_function=preprocess_input, validation_split=.15)

training_gen_original = img_datagen.flow_from_directory(train_path, target_size=(51,51), subset="training", batch_size=32)
training_gen_augmented = img_datagen.flow_from_directory(symblink_path, target_size=(51,51), subset="training", batch_size=32, follow_links=True)


validation_gen_original = img_datagen.flow_from_directory(train_path, target_size=(51,51), subset="validation", batch_size=32)
validation_gen_augmented = img_datagen.flow_from_directory(symblink_path, target_size=(51,51), subset="validation", batch_size=32, follow_links=True)


test_gen = img_datagen.flow_from_directory(test_path, target_size=(51,51), batch_size=32)

Found 141722 images belonging to 2 classes.
Found 226722 images belonging to 2 classes.
Found 25008 images belonging to 2 classes.
Found 40008 images belonging to 2 classes.
Found 110794 images belonging to 2 classes.


In [13]:
epochs = 50
batch_size = 32
callback = lambda x: [EarlyStopping(monitor='val_loss', patience=2, mode="min"), ModelCheckpoint(filepath=f"{wd}/{x}_model.h5", monitor='val_loss', save_best_only=True)]

# original
totalTrain_original = 141722 # total number of images in the train set
totalVal_original = 25008    # total number of images in the validation set

# original + augmented
totalTrain_augmented = 226722 # total number of images in the train set
totalVal_augmented = 40008    # total number of images in the validation set

## I - CNN model with unbalanced classes (without data augmentation)
### First perform classes weight to prioritize sample that are fewer (classe 1)

In [6]:
from sklearn.utils import class_weight

class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(training_gen_original.classes), y=training_gen_original.classes)
train_class_weights = dict(enumerate(class_weights))
train_class_weights

{0: 0.6984751259228593, 1: 1.7596036850338954}

In [19]:
model = Sequential()
model.add(SeparableConv2D(32, (3, 3), activation="relu"))
model.add(BatchNormalization(axis=-1))
model.add(MaxPooling2D((2, 2)))

model.add(SeparableConv2D(32, (3, 3), activation="relu"))
model.add(BatchNormalization(axis=-1))
model.add(MaxPooling2D((2, 2)))

model.add(SeparableConv2D(64, (3, 3), activation="relu"))
model.add(BatchNormalization(axis=-1))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())

model.add(Dense(64, activation="relu"))
model.add(BatchNormalization(axis=-1))
model.add(Dropout(0.5))
model.add(Dense(2, activation="softmax"))

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

### Batch size comparaison :
128 -> Val accuracy: .76 - 5min
64 -> Val accuracy: .82 - 15min
32 -> Val accuracy: .84 - +60min

## II - CNN model with balanced classes data augmentation

In [20]:
hist = model.fit(x=training_gen_augmented, validation_data=training_gen_augmented, callbacks=callback("CNN_balanced"), epochs=epochs,
          batch_size=batch_size, steps_per_epoch=totalTrain_augmented // batch_size, validation_steps=totalVal_augmented // batch_size).history

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50


In [21]:
import pickle
with open('data/history.pickle', 'wb') as handle:
    pickle.dump(hist, handle, protocol=pickle.HIGHEST_PROTOCOL)

# with open('data/history.pickle', 'rb') as handle:
#     hist = pickle.load(handle)