In [1]:
import datetime

from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.layers import RandomRotation, RandomZoom, Rescaling, Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.models import Sequential
import tensorflow as tf

epochs = 2000

width = 256
height = 256
model_name = 'v1.h5'

In [2]:
def get_new_model(width, height):
    model = Sequential([
        # preprocessing layers 
        Rescaling(1./255, input_shape=(width, height, 1)),
        RandomRotation(0.2),
        RandomZoom(0.2, 0.2),
        # convolutional layers
        Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1'),
        MaxPooling2D(pool_size=(2, 2), name='maxpool1'),
        Conv2D(64, (3, 3), activation='relu', padding='same', name='conv2'),
        MaxPooling2D(pool_size=(2, 2), name='maxpool2'),
        Conv2D(128, (3, 3), activation='relu', padding='same', name='conv3'),
        MaxPooling2D(pool_size=(2, 2), name='maxpool3'),
        Flatten(name='flatten'),
        Dense(128, activation='relu', name='dense1'),
        Dropout(0.5, name='dropout'),
        Dense(2, activation='softmax', name='output')
    ])
    
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

In [3]:
train_dir = "./data_training/training"
validation_dir = "./data_training/validation"

def train_model(model, batch_size, width, height, epochs, model_name, initial_epoch=0, log_dir=None):
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        './data_training/training',
        image_size=(width, height),
        batch_size=batch_size,
        label_mode='categorical',
        color_mode='grayscale')

    validation_ds = tf.keras.preprocessing.image_dataset_from_directory(
        './data_training/validation',
        image_size=(width, height),
        batch_size=batch_size,
        label_mode='categorical',
        color_mode='grayscale')

    # Use the existing log directory if provided, else create a new one
    if not log_dir:
        log_dir = f"logs/{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}_{model_name}"
    
    tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=1, write_images=True)
    checkpoint_path = f"{model_name}_{{epoch:02d}}.weights.h5"
    checkpoint = ModelCheckpoint(checkpoint_path, save_weights_only=True, save_freq='epoch')

    history = model.fit(
        train_ds,
        epochs=epochs,
        initial_epoch=initial_epoch,
        validation_data=validation_ds,
        callbacks=[tensorboard, checkpoint])

    return model, history

In [None]:
# Load the existing model and weights
model = get_new_model(width, height)

# Resume training
checkpoint_file = 'v1.h5_1528.weights.h5'
model.load_weights(checkpoint_file)
existing_log_dir = "logs/20240502-233028_v1.h5"
#existing_log_dir = None
initial_epoch = 1528

model, history = train_model(model, batch_size=32, width=width, height=height, epochs=epochs, model_name=model_name, initial_epoch=initial_epoch, log_dir=existing_log_dir)