# Exercise: Convert to TF-TRT Float16

In this notebook you'll update the `convert_to_trt_graph_and_save` function you worked with in the last notebook to be able to also perform conversion for Float16 precision.

## Objectives

By the time you complete this notebook you will be able to:

- Optimize a saved model with TF-TRT

## Imports

In [None]:
from tensorflow.python.compiler.tensorrt import trt_convert as trt

In [None]:
from lab_helpers import (
    get_images, batch_input, load_tf_saved_model,
    predict_and_benchmark_throughput_from_saved, display_prediction_info
)

## Create Batched Input

Run these cells to create batched input. You don't need to modify the cells.

In [None]:
number_of_images = 32
images = get_images(number_of_images)

In [None]:
batched_input = batch_input(images)

## Make Conversion

Address the `TODO` and make this function capable of performing conversion for Float16 precision.

In [None]:
def convert_to_trt_graph_and_save(precision_mode='float32', input_saved_model_dir='resnet_v2_152_saved_model'):
    
    if precision_mode == 'float32':
        precision_mode = trt.TrtPrecisionMode.FP32
        converted_save_suffix = '_TFTRT_FP32'
        
    if precision_mode == 'float16':
        # TODO: Correctly set precision_mode
        precision_mode = None
        converted_save_suffix = '_TFTRT_FP16'
        
    output_saved_model_dir = input_saved_model_dir + converted_save_suffix
    
    conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
        precision_mode=precision_mode, 
        max_workspace_size_bytes=8000000000
    )

    converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=input_saved_model_dir,
        conversion_params=conversion_params
    )

    print('Converting {} to TF-TRT graph precision mode {}...'.format(input_saved_model_dir, precision_mode))

    converter.convert()

    print('Saving converted model to {}...'.format(output_saved_model_dir))
    converter.save(output_saved_model_dir=output_saved_model_dir)
    print('Complete')

In [None]:
convert_to_trt_graph_and_save(precision_mode='float16', input_saved_model_dir='resnet_v2_152_saved_model')

### Solution

Expand the next cell to see the solution if you get stuck.

```python
def convert_to_trt_graph_and_save(precision_mode='float32', input_saved_model_dir='resnet_v2_152_saved_model'):
    
    if precision_mode == 'float32':
        precision_mode = trt.TrtPrecisionMode.FP32
        converted_save_suffix = '_TFTRT_FP32'
        
    if precision_mode == 'float16':
        # TODO: Correctly set precision_mode`
        precision_mode = trt.TrtPrecisionMode.FP16
        converted_save_suffix = '_TFTRT_FP16'
        
    output_saved_model_dir = input_saved_model_dir + converted_save_suffix
    
    conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
        precision_mode=precision_mode, 
        max_workspace_size_bytes=8000000000
    )

    converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=input_saved_model_dir,
        conversion_params=conversion_params
    )

    print('Converting {} to TF-TRT graph precision mode {}...'.format(input_saved_model_dir, precision_mode))

    converter.convert()

    print('Saving converted model to {}...'.format(output_saved_model_dir))
    converter.save(output_saved_model_dir=output_saved_model_dir)
    print('Complete')
```

## Benchmark TF-TRT Float16

Load the optimized TF model.

In [None]:
infer = load_tf_saved_model('resnet_v2_152_saved_model_TFTRT_FP16')

Perform inference with the optimized graph, and after a warmup, time and calculate throughput.

In [None]:
all_preds = predict_and_benchmark_throughput_from_saved(batched_input, infer, N_warmup_run=50, N_run=150)

Run this cell to view predictions, which you can use for comparison.

In [None]:
last_run_preds = all_preds[0]
display_prediction_info(last_run_preds, images)

## Restart Kernel

Please execute the cell below to restart the kernel and clear GPU memory.

In [None]:
import IPython
IPython.Application.instance().kernel.do_shutdown(True)

## Next

Next you'll learn about the additional steps required to optimize TF-TRT models with Int8 precision.