In [1]:
import tensorflow as tf

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [3]:
import os
import pandas as pd
import datetime
import warnings
import network
import numpy as np
import func
import fnmatch
from params import dresden_csv, ins_train_csv, ins_test_csv, ins_train, ins_test, \
                ins_patches_db, ins_weights
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback

In [7]:
img_height = 256
img_width = 256
batch_size = 64

train_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255, horizontal_flip=True, vertical_flip=True)

validation_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255)

train_data_gen = train_generator.flow_from_directory(
    directory=r"./instance/train/train/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

validation_data_gen = validation_generator.flow_from_directory(
    directory=r"./instance/train/val/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

Found 12227 images belonging to 3 classes.
Found 3050 images belonging to 3 classes.


In [None]:
# Load the training and validation datasets
print("[*] Define model")
model = network.build()

sgd = tf.optimizers.SGD(lr=0.001, momentum=0.9, decay=0.0005)
model.compile(
    optimizer=sgd, 
    loss='categorical_crossentropy', 
    metrics=['accuracy'])

#           ------------ Train the Model ------------
if not os.path.exists('./instance/saved_model'):
    os.makedirs('./instance/saved_model')
    
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

ConstrainLayer = network.ConstrainLayer(model)
callbacks = [ModelCheckpoint('./instance/saved_model/weights.{epoch:02d}.h5',
    monitor='acc',verbose=1, save_best_only=False,
    save_freq=1), ConstrainLayer, tensorboard_callback]

df = pd.read_csv(ins_weights)
class_weight = df.to_dict('records')[0]
class_weight = {int(k):v for k, v in class_weight.items()}

history = model.fit_generator(generator=train_data_gen, epochs=45, workers=10,
     callbacks=callbacks, validation_data=validation_data_gen, class_weight=class_weight)

model.save('./instance/model.h5')