In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import datetime
import PIL
import tensorflow as tf
import requests

from tensorflow import keras
from tensorflow.keras import layers, models, applications

# Note : we are using TensorFlow Core v2.5.0, in TensorFlow Core v2.6.0 all the data 
# augmentation layers are part of tf.keras.layers
from tensorflow.keras.layers.experimental.preprocessing import RandomFlip, RandomRotation, RandomZoom
from tensorflow.keras import Input, Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard, ModelCheckpoint
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adamax

In [2]:
TRAIN_DIR = "../raw_data/wikiart/wikiart-target_style-class_14-keepgenre_True-merge_style_m1-flat_False/train"
VAL_DIR = "../raw_data/wikiart/wikiart-target_style-class_14-keepgenre_True-merge_style_m1-flat_False/val"
TEST_DIR = "../raw_data/wikiart/wikiart-target_style-class_14-keepgenre_True-merge_style_m1-flat_False/test"

BATCH_SIZE = 128 # Hyper param, you can tune it
EPOCHS = 1000 # Large number, early stopping to stop training before this number
IMG_HEIGHT = 224 # VGG's dim
IMG_WIDTH = 224 # VGG's dim
NUM_CLASSES = 14 # Number of art styles

In [3]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=TRAIN_DIR,
    labels='inferred',
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    label_mode='categorical',
    shuffle=True)

assert len(train_ds.class_names) == NUM_CLASSES

Found 46971 files belonging to 14 classes.


In [4]:
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=VAL_DIR,
    labels='inferred',
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    label_mode='categorical',
    shuffle=True)

assert len(val_ds.class_names) == NUM_CLASSES

Found 5871 files belonging to 14 classes.


In [5]:
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=TEST_DIR,
    labels='inferred',
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    label_mode='categorical',
    shuffle=True)

assert len(test_ds.class_names) == NUM_CLASSES


Found 5872 files belonging to 14 classes.


In [6]:
total_images_count = (int(len(list(train_ds)))+int(len(list(val_ds)))+int(len(list(test_ds))))*BATCH_SIZE

In [7]:
total_images_count

58752

In [8]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

# Optimizing the dataset by caching and prefetching the data
train_ds = train_ds.cache().shuffle(int(total_images_count)).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [9]:
layer_model = applications.VGG16(
    include_top=False, # We do not include VGG classification layers
    weights='imagenet', # We import VGG pre-trained on ImageNet
    input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), 
    classes=NUM_CLASSES)

In [10]:
# layer_model.trainable = False # We do not train VGG weights

# layer_model.layers[-2:] # Set the two last layers as trainable
# for layer in layer_model.layers[-4:]:
#     layer.trainable = True

trainable_layer_count = 0
for i in range(len(layer_model.layers)):
    if layer_model.layers[i].trainable:
        trainable_layer_count += 1

layer_model.trainable = True

trainable_layer_count


19

In [11]:
data_augmentation_layers = tf.keras.models.Sequential([
    RandomFlip("horizontal", input_shape=(224, 224, 3)),
    RandomRotation(0.3),
    RandomZoom(0.3)
])

In [12]:
tf.keras.backend.clear_session()

inputs = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3))

x = data_augmentation_layers(inputs)

x = applications.vgg16.preprocess_input(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)

outputs = layers.Dense(NUM_CLASSES, activation='softmax', name="classification_layer")(x)

model = Model(inputs, outputs)


In [13]:
es = EarlyStopping(monitor='val_loss', patience=10, mode='min', restore_best_weights=True)

# You can add it to the callbacks if you want to save checkpoints
checkpoint_dir = "../VGG16/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + f"-unfreeze_{trainable_layer_count}"
mcp = ModelCheckpoint(
    filepath=checkpoint_dir,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_freq=10,
    save_best_only=True)


In [14]:
%load_ext tensorboard
recorded_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 
log_dir = f"../VGG16/" + \
    recorded_time + \
    f"-images_{total_images_count}" + \
    f"-unfreeze_{trainable_layer_count}" + \
    f"-batch_{BATCH_SIZE}"

tsboard = TensorBoard(log_dir=log_dir)

In [15]:
model.compile(optimizer=tf.keras.optimizers.Adamax(learning_rate=0.001), 
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), 
              metrics=['accuracy'])

history = model.fit(
    train_ds, 
    epochs=EPOCHS, 
    validation_data=val_ds, 
    callbacks=[es, tsboard], 
    use_multiprocessing=True)



Epoch 1/1000


  output, from_logits = _get_logits(
2023-10-11 10:39:38.280500: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 12 of 58752
2023-10-11 10:39:48.548006: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 32 of 58752
2023-10-11 10:39:58.434776: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 52 of 58752
2023-10-11 10:40:08.177018: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 71 of 58752
2023-10-11 10:40:18.343187: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 92 of 58752
2023-10-11 10:40:28.787610: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 112 of 58752
2023-10-11 10:40:38.154059: I tensorflow/core/kernels/data/shuffle_d

Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000


In [19]:
model.evaluate(test_ds, callbacks=tsboard)



[2.107649564743042, 0.283549040555954]

In [20]:
model.save('model.h5')

  saving_api.save_model(


: 