#Convert models trained using TensorFlow Object Detection API to TensorFlow Lite


In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Preparation

### Install the TFLite Support Library

In [None]:
!pip install -q tflite_support

### Install the TensorFlow Object Detection API


In [None]:
import os
import pathlib

# Clone the tensorflow models repository if it doesn't already exist
if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

In [None]:
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
pip install -q .

### Import the necessary libraries

In [None]:
import matplotlib
import matplotlib.pyplot as plt

import os
import random
import io
import imageio
import glob
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display, Javascript
from IPython.display import Image as IPyImage

import tensorflow as tf

from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import colab_utils
from object_detection.utils import config_util
from object_detection.builders import model_builder

%matplotlib inline

In [None]:
# Extract data
import zipfile
def unzip(filename):
  zip_ref = zipfile.ZipFile(filename , "r")
  zip_ref.extractall()
  zip_ref.close()

In [None]:
unzip("/content/drive/MyDrive/my_ssd_mobnet_mature_cropped_10.zip")

## Generate TensorFlow Lite Model

### Step 1: Export TFLite inference graph

In [None]:
!python models/research/object_detection/export_tflite_graph_tf2.py \
    --trained_checkpoint_dir {'my_ssd_mobnet_mature_cropped_10/checkpoint'} \
    --output_directory {'my_ssd_mobnet_mature_cropped_10/tflite'} \
    --pipeline_config_path {'my_ssd_mobnet_mature_cropped_10/pipeline.config'}

### Step 2: Convert to TFLite

In [None]:
_TFLITE_MODEL_PATH = "my_ssd_mobnet_mature_cropped_10/model.tflite"

converter = tf.lite.TFLiteConverter.from_saved_model('my_ssd_mobnet_mature_cropped_10/tflite/saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open(_TFLITE_MODEL_PATH, 'wb') as f:
  f.write(tflite_model)

### Step 3: Add Metadata

This metadata helps the inference code perform the correct pre & post processing as required by the model.

In [None]:
# We need to convert the Object Detection API's labelmap into what the Task API needs:
# map file for your model if re-trained.
_ODT_LABEL_MAP_PATH = 'label_map.pbtxt'
_TFLITE_LABEL_PATH = "my_ssd_mobnet_mature_cropped_10/tflite_label_map.txt"

category_index = label_map_util.create_category_index_from_labelmap(
    _ODT_LABEL_MAP_PATH)
f = open(_TFLITE_LABEL_PATH, 'w')
for class_id in range(1, 91):
  if class_id not in category_index:
    f.write('???\n')
    continue
  name = category_index[class_id]['name']
  f.write(name+'\n')
f.close()

Then we'll add the label map and other necessary metadata (e.g. normalization config) to the TFLite model.

As the `SSD MobileNet V2 FPNLite  model take input image with pixel value in the range of [-1..1]  we need to set `norm_mean = 127.5` and `norm_std = 127.5`.

In [None]:
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils

_TFLITE_MODEL_WITH_METADATA_PATH = "my_ssd_mobnet_mature_cropped_10/model_with_metadata.tflite"

writer = object_detector.MetadataWriter.create_for_inference(
    writer_utils.load_file(_TFLITE_MODEL_PATH), input_norm_mean=[127.5], 
    input_norm_std=[127.5], label_file_paths=[_TFLITE_LABEL_PATH])
writer_utils.save_file(writer.populate(), _TFLITE_MODEL_WITH_METADATA_PATH)

Optional: Print out the metadata added to the TFLite model.

In [None]:
from tflite_support import metadata

displayer = metadata.MetadataDisplayer.with_model_file(_TFLITE_MODEL_WITH_METADATA_PATH)
print("Metadata populated:")
print(displayer.get_metadata_json())
print("=============================")
print("Associated file(s) populated:")
print(displayer.get_packed_associated_file_list())