# Transfer Learning for Image Classification

This notebook uses [ViT](https://huggingface.co/google/vit-base-patch16-224-in21k) classifier model from 🤗 model hub that was originally trained using [ImageNet](https://image-net.org) and does transfer learning with [Food101](https://huggingface.co/datasets/food101) dataset from 🤗 Datasets.
The notebook performs the following steps:
1. [Import dependencies and setup parameters](#1.-Import-dependencies-and-setup-parameters)
2. [Load the Food101 dataset](#2.-Load-the-Food101-dataset)
3. [Preprocess the dataset](#3.-Preprocess-the-dataset)
4. [Transfer Learning](#4.-Transfer-Learning)
5. [Predict on test subset](#5.-Predict-on-test-subset)

## 1. Import dependencies and setup parameters

This notebook assumes that you have already followed the instructions in the [README.md](/notebooks/README.md) to setup a 🤗 transformers environment with all the dependencies required to run the notebook.

In [None]:
# General
# -------
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import numpy as np
from PIL import Image

from tlt.utils.platform_util import PlatformUtil, OptimizedPlatformUtil


# Huggingface
# -----------
from datasets import load_dataset
from transformers import (
    AutoImageProcessor,
    DefaultDataCollator,
    AutoModelForImageClassification,
    TFAutoModelForImageClassification,
    TrainingArguments,
    Trainer,
    create_optimizer,
    pipeline
)
from transformers.keras_callbacks import KerasMetricCallback
import evaluate


# PyTorch
# -------
import torch
import torchvision.transforms as T


# TensorFlow
# ----------
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
# Specify the the parent directory for the custom or Torchvision dataset
dataset_directory = os.environ["DATASET_DIR"] if "DATASET_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "dataset")
    
# Specify a directory for output
output_directory = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "output")

print("Dataset directory:", dataset_directory)
print("Output directory:", output_directory)

### (Optional) Optimized CPU platform

This step uses TLT's `OptimizedPlatformUtil` to set certain environment variables for TensorFlow and PyTorch to recommended values for optimized model training.

In [None]:
pl = PlatformUtil()
OptimizedPlatformUtil(
    omp_num_threads=pl.cores,
    kmp_blocktime=0,
    kmp_affinity="granularity=fine,compact,1,0",
    tf_num_intraop_threads=pl.cores_per_socket,
    tf_num_interop_threads=pl.sockets,
    force_reset_env_vars=True
)


## 2. Load the Food101 dataset

**Note:** In this notebook, we will load a subset of 5000 train samples from Food101 dataset. You can modify the `split` arg in the `load_dataset()` as you need.

In [None]:
# Load a subset of Food101 dataset
dataset = load_dataset('food101', split='train[:5000]', cache_dir=dataset_directory)

# Split the dataset for training and evaluation
dataset = dataset.train_test_split(test_size=0.2)

# Define variables to hold labels and their mappings
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

# Assign different variables for different frameworks
dataset_for_pyt = dataset
dataset_for_tf = dataset

print(dataset)

### Inspect the dataset

Select a random image from the dataset and see how it actually represented in the dataset object.

In [None]:
import random

select_num = random.randint(0, len(labels) - 1)

img_dict = dataset['train'][select_num]

print(img_dict)
print(id2label[str(img_dict['label'])])
img_dict['image']

# 3. Preprocess the dataset

Run below cell which gets the matching image preprocessor for chosen model. This step is common for both PyTorch and TensorFlow frameworks.

In [None]:
model_name = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(model_name)

image_processor

### Option A: PyTorch

If using PyTorch as a backend, use the torch transforms to apply preprocessing to the datasets.

In [None]:
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)

# Define your transforms
_transforms = T.Compose([
    T.RandomResizedCrop(size), 
    T.ToTensor(),
    T.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
])

# Apply the transforms to the dataset
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

dataset_for_pyt = dataset_for_pyt.with_transform(transforms)

Skip to next step [4. Transfer Learning](#4.-Transfer-Learning) to continue using PyTorch

### Option B: TensorFlow

If using TensorFlow as a backend, use the keras layers to apply preprocessing

In [None]:
size = (image_processor.size["height"], image_processor.size["width"])

# Define your keras layers for preprocessing
train_data_augmentation = keras.Sequential(
    [
        layers.RandomCrop(size[0], size[1]),
        layers.Rescaling(scale=1.0 / 127.5, offset=-1),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="train_data_augmentation",
)

val_data_augmentation = keras.Sequential(
    [
        layers.CenterCrop(size[0], size[1]),
        layers.Rescaling(scale=1.0 / 127.5, offset=-1),
    ],
    name="val_data_augmentation",
)

# Define helper functions to apply preprocessing layers
def convert_to_tf_tensor(image: Image):
    np_image = np.array(image)
    tf_image = tf.convert_to_tensor(np_image)
    # `expand_dims()` is used to add a batch dimension since
    # the TF augmentation layers operates on batched inputs.
    return tf.expand_dims(tf_image, 0)


def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    images = [
        train_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"]
    ]
    example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images]
    return example_batch


def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    images = [
        val_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"]
    ]
    example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images]
    return example_batch

# Set the helper methods to the dataset(s)
dataset_for_tf["train"].set_transform(preprocess_train)
dataset_for_tf["test"].set_transform(preprocess_val)

## 4. Transfer Learning

### Option A: PyTorch

If using PyTorch as a backend, get the model from 🤗 Auto... class and use 🤗 Trainer to train the model

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)


# Define a function to calculate accuracy
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

# Define training args for the Trainer
training_args = TrainingArguments(
    output_dir=output_directory,
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

# Define the Trainer class
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=DefaultDataCollator(),
    train_dataset=dataset_for_pyt["train"],
    eval_dataset=dataset_for_pyt["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

# Train the model
model.train()  # Puts the model in training mode
trainer.train()

### Option B: TensorFlow

If using TensorFlow as a backend, get the model from 🤗 TFAuto... class and use the TensorFlow's `fit()` method to train the model

In [None]:
model = TFAutoModelForImageClassification.from_pretrained(
    model_name,
    id2label=id2label,
    label2id=label2id,
)


# Convert 🤗 Dataset to tf.data.Dataset
batch_size = 16
tf_train_dataset = dataset_for_tf["train"].to_tf_dataset(
    columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size,
    collate_fn=DefaultDataCollator(return_tensors="tf")
)

# converting our test dataset to tf.data.Dataset
tf_eval_dataset = dataset_for_tf["test"].to_tf_dataset(
    columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size,
    collate_fn=DefaultDataCollator(return_tensors="tf")
)


# Define a function to calculate accuracy
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)


# Create optimizer
num_epochs = 2
num_train_steps = len(dataset["train"]) * num_epochs
learning_rate = 3e-5
weight_decay_rate = 0.01

optimizer, lr_schedule = create_optimizer(
    init_lr=learning_rate,
    num_train_steps=num_train_steps,
    weight_decay_rate=weight_decay_rate,
    num_warmup_steps=0,
)

# Define loss and callbacks
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
callbacks = [KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)]


# Compile the model
model.compile(optimizer=optimizer, loss=loss)

# Train the model
model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs, callbacks=callbacks)

## 5. Predict on test subset

In [None]:
from tqdm import tqdm
import random
import matplotlib.pyplot as plt

def get_predictions(image, model, framework):
    model_input = image_processor(
        image,
        return_tensors = 'tf' if framework == 'tensorflow' else 'pt'
    )
    if framework == 'tensorflow':
        logits = model(**model_input).logits
        predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
    if framework == 'pytorch':
        with torch.no_grad():
            logits = model(**model_input).logits
        predicted_class_id = logits.argmax(-1).item()
    return predicted_class_id

selected_image_indices = random.sample(range(0, dataset['test'].num_rows), 30)
selected_images = []
true_labels = []
predicted_labels = []

for s in tqdm(selected_image_indices):
    img = dataset['test'][s]['image']
    label = dataset['test'][s]['label']
    selected_images.append(img)
    true_labels.append(label)
    predicted_labels.append(get_predictions(img, model, 'pytorch' if isinstance(model, torch.nn.Module) else 'tensorflow'))
    
# Visualize the predictions
plt.figure(figsize=(16,16))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
    plt.subplot(6,5,n+1)
    plt.imshow(selected_images[n])
    correct_prediction = true_labels[n] == predicted_labels[n]
    color = "darkgreen" if correct_prediction else "crimson"
    true_label_name = model.config.id2label[str(true_labels[n])]
    predicted_label_name = model.config.id2label[str(predicted_labels[n])]
    title = predicted_label_name if correct_prediction else "{}\n({})".format(predicted_label_name, true_label_name) 
    plt.title(title, fontsize=14, color=color)
    plt.axis('off')
_ = plt.suptitle("ImageNet predictions", fontsize=16)
plt.show()