# Train Models using keras

## Dataset

If you don't have the dataset ready yet, no worries, uncomment the bash command bellow and run it so the data download and prepare to use Keras Data Preprocessor.

In [1]:
import os

from tensorflow import keras
from tensorflow.keras.preprocessing import image_dataset_from_directory

In [2]:
# Download and prepare data (it may take a while, ~250mb)
#os.system("poetry run python data_preprocessing.py")

In [3]:
train = image_dataset_from_directory(
    directory='data/train/',
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(256, 256)
)

Found 11490 files belonging to 8 classes.


In [4]:
val = image_dataset_from_directory(
    directory='data/val/',
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(256, 256)
)

Found 1272 files belonging to 8 classes.


In [15]:
def build_model(n_classes):
    """Create the model using Pretreined Xception on Imagenet"""
    inputs = keras.layers.Input(shape=(256, 256, 3))

    model = keras.applications.Xception( 
        weights='imagenet',
        input_tensor=inputs, 
        classes=n_classes,
        include_top=False
    )

    # Freeze the pretrained weights
    model.trainable = False

    # Rebuild top
    x = keras.layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
    x = keras.layers.BatchNormalization()(x)

    top_dropout_rate = 0.2
    x = keras.layers.Dropout(top_dropout_rate, name="top_dropout")(x)
    outputs = keras.layers.Dense(n_classes, activation="softmax", name="pred")(x)

    model = keras.Model(inputs, outputs, name="Xception")

    model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

    return model

In [16]:
model = build_model(8)

model.fit(train, epochs=1, validation_data=val)

 18/360 [>.............................] - ETA: 1:21:19 - loss: 2.1524

KeyboardInterrupt: 

In [None]:
model.save("models/xception")