# Transfer Learning for Image Classification using the TLK API

This notebook uses the `tlk` library to do transfer learning for image classfication with a TensorFlow pretrained model.

## 1. Import dependencies and setup parameters

This notebook assumes that you have already followed the instructions in the [notebooks README.md](/notebooks/README.md) to setup a TensorFlow environment with all the dependencies required to run the notebook.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import PIL.Image as Image
import tensorflow as tf

from tlk.datasets import dataset_factory
from tlk.models import model_factory

# Specify a directory for the dataset to be downloaded
dataset_dir = os.environ["DATASET_DIR"] if "DATASET_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "dataset")
     
# Specify a directory for output
output_dir = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "output")

print("Dataset directory:", dataset_dir)
print("Output directory:", output_dir)

## 2. Get the model

In this step, we call the TLK model factory to list supported TensorFlow image classification models. This is a list of pretrained models from [TFHub](https://tfhub.dev) that we tested with our API. Optionally, the `verbose=True` argument can be added to the `print_supported_models` function call to get more information about each model (such as the link to TFHub, image size, the original dataset, etc).

In [None]:
# See a list of available models
model_factory.print_supported_models(use_case='image_classification', framework='tensorflow')

Next, use the TLK model factory to get one of the models listed in the previous cell. The `get_model` function returns a TLK model object that will later be used for training.

In [None]:
model = model_factory.get_model(model_name='efficientnet_b0', framework='tensorflow')

print("Model name:", model.model_name)
print("Framework:", model.framework)
print("Use case:", model.use_case)
print("Image size:", model.image_size)

## 3. Prepare the dataset

We call the TLK dataset factory to get a sample image classification dataset. For demonstration purposes, we are using the [tf_flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers) dataset from the [TensorFlow Datasets catalog](https://www.tensorflow.org/datasets). This dataset contains images of flowers in 5 different classes.

In [None]:
flowers = dataset_factory.get_dataset(dataset_dir=dataset_dir,
                                      use_case='image_classification', 
                                      framework='tensorflow',
                                      dataset_name='tf_flowers',
                                      dataset_catalog='tf_datasets')

print(flowers.info)

In [None]:
print("Class names:", str(flowers.class_names))

In [None]:
# Preprocess the dataset with an image size that matches the model and a batch size of 32
flowers.preprocess(model.image_size, batch_size=32)

## 4. Transfer Learning

This step calls the TLK model's train function with the dataset that was just prepared. The training function will get the TFHub feature vector and add on a dense layer based on the number of classes in the dataset. The model is then compiled and trained based on the number of epochs specified in the argument.

In [None]:
history = model.train(flowers, output_dir=output_dir, epochs=1)

## 5. Evaluate

The next step shows how the model can be evaluated. Pass a dataset to the model's evaluate function, and it returns a a list of metrics.

In [None]:
metrics = model.evaluate(flowers)

In [None]:
# Print evaluation metrics
for metric_name, metric_value in zip(model._model.metrics_names, metrics):
    print("{}: {}".format(metric_name, metric_value))

## 6. Predict

To predict with a single batch, we first call `get_batch` from our TLK dataset object to get a list of images and labels for a single batch. We map the labels to our class names to get the human readable string labels.

In [None]:
images, labels = flowers.get_batch()
labels = [flowers.class_names[id] for id in labels]

Next, call the `predict` function from our TLK model object with our batch of images to get a list of predictions. Again, we map these predictions to our dataset's class names to get the human readable string labels.

In [None]:
predictions = model.predict(images)
predictions = [flowers.class_names[id] for id in predictions]

We create a dataset frame to display the predicted labels with the actual label for our batch.

In [None]:
# Create a results table to list out the class prediction vs the actual dataset label
results_table = []
for prediction, actual in zip(predictions, labels):
    results_table.append([prediction, actual])

pd.DataFrame(results_table, columns=["Prediction", "Actual Label"])

We can also visualize results by displaying the image along with it's predicted label.

In [None]:
# Display the images with the predict label
plt.figure(figsize=(18,14))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
    plt.subplot(6,5,n+1)
    plt.imshow(images[n])
    plt.title(predictions[n].title(), fontsize=16)
    plt.axis('off')
_ = plt.suptitle("Model predictions", fontsize=20)
plt.show()

We can also predict using a single image that wasn't part of our original dataset. We download a flower image from the [Open Images Dataset](https://storage.googleapis.com/openimages/web/index.html) and then resize it to match our model.

In [None]:
# Download an image from the web and resize it to match our model
image_url = 'https://c8.staticflickr.com/8/7095/7210797228_c7fe51c3cb_z.jpg'
daisy = tf.keras.utils.get_file(origin=image_url)

image_shape = (model.image_size, model.image_size)
daisy = Image.open(daisy).resize(image_shape)
daisy

Then, we call predict by passing the np array for our image and add a dimension to our array to represent the batch.

In [None]:
# Get the image as a np array and call predict while adding a batch dimension (with np.newaxis) 
daisy = np.array(daisy)/255.0
result = model.predict(daisy[np.newaxis, ...])

# Print the predicted class name
print(flowers.class_names[result[0]])

## 7. Export

Lastly, we can call the TLK model `export` function to generate a `saved_model.pb`. The model is saved in a format that is ready to use with [TensorFlow Serving](https://github.com/tensorflow/serving). Each time the model is exported, a new numbered directory is created, which allows serving to pick up the latest model. 

In [None]:
model.export(output_dir)

## Dataset Citations

```
@ONLINE {tfflowers,
author = "The TensorFlow Team",
title = "Flowers",
month = "jan",
year = "2019",
url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }

@article{openimages,
  title={OpenImages: A public dataset for large-scale multi-label and multi-class image classification.},
  author={Krasin, Ivan and Duerig, Tom and Alldrin, Neil and Veit, Andreas and Abu-El-Haija, Sami
    and Belongie, Serge and Cai, David and Feng, Zheyun and Ferrari, Vittorio and Gomes, Victor
    and Gupta, Abhinav and Narayanan, Dhyanesh and Sun, Chen and Chechik, Gal and Murphy, Kevin},
  journal={Dataset available from https://github.com/openimages},
  year={2016}
}
```