In [None]:
import warnings
warnings.filterwarnings('ignore')

from tensorflow.keras.applications.xception import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from models import create_fixed_input_shape_model, create_variable_input_shape_model 

In [None]:
IMAGE_SIZE = (160, 160)
NUM_CLASSES = 10
train_dir = "datasets/imagenette2/train"
test_dir = "datasets/imagenette2/val"

# __Training a Fixed Input Shape Model__

In [None]:
data_augment_generator = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    zoom_range=0.2,
    brightness_range=[0.2,1.0],
    horizontal_flip=True,
    preprocessing_function=preprocess_input,
    fill_mode="nearest",
)

train_data_generator = data_augment_generator.flow_from_directory(
    train_dir, batch_size=32, class_mode="categorical", target_size=IMAGE_SIZE
)

data_generator = ImageDataGenerator(preprocessing_function=preprocess_input)

test_data_generator = data_generator.flow_from_directory(
    test_dir, batch_size=64, class_mode="categorical", target_size=IMAGE_SIZE
)

In [None]:
model = create_fixed_input_shape_model(IMAGE_SIZE, NUM_CLASSES)
model.compile(
    optimizer="RMSProp",
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)

In [None]:
history = model.fit(
    train_data_generator,
    validation_data=test_data_generator,
    epochs=20,
    verbose=1,
)

In [None]:
model.save_weights("fixed-imagenette2.h5")

# __Training a Variable Input Shape Model__

In [None]:
model = create_variable_input_shape_model(NUM_CLASSES)
model.compile(
    optimizer="RMSProp",
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)

In [None]:
history = model.fit(
    train_data_generator,
    validation_data=test_data_generator,
    epochs=20,
    verbose=1,
)

In [None]:
model.save_weights("variable-imagenette2.h5")