# Classification - Train

In [None]:
import yaml
import os
import shutil
import subprocess
import datetime
import numpy
import random
import sklearn.metrics
import sklearn.utils.class_weight
import tensorflow

print("INFO> TensorFlow version : %s" % tensorflow.__version__)
print("INFO> # of GPUs available: %d" % len(tensorflow.config.experimental.list_physical_devices('GPU')))

In [None]:
# Read parameters from local config.yaml file, and update corresponding Python variables
currentDir = os.getcwd()
print(f"INFO> Current folder : {currentDir}")
yamlFile = open('config.yaml','r')
yamlData = yaml.load(yamlFile,Loader=yaml.Loader)

for key in sorted(yamlData):
    print("YAML> %-15s: %s" % (key,yamlData[key]))

In [None]:
#myModelInput = tensorflow.keras.layers.Input(shape=(yamlData['imageHeight'],yamlData['imageWidth'],3))
#convBase = tensorflow.keras.applications.VGG16(weights='imagenet',include_top=False, input_shape=(yamlData['imageHeight'],yamlData['imageWidth'],3))(myModelInput)
#x = tensorflow.keras.layers.Flatten()(convBase)
#x = tensorflow.keras.layers.Dense(256,activation="relu")(x)
#x = tensorflow.keras.layers.Dropout(0.5)(x)
#myModelOutput = tensorflow.keras.layers.Dense(12,activation="softmax")(x)
#
#model = tensorflow.keras.models.Model(inputs=myModelInput,outputs=myModelOutput)
#convBase.trainable = False
myModelInput = tensorflow.keras.layers.Input(shape=(yamlData['imageHeight'],yamlData['imageWidth'],3))
x = tensorflow.keras.layers.Conv2D(64, (3,3), activation="relu")(myModelInput)
x = tensorflow.keras.layers.BatchNormalization()(x)
x = tensorflow.keras.layers.MaxPooling2D((2,2))(x)

x = tensorflow.keras.layers.Conv2D(128, (3,3), activation="relu")(x)
x = tensorflow.keras.layers.BatchNormalization()(x)
x = tensorflow.keras.layers.MaxPooling2D((2,2))(x)

x = tensorflow.keras.layers.Conv2D(192, (3,3), activation="relu")(x)
x = tensorflow.keras.layers.BatchNormalization()(x)
x = tensorflow.keras.layers.MaxPooling2D((2,2))(x)

x = tensorflow.keras.layers.Conv2D(192, (3,3), activation="relu")(x)
x = tensorflow.keras.layers.BatchNormalization()(x)
x = tensorflow.keras.layers.MaxPooling2D((2,2))(x)

x = tensorflow.keras.layers.Conv2D(192, (3,3), activation="relu")(x)
x = tensorflow.keras.layers.BatchNormalization()(x)
x = tensorflow.keras.layers.MaxPooling2D((2,2))(x)

x = tensorflow.keras.layers.Conv2D(128, (3,3), activation="relu")(x)
x = tensorflow.keras.layers.BatchNormalization()(x)
x = tensorflow.keras.layers.MaxPooling2D((2,2))(x)

x = tensorflow.keras.layers.Flatten()(x)
x = tensorflow.keras.layers.Dense(128, activation="relu")(x)
x = tensorflow.keras.layers.BatchNormalization()(x)

x = tensorflow.keras.layers.Dropout(0.5)(x)
myModelOutput = tensorflow.keras.layers.Dense(12, activation="softmax")(x)

model = tensorflow.keras.models.Model(inputs=myModelInput, outputs=myModelOutput)
model.summary()

tensorflow.keras.optimizers.RMSprop(lr=float(yamlData['learningRate']))
 
model.compile(loss=tensorflow.keras.losses.CategoricalCrossentropy(),
              optimizer='rmsprop',
              metrics=['accuracy'])

In [None]:
trnDataGen = tensorflow.keras.preprocessing.image.ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

valDataGen = tensorflow.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255)

trnGenerator = trnDataGen.flow_from_directory(
    yamlData['trnDir'],
    target_size=(yamlData['imageHeight'],yamlData['imageWidth']),
    batch_size=yamlData['batchSize'],
    class_mode='categorical')

valGenerator = valDataGen.flow_from_directory(
    yamlData['valDir'],
    target_size=(yamlData['imageHeight'],yamlData['imageWidth']),
    batch_size=yamlData['batchSize'],
    class_mode='categorical')

timeNow = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
fullCheckpointDir = os.path.join(yamlData['checkpointDir'],timeNow)
os.mkdir(fullCheckpointDir)
                                           
# need to replace acc by accuracy below when moving to TF2.0    
filePath = os.path.join(fullCheckpointDir,"{epoch:05d}_{loss:.6f}_{accuracy:.6f}_{val_loss:.6f}_{val_accuracy:.6f}.h5")
checkpoint = tensorflow.keras.callbacks.ModelCheckpoint(filePath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch')

# profile_batch=0 required to solve a bug w/ Tensorboard according to     
#   https://github.com/tensorflow/tensorboard/issues/2412    
fullLogDir = os.path.join(yamlData['logDir'], timeNow)
tensorboardCallback = tensorflow.keras.callbacks.TensorBoard(log_dir=fullLogDir,profile_batch=0)

classWeight = sklearn.utils.class_weight.compute_class_weight(
               'balanced',
                numpy.unique(trnGenerator.classes), 
                trnGenerator.classes)

history = model.fit(
    trnGenerator,
    steps_per_epoch=trnGenerator.n // trnGenerator.batch_size,
    epochs=yamlData['nEpochs'],
    validation_data=valGenerator,
    validation_steps=valGenerator.n // valGenerator.batch_size,
    class_weight=classWeight,
    callbacks=[tensorboardCallback,checkpoint])