In [None]:
import tensorflow as tf


In [None]:
image_size = (512, 512)
input_shape = (image_size[0], image_size[1], 3)
num_classes = 2


In [None]:
# base_model = tf.keras.applications.ResNet50V2(
# base_model = tf.keras.applications.EfficientNetB5(
base_model = tf.keras.applications.MobileNetV3Large(
    include_top=False,
    weights='imagenet',
    input_tensor=None,
    input_shape=input_shape,
    pooling=None,
    classes=num_classes,
    classifier_activation='softmax'
)
base_model.trainable = False
# base_model.summary()


In [None]:
inputs = tf.keras.Input(shape=input_shape)
x = base_model(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(num_classes)(x)

model = tf.keras.Model(inputs, outputs)
model.summary()


In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)


In [None]:
dataset_directory = './dataset'

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    dataset_directory,
    labels='inferred',
    label_mode='int',
    class_names=None,
    color_mode='rgb',
    batch_size=32,
    image_size=image_size,
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation='bilinear',
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    data_format=None,
    verbose=True
)


In [None]:
model.fit(dataset, epochs=5)


In [None]:
model.save('./saved_model/model.keras')


In [None]:
with open('./saved_model/class_names.txt', 'w') as file:
  file.write(str(dataset.class_names))


In [None]:
loaded_model = tf.keras.models.load_model('./saved_model/model.keras')


In [None]:
loaded_model.summary()
