The Vision Transformer (ViT) is a transformer encoder model (BERT-like) pretrained on a large collection of images in a supervised fashion, namely ImageNet-21k, at a resolution of 224x224 pixels. Next, the model was fine-tuned on ImageNet (also referred to as ILSVRC2012), a dataset comprising 1 million images and 1,000 classes, also at resolution 224x224.
Images are presented to the model as a sequence of fixed-size patches (resolution 16x16), which are linearly embedded.
arXiv:2010.11929

In [None]:
import keras
import keras_hub
import tensorflow as tf
import os
import sys

sys.path.append(os.path.abspath("../"))
from imagenet2012_utils import ImageNetDataset
os.environ["KERAS_BACKEND"] = "tensorflow"

In [None]:
gpus = tf.config.list_physical_devices(device_type='GPU')
print(gpus)
if gpus:
  try:
    # Set memory growth to avoid DNN library initialization errors
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    # Optionally, restrict to only the first GPU
    tf.config.set_visible_devices(gpus[0], 'GPU')
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(f"{len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
  except RuntimeError as e:
    print(e)

In [None]:
(training_images, training_labels) ,\
(validation_images, validation_labels) = \
ImageNetDataset.load_validation_dataset(mode='ViT-b_16p')

In [None]:
model = keras_hub.models.ImageClassifier.from_preset(
  "vit_base_patch16_224_imagenet",
  activation="softmax",
)

model.compile(
  loss="sparse_categorical_crossentropy",
  metrics=["accuracy"]
)

model.summary()

In [None]:
model.evaluate(validation_images, validation_labels)
model.evaluate(validation_images[:2000], validation_labels[:2000])

In [None]:
MODEL_PATH = "../models_data/ViT-b_16p_224/"
os.makedirs(MODEL_PATH, exist_ok=True)
model.save(MODEL_PATH+'/ViT-b_16p_224_fp32_imagenet2012.keras')

In [None]:
def representative_dataset():
  for images in tf.data.Dataset.from_tensor_slices(validation_images).batch(1).take(1000):
    yield [images]

loaded_model = keras.saving.load_model("../models_data/ViT-b_16p_224/ViT-b_16p_224_fp32_imagenet2012.keras")
converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT] 
converter.representative_dataset = representative_dataset

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_quant_model = converter.convert()
with open('../models_data/ViT-b_16p_224/ViT-b_16p_224_uint8_imagenet2012.tflite', 'wb') as f:
  f.write(tflite_quant_model)

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_quant_model = converter.convert()
with open('../models_data/ViT-b_16p_224/ViT-b_16p_224_int8_imagenet2012.tflite', 'wb') as f:
  f.write(tflite_quant_model)