In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
keras_vit_keras_vit_base_patch16_224_imagenet_2_path = kagglehub.model_download('keras/vit/Keras/vit_base_patch16_224_imagenet/2')

print('Data source import complete.')


# Vision Transformer (ViT) Quickstart with Keras-Hub

## Install dependencies

In [None]:
!pip install -U -q keras

In [None]:
! pip install -U -q git+https://github.com/keras-team/keras-hub.git

## Set a backend

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # @param ["tensorflow", "jax", "torch"]

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

import keras
import keras_hub

from keras_hub.layers import ViTImageConverter
from keras_hub.models import ViTImageClassifierPreprocessor

In [None]:
img_shape = (224, 224, 3)
batch_size = 16

## Load Cats Vs Dogs training data

In [None]:
train_data, valid_data, test_data = tfds.load(
    "cats_vs_dogs",
    split=["train[:70%]", "train[70%:90%]", "train[90%:]"],
    with_info=False,
    download=True,
)

resize = keras.layers.Resizing(*img_shape[:-1])

In [None]:
train_data = train_data.map(
    lambda x: (x['image'], x['label']),
    num_parallel_calls=tf.data.AUTOTUNE
)
train_data = train_data.map(
    lambda x, y: (resize(x), y),
    num_parallel_calls=tf.data.AUTOTUNE
)
train_data = train_data.batch(batch_size)

In [None]:
valid_data = valid_data.map(
    lambda x: (x['image'], x['label']),
    num_parallel_calls=tf.data.AUTOTUNE
)
valid_data = valid_data.map(
    lambda x, y: (resize(x), y),
    num_parallel_calls=tf.data.AUTOTUNE
)
valid_data = valid_data.batch(batch_size)

In [None]:
test_data = test_data.map(
    lambda x: (x['image'], x['label']),
    num_parallel_calls=tf.data.AUTOTUNE
)
test_data = test_data.map(
    lambda x, y: (resize(x), y),
    num_parallel_calls=tf.data.AUTOTUNE
)
test_data = test_data.batch(batch_size)

## Instantiate a model

In [None]:
backbone = keras_hub.models.Backbone.from_preset(
    "vit_base_patch16_224_imagenet"
)

preprocessor = keras_hub.models.ViTImageClassifierPreprocessor.from_preset(
    "vit_base_patch16_224_imagenet"
)

vit = keras_hub.models.ViTImageClassifier(
    backbone=backbone,
    num_classes=2,
    preprocessor=preprocessor,
)

In [None]:
vit.summary()

## Fine-tune the model

In [None]:
history = vit.fit(
    train_data,
    epochs=1,
    validation_data=valid_data
)

In [None]:
images, y_true = next(iter(test_data))
y_pred = vit.predict(images)

In [None]:
label_map = {
    0: "cat",
    1: "dog"
}

## Run inference on test data

In [None]:
plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(f"True: {label_map[int(y_true[i])]}, Pred: {label_map[int(y_pred[i].argmax())]}")
    plt.axis("off")