In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import layers

import matplotlib.pyplot as plt
import app

In [None]:
BATCH_SIZE = 32

In [None]:
train_data_gen = ImageDataGenerator(rescale=1./255,
        zoom_range=0.2,
        rotation_range=15,
        width_shift_range=0.05,
        height_shift_range=0.05)

print("\nLoading training data...")
train_generator = train_data_gen.flow_from_directory('augmented-data/train', target_size=(256, 256),
class_mode='categorical', color_mode='grayscale',batch_size=BATCH_SIZE)

test_data_gen = ImageDataGenerator(rescale=1.0/255)

print("\nLoading validation data...")
test_generator = test_data_gen.flow_from_directory('augmented-data/test', target_size=(256, 256),class_mode='categorical',color_mode='grayscale',batch_size=BATCH_SIZE)

In [None]:
def build_model():
    model = Sequential()
    model.add(tf.keras.Input(shape=(256, 256, 1)))
    model.add(tf.keras.layers.Conv2D(2, 5, strides=3, activation="relu")) 
    model.add(tf.keras.layers.MaxPooling2D(
      pool_size=(5, 5), strides=(5,5)))
    model.add(tf.keras.layers.Conv2D(4, 3, strides=1, activation="relu")) 
    model.add(tf.keras.layers.MaxPooling2D(
      pool_size=(2,2), strides=(2,2)))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(3,activation="softmax"))
    return model

In [None]:
model = build_model() 
print("\nBuilding model...")
model.summary()

In [None]:
OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=0.001)
LOSS = tf.keras.losses.CategoricalCrossentropy() 

print("\nCompiling model...")
model.compile(
   optimizer=OPTIMIZER,
   loss=LOSS,
   metrics=[tf.keras.metrics.CategoricalAccuracy(),tf.keras.metrics.AUC()]
)

In [None]:
print("\nTraining model...")
history = model.fit(
       train_generator,
       steps_per_epoch=train_generator.samples/BATCH_SIZE,
       epochs=5,
       validation_data=test_generator,
       validation_steps=test_generator.samples/BATCH_SIZE
       )