In [None]:
import os
import glob
import random
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import keras_cv

In [None]:
IMAGES_DIR = "/app/data/images/"

dir_names = os.listdir(IMAGES_DIR)
img_paths_dict = dict()
for dir_name in dir_names:
    img_paths_dict[dir_name] = glob.glob(f"{IMAGES_DIR}{dir_name}/*")

In [None]:
for dir_name, file_paths in img_paths_dict.items():
    print(dir_name, len(file_paths))
    for file_path in file_paths[:5]:
        print(file_path)

In [None]:
for dir_name, file_paths in img_paths_dict.items():
    print(dir_name)
    for file_path in file_paths[:5]:
        image = plt.imread(file_path)
        plt.imshow(image)
        plt.show()

In [None]:
dataset = keras.utils.image_dataset_from_directory(IMAGES_DIR)

In [None]:
images, labels = next(iter(dataset))

print(f"type: {type(images)}")
print(f"dtype: {images.dtype}")
print(f"shape: {images.shape}")

print(dataset.class_names[labels[0]])
plt.imshow(images[0].numpy().astype("uint8"))
plt.show()

In [None]:
batch_size = 16

train_ds = keras.utils.image_dataset_from_directory(
    IMAGES_DIR,
    batch_size=batch_size,
    validation_split=0.2,
    subset="training",
    label_mode="categorical",
    seed=0,
)
valid_ds = keras.utils.image_dataset_from_directory(
    IMAGES_DIR,
    validation_split=0.2,
    subset="validation",
    label_mode="categorical",
    seed=0,
)

In [None]:
backbone = keras_cv.models.EfficientNetV2Backbone.from_preset("efficientnetv2_b0_imagenet")

model = keras_cv.models.ImageClassifier(
    backbone=backbone,
    num_classes=3,
    activation="softmax",
)

In [None]:
loss = keras.losses.CategoricalCrossentropy()
metric = keras.metrics.CategoricalAccuracy()

In [None]:
model.compile(loss=loss, metrics=metric)

In [None]:
epochs = 4
model.fit(train_ds, validation_data=valid_ds, epochs=epochs)

In [None]:
loss, acc = model.evaluate(valid_ds, verbose=False)
print(f"{loss=:.3}, {acc=:.3}")

In [None]:
image_paths = glob.glob(f"{IMAGES_DIR}/*/*")
image = plt.imread(random.choice(image_paths))

predictions = model.predict(image[None, ...], verbose=False)[0]
pred_cls = valid_ds.class_names[predictions.argmax()]

print(pred_cls)
plt.imshow(image)
plt.show()