In [1]:
import tensorflow as tf

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [3]:
import datetime
import os
import matplotlib.pyplot as plt
import numpy as np
import network
from params import patches_root, train_db_path, test_db_path
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback

In [4]:
img_height = 256
img_width = 256
batch_size = 64

In [5]:
# Load the training and validation datasets

train_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255, horizontal_flip=True, vertical_flip=True)

validation_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255)

train_data_gen  = train_generator.flow_from_directory(
    directory=r"./patches/train/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

validation_data_gen = validation_generator.flow_from_directory(
    directory=r"./patches/val/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

Found 15700 images belonging to 3 classes.
Found 3900 images belonging to 3 classes.


In [7]:
print("[*] Define model")
model = network.build()

[*] Define model
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 256, 256, 3)       78        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 128, 96)      14208     
_________________________________________________________________
batch_normalization (BatchNo (None, 128, 128, 96)      384       
_________________________________________________________________
activation (Activation)      (None, 128, 128, 96)      0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 64, 64, 96)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 64, 64, 64)        153664    
_________________________________________________________________
batch_normalization_1 (Batch (None, 64,

In [8]:
sgd = tf.optimizers.SGD(lr=0.001, momentum=0.9, decay=0.0005)
model.compile(
    optimizer=sgd, 
    loss='categorical_crossentropy', 
    metrics=['accuracy'])

In [9]:
#           ------------ Train the Model ------------
if not os.path.exists('./saved_model'):
    os.makedirs('./saved_model')
    
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

ConstrainLayer = network.ConstrainLayer(model)
callbacks = [ModelCheckpoint('./saved_model/weights.{epoch:02d}.h5',
    monitor='acc',verbose=1, save_best_only=False,
    save_freq=1), ConstrainLayer, tensorboard_callback]

history = model.fit(x=train_data_gen, validation_data=validation_data_gen, epochs=45, workers=10,
     callbacks=callbacks)  #removed validation data

Epoch 1/45

Epoch 00001: saving model to ./saved_model/weights.01.h5
  1/246 [..............................] - ETA: 21:04 - loss: 1.6401 - accuracy: 0.2188
Epoch 00001: saving model to ./saved_model/weights.01.h5
  2/246 [..............................] - ETA: 11:20 - loss: 1.4875 - accuracy: 0.2422
Epoch 00001: saving model to ./saved_model/weights.01.h5
  3/246 [..............................] - ETA: 8:03 - loss: 1.3588 - accuracy: 0.2865 
Epoch 00001: saving model to ./saved_model/weights.01.h5
  4/246 [..............................] - ETA: 6:24 - loss: 1.3588 - accuracy: 0.3320
Epoch 00001: saving model to ./saved_model/weights.01.h5
  5/246 [..............................] - ETA: 5:23 - loss: 1.3795 - accuracy: 0.3688
Epoch 00001: saving model to ./saved_model/weights.01.h5
  6/246 [..............................] - ETA: 4:43 - loss: 1.3273 - accuracy: 0.4115
Epoch 00001: saving model to ./saved_model/weights.01.h5
  7/246 [..............................] - ETA: 4:15 - loss: 1.3

In [10]:
model.save('model.h5')