In [1]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [2]:
import datetime
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import network
from func import split_info
from params import patches_root, patches_db_path, weights_path
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback

In [3]:
img_height = 256
img_width = 256
batch_size = 64

# Load the training and validation datasets

train_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255, horizontal_flip=True, vertical_flip=True)

validation_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255)

train_data_gen  = train_generator.flow_from_directory(
    directory=r"./train/train/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

validation_data_gen = validation_generator.flow_from_directory(
    directory=r"./train/val/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

Found 37775 images belonging to 3 classes.
Found 9425 images belonging to 3 classes.


In [4]:
print("[*] Define model")
model = network.build()

sgd = tf.optimizers.SGD(lr=0.001, momentum=0.9, decay=0.0005)
model.compile(
    optimizer=sgd, 
    loss='categorical_crossentropy', 
    metrics=['accuracy'])

#           ------------ Train the Model ------------
if not os.path.exists('./saved_model'):
    os.makedirs('./saved_model')
    
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

ConstrainLayer = network.ConstrainLayer(model)
callbacks = [ModelCheckpoint('./saved_model/weights.{epoch:02d}.h5',
    monitor='acc',verbose=1, save_best_only=False,
    save_freq='epoch'), ConstrainLayer, tensorboard_callback]

df = pd.read_csv(weights_path)
class_weight = df.to_dict('records')[0]
class_weight = {int(k):v for k, v in class_weight.items()}

[*] Define model
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 256, 256, 3)       78        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 128, 96)      14208     
_________________________________________________________________
batch_normalization (BatchNo (None, 128, 128, 96)      384       
_________________________________________________________________
activation (Activation)      (None, 128, 128, 96)      0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 64, 64, 96)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 64, 64, 64)        153664    
_________________________________________________________________
batch_normalization_1 (Batch (None, 64,

In [None]:
history = model.fit_generator(train_data_gen, epochs=45, workers=10,
     callbacks=callbacks, validation_data=validation_data_gen, class_weight=class_weight)

Epoch 1/45

Epoch 00001: saving model to ./saved_model/weights.01.h5
  1/591 [..............................] - ETA: 1:44:54 - loss: 1.8856 - accuracy: 0.4375
Epoch 00001: saving model to ./saved_model/weights.01.h5
  2/591 [..............................] - ETA: 54:44 - loss: 2.1276 - accuracy: 0.3438  
Epoch 00001: saving model to ./saved_model/weights.01.h5
  3/591 [..............................] - ETA: 37:35 - loss: 2.0370 - accuracy: 0.3542
Epoch 00001: saving model to ./saved_model/weights.01.h5
  4/591 [..............................] - ETA: 29:04 - loss: 2.1640 - accuracy: 0.3203
Epoch 00001: saving model to ./saved_model/weights.01.h5
  5/591 [..............................] - ETA: 23:59 - loss: 2.1947 - accuracy: 0.3156
Epoch 00001: saving model to ./saved_model/weights.01.h5
  6/591 [..............................] - ETA: 20:35 - loss: 2.2682 - accuracy: 0.2917
Epoch 00001: saving model to ./saved_model/weights.01.h5
  7/591 [..............................] - ETA: 18:07 - l

In [None]:
model.save('model.h5')