# Transfer Learning and Graph Optimization using TensorFlow and the Intel® Transfer Learning Tool API

This notebook uses the `tlt` library to do transfer learning and graph optimization for image classfication with a TensorFlow pretrained model.

## 1. Import dependencies and setup parameters

This notebook assumes that you have already followed the instructions in the [notebooks README.md](/notebooks/README.md) to setup a TensorFlow environment with all the dependencies required to run the notebook.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import PIL.Image as Image
import tensorflow as tf

# tlt imports
from tlt.datasets import dataset_factory
from tlt.models import model_factory
from tlt.utils.file_utils import download_and_extract_tar_file

# Specify a directory for the dataset to be downloaded
dataset_dir = os.environ["DATASET_DIR"] if "DATASET_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "dataset")
     
# Specify a directory for output
output_dir = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "output")

print("Dataset directory:", dataset_dir)
print("Output directory:", output_dir)

## 2. Get the model

In this step, we call the Intel Transfer Learning Tool model factory to list supported TensorFlow image classification models. This is a list of pretrained models from [TFHub](https://tfhub.dev) and [Keras Applications](https://keras.io/api/applications/) that we tested with our API. Optionally, the `verbose=True` argument can be added to the `print_supported_models` function call to get more information about each model (such as the image size, the original dataset, the preprocessor, etc).

In [None]:
# See a list of available models
model_factory.print_supported_models(use_case='image_classification', framework='tensorflow')

Next, use the model factory to get one of the models listed in the previous cell. The `get_model` function returns a model object that will later be used for training.

In [None]:
model = model_factory.get_model(model_name='resnet_v1_50', framework='tensorflow')

print("Model name:", model.model_name)
print("Framework:", model.framework)
print("Use case:", model.use_case)
print("Image size:", model.image_size)

## 3. Get the dataset

We call the dataset factory to load sample image classification dataset. For demonstration purposes, we will download a flower species dataset. After downloading and extracting, you will have the following  subdirectories in your dataset directory. Each species subfolder will contain numerous `.jpg` files:

```
flower_photos
  └── daisy
  └── dandelion
  └── roses
  └── sunflowers
  └── tulips
```

When using your own dataset, ensure that it is similarly organized with folders for each class. Change the `custom_dataset_path` variable to point to your dataset folder.

In [None]:
# For demonstration purposes, we download a flowers dataset. To instead use your own dataset, set the
# custom_dataset_path to point to your dataset's directory and comment out the download_and_extract_tar_file line.
custom_dataset_path = os.path.join(dataset_dir, "flower_photos")

if not os.path.exists(custom_dataset_path):
    download_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    download_and_extract_tar_file(download_url, dataset_dir)

Call the dataset factory to load the dataset from the directory.

In [None]:
# Load the dataset from the custom dataset path
dataset = dataset_factory.load_dataset(dataset_dir=custom_dataset_path,
                                       use_case='image_classification', 
                                       framework='tensorflow')

print("Class names:", str(dataset.class_names))

## 4. Prepare the dataset

Once you have your dataset, use the following cells to split and preprocess the data. We split them into training and validation subsets, then resize the images to match the selected models, and then batch the images.

In [None]:
# Split the dataset into training and validation subsets
dataset.shuffle_split(train_pct=.75, val_pct=.25)

In [None]:
# Preprocess the dataset with an image size and preprocessor that match the model and a batch size of 32
batch_size = 32
dataset.preprocess(model.image_size, batch_size=batch_size, preprocessor=model.preprocessor)

## 5. Evaluate the model before training

Since we haven't done any training on the model yet, it will evaluate using the original ImageNet trained model and accuracy on the new classes will be near zero.

##### Optional Argument
-  **enable_auto_mixed_precision** (bool or None): Enable auto mixed precision for evaluation. Mixed precision uses both 16-bit and 32-bit floating point types to make evaluation run faster and use less memory. It is recommended to enable auto mixed precision when running on platforms that support bfloat16 (Intel third or fourth generation Xeon processors). If it is enabled on a platform that does not support bfloat16, it can be detrimental to the evaluation performance. If enable_auto_mixed_precision is set to None, auto mixed precision will be automatically enabled when running with Intel fourth generation Xeon processors, and disabled for other platforms.

In [None]:
enable_auto_mixed_precision = None

model.evaluate(dataset, enable_auto_mixed_precision=enable_auto_mixed_precision)

## 6. Transfer Learning

This step calls the model's train function with the dataset that was just prepared. The training function will get the base model and add on a dense layer based on the number of classes in the dataset. The model is then compiled and trained based on the number of epochs specified in the argument. With the do_eval parameter set to True by default, this step will also show how the model can be evaluated and will return a list of metrics calculated from the dataset's validation subset.

### Arguments

#### Required
-  **dataset** (ImageClassificationDataset, required): Dataset to use when training the model
-  **output_dir** (str): Path to a writeable directory for checkpoint files
-  **epochs** (int): Number of epochs to train the model (default: 1)

#### Optional
-  **initial_checkpoints** (str): Path to checkpoint weights to load. If the path provided is a directory, the latest checkpoint will be used.
-  **early_stopping** (bool): Enable early stopping if convergence is reached while training at the end of each epoch. (default: False)
-  **lr_decay** (bool): If lr_decay is True and do_eval is True, learning rate decay on the validation loss is applied at the end of each epoch.
-  **enable_auto_mixed_precision** (bool or None): Enable auto mixed precision for training. Mixed precision uses both 16-bit and 32-bit floating point types to make training run faster and use less memory. It is recommended to enable auto mixed precision training when running on platforms that support bfloat16 (Intel third or fourth generation Xeon processors). If it is enabled on a platform that does not support bfloat16, it can be detrimental to the training performance. If enable_auto_mixed_precision is set to None, auto mixed precision will be automatically enabled when running with Intel fourth generation Xeon processors, and disabled for other platforms.
-  **extra_layers** (list[int]): Optionally insert additional dense layers between the base model and output layer. This can help increase accuracy when fine-tuning a TFHub model. The input should be a list of integers representing the number and size of the layers, for example [1024, 512] will insert two dense layers, the first with 1024 neurons and the second with 512 neurons.

Note: refer to release documentation for an up-to-date list of train arguments and their current descriptions

In [None]:
enable_auto_mixed_precision = None

# Train using the pretrained model from TF Hub with the new dataset
history = model.train(dataset, output_dir=output_dir, epochs=1,
                      enable_auto_mixed_precision=enable_auto_mixed_precision)

This time, the accuracy looks much better.

## 7. Export

Next, we can call the model `export` function to generate a `saved_model.pb`. The model is saved in a format that is ready to use with [TensorFlow Serving](https://github.com/tensorflow/serving). Each time the model is exported, a new numbered directory is created, which allows serving to pick up the latest model. 

In [None]:
saved_model_dir = model.export(output_dir)

## 8. Graph Optimization

The `tlt` API uses [Intel® Neural Compressor (INC)](https://github.com/intel/neural-compressor) to optimize the FP32 graph for improved inference performance. Graph optimization performs the following:
* Converting variables to constants
* Removing training-only operations like checkpoint saving
* Stripping out parts of the graph that are never reached
* Removing debug operations like CheckNumerics
* Folding batch normalization ops into the pre-calculated weights
* Fusing common operations into unified versions

For benchmarking, we will use an auto-generated config. If you want more control over the configuration, you can create your own config using the Intel Neural Compressor or the `tlt.utils.inc_utils.get_inc_config()` method.

We use the Intel Neural Compressor config to benchmark the full non-optimized model to see how it performs, as our baseline.

> Note that there is a known issue when running Intel Neural Compressor from a notebook that you may sometimes see the error 
> `zmq.error.ZMQError: Address already in use`. If you see this error, rerun the cell again.

In [None]:
result = model.benchmark(dataset)

Next, we do the FP32 graph optimization.

In [None]:
# Create an output directory for optimization output with the same base name as our saved model directory
optimization_output_dir = os.path.join(output_dir, 'optimized_models', model.model_name,
                                       os.path.basename(saved_model_dir))

model.optimize_graph(optimization_output_dir, overwrite_model=True)

In [None]:
optimized_result = model.benchmark(dataset, saved_model_dir=optimization_output_dir)

## Dataset Citations

```
@ONLINE {tfflowers,
author = "The TensorFlow Team",
title = "Flowers",
month = "jan",
year = "2019",
url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }
```