In [3]:
from tensorflow.keras.models import load_model

# Load the saved model
loaded_model = load_model('art_style_recognition_final_model.h5')

In [4]:
import tensorflow as tf

TEST_DIR = "../raw_data/wikiart/wikiart-target_style-class_14-keepgenre_True-merge_style_m1-flat_False/test"

BATCH_SIZE = 128 # Hyper param, you can tune it
EPOCHS = 1000 # Large number, early stopping to stop training before this number
IMG_HEIGHT = 224 # VGG's dim
IMG_WIDTH = 224 # VGG's dim
NUM_CLASSES = 14 # Number of art styles

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=TEST_DIR,
    labels='inferred',
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    label_mode='categorical',
    shuffle=True)

assert len(test_ds.class_names) == NUM_CLASSES

Found 5872 files belonging to 14 classes.


In [5]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [6]:
test_predictions = loaded_model.predict(test_ds)

top5_classes = tf.math.top_k(test_predictions, k=5).indices
top5_probilities = tf.math.top_k(test_predictions, k=5).values

print("Top 5 classes for the first image:", top5_classes[0].numpy())
print("Probabilities for the top 5 classes for the first image:", top5_probilities[0].numpy())

Top 5 classes for the first image: [ 6  1  4  3 10]
Probabilities for the top 5 classes for the first image: [0.7638249  0.10111675 0.08073607 0.03000777 0.00939403]


In [9]:
success_count = 0

for images, labels in test_ds:
    # Get the actual class numbers
    true_labels = tf.argmax(labels, axis=1).numpy()

    # Get the predicted class numbers
    predicted_labels = tf.math.top_k(loaded_model.predict(images), k=5).indices.numpy()

    # Check if the true label is in the predicted labels for each image
    for i in range(len(true_labels)):
        if true_labels[i] in predicted_labels[i]:
            success_count += 1

success_rate = success_count / (len(test_ds) * BATCH_SIZE)
print("Top-5 Accuracy:", success_rate)

Top-5 Accuracy: 0.9599184782608695


In [10]:
test_loss, test_accuracy = loaded_model.evaluate(test_ds)



In [11]:
print(test_accuracy)

0.5894073843955994
