# Exercise: Convert to TF-TRT INT8

In this notebook you will convert a TensorFlow saved model into a TF-TRT optimized graph using INT8 precision. You will use the optimized graph to make predictions and will benchmark its performance.

## Objectives

By the time you complete this notebook you wil be able to:

- Use TF-TRT to optimize a saved model with INT8 precision

## Imports

In [None]:
from tensorflow.python.compiler.tensorrt import trt_convert as trt

In [None]:
from lab_helpers import (
    get_images, batch_input, load_tf_saved_model,
    predict_and_benchmark_throughput_from_saved, display_prediction_info
)

## Create Batched Input

In [None]:
batch_size = 32
images = get_images(batch_size)

In [None]:
batched_input = batch_input(images)

## Converting to TF-TRT INT8

To perform INT8 optimization, we simply need to:

- Set `precision_mode` to `trt.TrtPrecisionMode.INT8`
- Pass a `calibration_input_fn` to `converter.convert`

### Calibration Input Function

`calibration_input_fn` should be a generator function that yields input data as a list or tuple.

You need to make sure that the calibration dataset covers all the expected scenarios, for example, clear weather, rainy day, night scenes, etc. When examining your own dataset, you should create a separate calibration dataset. The calibration dataset should not overlap with the training, validation, or test datasets.

For our simple example here, we will not take these extra steps and will simply pass in our `batched_input` as calibration data.

## Convert to TF-TRT INT8

Address the `TODO`s and make this function capable of performing conversion for INT8 precision.

In [None]:
def convert_to_trt_graph_and_save(precision_mode='float32', input_saved_model_dir='resnet_v2_152_saved_model', calibration_data=batched_input):
    
    if precision_mode == 'float32':
        precision_mode = trt.TrtPrecisionMode.FP32
        converted_save_suffix = '_TFTRT_FP32'
        
    if precision_mode == 'float16':
        precision_mode = trt.TrtPrecisionMode.FP16
        converted_save_suffix = '_TFTRT_FP16'

    # TODO: correctly set precision_mode
    if precision_mode == 'int8':
        precision_mode = None
        converted_save_suffix = '_TFTRT_INT8'
        
    output_saved_model_dir = input_saved_model_dir + converted_save_suffix
    
    conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
        precision_mode=precision_mode, 
        max_workspace_size_bytes=8000000000
    )

    converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=input_saved_model_dir,
        conversion_params=conversion_params
    )

    print('Converting {} to TF-TRT graph precision mode {}...'.format(input_saved_model_dir, precision_mode))
    
    if precision_mode == trt.TrtPrecisionMode.INT8:
        
        # Here we define a simple generator to yield calibration data
        def calibration_input_fn():
            yield (calibration_data, )

        # TODO: When performing INT8 optimization, we must pass the named argument calibration_input_fn to converter.convert.
        # Use the `calibration_input_fn` defined a few lines above for this.
        converter.convert()
    
    else:
        converter.convert()

    print('Saving converted model to {}...'.format(output_saved_model_dir))
    converter.save(output_saved_model_dir=output_saved_model_dir)
    print('Complete')

In [None]:
# Run to check your work. Takes a couple minutes.
convert_to_trt_graph_and_save(precision_mode='int8', input_saved_model_dir='resnet_v2_152_saved_model', calibration_data=batched_input)

### Solution

Expand the next cell to see the solution if you get stuck.

In [None]:
def convert_to_trt_graph_and_save(precision_mode='float32', input_saved_model_dir='resnet_v2_152_saved_model', calibration_data=batched_input):
    
    if precision_mode == 'float32':
        precision_mode = trt.TrtPrecisionMode.FP32
        converted_save_suffix = '_TFTRT_FP32'
        
    if precision_mode == 'float16':
        precision_mode = trt.TrtPrecisionMode.FP16
        converted_save_suffix = '_TFTRT_FP16'

    # Per usual, set the precision_mode
    if precision_mode == 'int8':
        precision_mode = trt.TrtPrecisionMode.INT8
        converted_save_suffix = '_TFTRT_INT8'
        
    output_saved_model_dir = input_saved_model_dir + converted_save_suffix
    
    conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
        precision_mode=precision_mode, 
        max_workspace_size_bytes=8000000000
    )

    converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=input_saved_model_dir,
        conversion_params=conversion_params
    )

    print('Converting {} to TF-TRT graph precision mode {}...'.format(input_saved_model_dir, precision_mode))
    
    if precision_mode == trt.TrtPrecisionMode.INT8:
        
        # Here we define a simple generator to yield calibration data
        def calibration_input_fn():
            yield (calibration_data, )

        # When performing INT8 optimization, we must pass a calibration function to convert
        converter.convert(calibration_input_fn=calibration_input_fn)
    
    else:
        converter.convert()

    print('Saving converted model to {}...'.format(output_saved_model_dir))
    converter.save(output_saved_model_dir=output_saved_model_dir)
    print('Complete')

## Benchmark TF-TRT INT8

Load the optimized TF model.

In [None]:
infer = load_tf_saved_model('resnet_v2_152_saved_model_TFTRT_INT8')

Perform inference with the optimized graph, and after a warmup, time and calculate throughput.

In [None]:
all_preds = predict_and_benchmark_throughput_from_saved(batched_input, infer, N_warmup_run=50, N_run=150)

Run this cell to view predictions, which you can use for comparison.

In [None]:
last_run_preds = all_preds[0]
display_prediction_info(last_run_preds, images)

## Restart Kernel

Please execute the cell below to restart the kernel and clear GPU memory.

In [None]:
import IPython
IPython.Application.instance().kernel.do_shutdown(True)

## Next

In the next notebook you will optimize additional models, and experiment with the impact of changing the `minimum_segment_size` conversion parameter.