## Imports and setup

In [None]:
from keras import layers
import tensorflow_addons as tfa
from tensorflow import keras
import tensorflow as tf
from config import *
from util import *
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import random

# Setting seeds for reproducibility.
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## CIFAR-10 dataset loading and preparation

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:40000], y_train[:40000]),
    (x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")

In [None]:
train_ds = prepare_data(x_train, y_train)
val_ds = prepare_data(x_train, y_train, is_train=False)
test_ds = prepare_data(x_test, y_test, is_train=False)

## ViT model utility

In [None]:
from missing_modality.WarmUpCosine import WarmUpCosine

total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
warmup_steps = int(total_steps * 0.15)
scheduled_lrs = WarmUpCosine(
    learning_rate_base=LEARNING_RATE,
    total_steps=total_steps,
    warmup_learning_rate=0.0,
    warmup_steps=warmup_steps,
)

## Train and evaluate model

In [None]:
optimizer = tfa.optimizers.AdamW(
    learning_rate=scheduled_lrs,
    weight_decay=WEIGHT_DECAY
)

vit_model = create_vit_classifier()
vit_model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)
vit_model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)

loss, accuracy = vit_model.evaluate(test_ds)
accuracy = round(accuracy * 100, 2)
print(f"Accuracy on the test set: {accuracy}%.")

In [None]:
vit_model.save(f"classification_vit_model@acc_{accuracy}", include_optimizer=False)