In [1]:
#train.py


# Importing Libraries
from keras.models import Sequential
from keras.layers import Convolution2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.utils.vis_utils import plot_model
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard
from keras.preprocessing.image import ImageDataGenerator

In [2]:
# Creating and initializing the CNN model

classifier = Sequential()

In [3]:
# Adding first convolution layer and pooling
classifier.add(Convolution2D(32, (3, 3), input_shape=(64, 64, 1), activation='relu'))
classifier.add(MaxPooling2D(pool_size=(2, 2)))
# Adding second convolution layer and pooling
classifier.add(Convolution2D(32, (3, 3), activation='relu'))
# The input shape is set to the pooled feature maps from the previous convolution layer
# Flattening the layers
classifier.add(Flatten())

In [4]:
# Adding a fully connected layer
classifier.add(Dense(units=128, activation='relu'))
classifier.add(Dense(units=4, activation='softmax'))

In [5]:
# Compiling the model
#categorical_crossentropy for multi-class classification
classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])


In [6]:
# Preparing the train/test data and training the model

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

training_set = train_datagen.flow_from_directory('data/train',
                                                 target_size=(64, 64),
                                                 batch_size=32,
                                                 color_mode='grayscale',
                                                 class_mode='categorical')

test_set = test_datagen.flow_from_directory('data/test',
                                            target_size=(64, 64),
                                            batch_size=32,
                                            color_mode='grayscale',
                                            class_mode='categorical')

Found 4162 images belonging to 4 classes.
Found 441 images belonging to 4 classes.


In [7]:
# Adding early stopping callback
#earlystop = EarlyStopping(monitor='val_loss', patience=5, verbose=1)

#checkpoint = ModelCheckpoint('model.h5', monitor='val_loss', save_best_only=True)


classifier.fit(
        training_set,
        steps_per_epoch=1000/32, # No of images in training set
        epochs=10,
        validation_data=test_set,
        validation_steps=100/32)
        #callbacks=[checkpoint])
        #callbacks=[earlystop])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x2ab0f541a30>

In [8]:
# Saving the models
model_json = classifier.to_json()
with open("model-bw.json", "w") as json_file:
    json_file.write(model_json)
classifier.save_weights('model-bw.h5')

classifier.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 62, 62, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 31, 31, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 29, 29, 32)        9248      
                                                                 
 flatten (Flatten)           (None, 26912)             0         
                                                                 
 dense (Dense)               (None, 128)               3444864   
                                                                 
 dense_1 (Dense)             (None, 4)                 516       
                                                        