In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from vit import utils, vit, our_vit

from tensorflow.keras import optimizers
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import numpy as np
import matplotlib.pyplot as plt
import gc

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu,True)

In [None]:
seed = 2022
np.random.seed(seed)
tf.random.set_seed(seed)

In [None]:
(train_data, train_label), (test_data, test_label) = cifar10.load_data()
train_label = to_categorical(train_label)
test_label = to_categorical(test_label)
train_data = (train_data/255.).astype("float32")
test_data = (test_data/255.).astype("float32")

In [None]:
X_train, X_valid, y_train, y_valid = train_test_split(train_data,train_label,random_state=seed,shuffle=True,train_size=0.8)

In [None]:
batch_size = 16
datagen = ImageDataGenerator(rotation_range=15,width_shift_range=0.25, 
                             height_shift_range=0.25,zoom_range=0.25,horizontal_flip=True)
train_generator = datagen.flow(X_train, y_train, batch_size=batch_size)

In [None]:
checkpoint_vit = ModelCheckpoint(r'C:\Users\fano\Desktop\weights\vit4_6.h5', 
    verbose=1, 
    monitor='val_loss',
    save_best_only=True, 
    mode='auto'
)  

In [None]:
model = our_vit.ViT(image_size=224,
                   patch_size=16,    
                   num_classes=10,
                   hidden_size=768,  
                   num_layers=12,
                   num_heads=3,    
                   mlp_dim=3072,    
                   dropout=0.1,
                   emb_dropout=0
                   )

optimizer = tfa.optimizers.AdamW(learning_rate=0.0001,weight_decay=0.00001)

model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"])
model.summary()
gc.collect()

In [None]:
history = model.fit(train_generator,
                    epochs=100,
                    validation_data=(X_valid, y_valid)
                   )
print("\nTest Accuracy: ", accuracy_score(np.argmax(test_label, axis=1), np.argmax(model.predict(test_data), axis=1)))

In [None]:
# save the model !!!

In [None]:
# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
results = model.evaluate(test_data,test_label, batch_size=32, verbose=1)
print("test loss, test acc:", results)