# Convert to TFLite

<div class="alert alert-info">

This tutorial is available as an IPython notebook at [malaya-speech/example/convert-asr-to-tflite](https://github.com/huseinzol05/malaya-speech/tree/master/example/convert-asr-to-tflite).
    
</div>

### Purpose

We want to deploy our ASR model in smartphone / embedding devices.

### Download model

In this example, I am going to use QuartzNet ASR model to convert to TFLite.

**Right now TFLite only supported CTC greedy decoder**.

In [20]:
# !wget https://f000.backblazeb2.com/file/malaya-speech-model/pretrained/asr-quartznet-ctc-output-75k.tar.gz
# !tar -zxf asr-quartznet-ctc-output-75k.tar.gz
!ls -lh asr-quartznet-ctc-output

total 457424
-rw-r--r--  1 huseinzolkepli  staff    77B Nov 14 22:08 checkpoint
-rw-r--r--  1 huseinzolkepli  staff    74M Nov 14 22:08 frozen_model.pb
-rw-r--r--  1 huseinzolkepli  staff    74M Nov 14 22:25 frozen_model.pb.quantized
-rw-r--r--  1 huseinzolkepli  staff    74M Nov 14 22:08 model.ckpt.data-00000-of-00001
-rw-r--r--  1 huseinzolkepli  staff    23K Nov 14 22:08 model.ckpt.index
-rw-r--r--  1 huseinzolkepli  staff   1.7M Nov 14 22:08 model.ckpt.meta


In [4]:
# !wget https://raw.githubusercontent.com/huseinzol05/malaya-speech/master/session/prepare-asr/malaya-speech-sst-vocab.json

In [5]:
import tensorflow as tf
import malaya_speech
import json

In [6]:
with open('malaya-speech-sst-vocab.json') as fopen:
    unique_vocab = json.load(fopen)

### Load frozen graph

To make sure our frozen graph not corrupted.

In [7]:
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph

In [8]:
g = load_graph('asr-quartznet-ctc-output/frozen_model.pb')
x = g.get_tensor_by_name('import/Placeholder:0')
logits = g.get_tensor_by_name('import/logits:0')
test_sess = tf.InteractiveSession(graph = g)

1. `import/Placeholder`,

Accept N-length of float32, this is a placeholder to feed signal.

2. `import/logits`,

Returned N-length of int32, represent string, need to reverse using provided `unique_vocab`.

In [10]:
y, sr = malaya_speech.load('speech/khutbah/wadi-annuar.wav')
len(y) / sr

10.0

In [11]:
result = test_sess.run(logits, feed_dict = {x: y})
result

array([[24,  3, 15,  6,  2, 15,  3, 13,  3, 16,  2, 20,  4, 12, 24,  3,
        13,  3,  5,  3,  5,  2,  6,  5,  6,  2, 15,  9,  5,  6,  3,  2,
        18,  3,  5, 17,  2, 10,  4, 10,  3, 14,  2,  6,  5,  6,  2, 11,
         4,  7,  6, 11,  3,  2,  5,  3, 19,  6,  2, 16,  4,  5, 17,  3,
        24,  3, 12,  2, 16,  9,  3, 10, 19,  6,  5,  2, 24,  3, 19,  3,
         7,  2,  7,  3, 15,  6,  5,  6,  2,  3, 13, 13,  3, 14,  2, 16,
         3, 11,  6, 13,  5,  6]], dtype=int32)

In [13]:
malaya_speech.char.decode(result[0], lookup = unique_vocab)

'jadi dalam perjalanan ini dunia yang sesah ini ketika nabi mengajar muasbin jabat tadini allah makilni'

Looks good!

### Optimize frozen graph

Before we convert frozen graph to tflite, we need to optimize the frozen graph,

1. tflite cannot initiate random values (initiated during dropout, etc), so we need to convert to constants.
2. we do not want dropout operation.
3. Change batch norm to constants.

In [15]:
from tensorflow.tools.graph_transforms import TransformGraph

In [16]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_constants(ignore_errors=true)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'strip_unused_nodes',
             'sort_by_execution_order']

input_graph_def = tf.GraphDef()
with tf.gfile.FastGFile('asr-quartznet-ctc-output/frozen_model.pb', 'rb') as f:
    input_graph_def.ParseFromString(f.read())
    
transformed_graph_def = TransformGraph(input_graph_def, 
                                       ['Placeholder'], 
                                       ['logits'], 
                                       transforms)

with tf.gfile.GFile('asr-quartznet-ctc-output/frozen_model.pb.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.gfile.GFile.


### Convert to TFLite

Right now, tflite not yet supported dynamic length of input, so in this example, I want the tflite model able to accept 10 seconds of input (16k sample rate).

In [17]:
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
    graph_def_file='asr-quartznet-ctc-output/frozen_model.pb.quantized',
    input_arrays=['Placeholder'],
    input_shapes={'Placeholder' : [16000 * 10]},
    output_arrays=['logits']
)

In [18]:
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.experimental_new_converter = True

In [19]:
tflite_model = converter.convert()
open('asr-quartznet-ctc.tflite', 'wb').write(tflite_model)

19738936

Output size ~19.7MB, original size is 74MB.