In [24]:
from vit_keras import vit
import tensorflow as tf
from typing import *
from tqdm import tqdm

In [25]:
CFG_SEED=71
NUM_EPOCHS = 30

In [26]:
vit_model = vit.vit_b16(
        image_size=224,
        activation='softmax',
        pretrained=True,
        include_top=False,
        pretrained_top=False,
        classes=44)

for layer in vit_model.layers:
    layer.trainable = False


In [27]:
initializer = tf.keras.initializers.GlorotNormal(seed=CFG_SEED)
early_stoppage = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=3)

In [28]:
def vit_b16_model():    
    vit_b16_sequential = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(224,224,3), dtype=tf.float32, name='input_image'),
        vit_model,
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(512, activation='relu', kernel_initializer=initializer),
        tf.keras.layers.Dense(256, activation='relu', kernel_initializer=initializer),
        tf.keras.layers.Dense(44, dtype=tf.float32, activation='softmax', kernel_initializer=initializer)
    ], name='vit_b16_sequential_model')
    
    return vit_b16_sequential

In [29]:
# Generate Model
model_vit_b16 = vit_b16_model()

# Generate Summary of the Model
model_vit_b16.summary()

Model: "vit_b16_sequential_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vit-b16 (Functional)        (None, 768)               85798656  
                                                                 
 dropout_2 (Dropout)         (None, 768)               0         
                                                                 
 dense_6 (Dense)             (None, 512)               393728    
                                                                 
 dense_7 (Dense)             (None, 256)               131328    
                                                                 
 dense_8 (Dense)             (None, 44)                11308     
                                                                 
Total params: 86,335,020
Trainable params: 536,364
Non-trainable params: 85,798,656
_________________________________________________________________


In [30]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    directory='training_data/',
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(224, 224))


validation_ds = tf.keras.utils.image_dataset_from_directory(
    directory='validation_data/',
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(224, 224))

Found 3581 files belonging to 44 classes.
Found 895 files belonging to 44 classes.


In [31]:
val_batches = tf.data.experimental.cardinality(validation_ds)
print('Number of val batches: %d' % val_batches)
test_dataset = validation_ds.take(val_batches // 5)
validation_data = validation_ds.skip(val_batches // 5)

Number of val batches: 28


In [32]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
val_ds = validation_data.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
test_ds = test_dataset.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.


In [33]:
# val_batches = tf.data.experimental.cardinality(validation_ds)
# print('Number of val batches: %d' % val_batches)
# test_dataset = validation_ds.take(val_batches // 5)
# validation_data = validation_ds.skip(val_batches // 5)

In [34]:
tf.random.set_seed(CFG_SEED)

# Compile the model
model_vit_b16.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(),
    # loss = tf.keras.losses.BinaryCrossentropy(),

    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    metrics=['accuracy']
)


history = model_vit_b16.fit(train_ds,validation_data=val_ds, epochs=NUM_EPOCHS, callbacks=[early_stoppage])
print(model_vit_b16.evaluate(test_ds))

Epoch 1/30
 10/112 [=>............................] - ETA: 14:56 - loss: 3.5405 - accuracy: 0.1156

In [None]:
# ## 2 EPOCHS TRAINING: 
#Epoch 1/2
# 112/112 [==============================] - 2820s 25s/step - loss: 2.5552 - accuracy: 0.3250 - val_loss: 1.9654 - val_accuracy: 0.4395
# Epoch 2/2
# 112/112 [==============================] - 2754s 25s/step - loss: 1.6741 - accuracy: 0.5035 - val_loss: 1.5377 - val_accuracy: 0.5401
# 1/5 [=====>........................] - ETA: 1:13 - loss: 1.6646 - accuracy: 0.4375


#5/5 [==============================] - 98s 20s/step - loss: 1.7411 - accuracy: 0.4500
# [1.7410808801651, 0.44999998807907104]

In [None]:
from matplotlib import pyplot as plt

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('accuracy.png')
plt.show()


plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('loss.png')
plt.show()