```
Copyright 2021 The IREE Authors

Licensed under the Apache License v2.0 with LLVM Exceptions.
See https://llvm.org/LICENSE.txt for license information.
SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
```

# TFLite text classification sample with IREE

This notebook demonstrates how to download, compile, and run a TFLite model with IREE.  It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE.  The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.


## Setup

In [1]:
%%capture
!python -m pip install iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot -f https://github.com/google/iree/releases/latest
!pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime
!pip3 install --upgrade pyyaml

In [2]:
import numpy as np
import urllib.request
import pathlib
import tempfile
import re
import tflite_runtime.interpreter as tflite

from iree import runtime as iree_rt
from iree.compiler import compile_str
from iree.tools import tflite as iree_tflite

ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), "iree", "colab_artifacts")
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

%env IREE_SAVE_CALLS=$ARTIFACTS_DIR/traces

env: IREE_SAVE_CALLS=/tmp/iree/colab_artifacts/traces


### Load the TFLite model

1.   Download files for the pretrained model
2.   Extract model metadata used for input pre-processing and output post-processing
3.   Define helper functions for pre- and post-processing

These steps will differ from model to model.  Consult the model source or reference documentation for details.


In [3]:
#@title Download pretrained text classification model
MODEL_URL = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite"
urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath("text_classification.tflite"))

(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),
 <http.client.HTTPMessage at 0x7fb8f43a9c90>)

In [4]:
#@title Extract model vocab and label metadata
!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite

# Load the vocab file into a dictionary.  It contains the most common 1,000
# words in the English language, mapped to an integer.
vocab = {}
with open(ARTIFACTS_DIR.joinpath("vocab.txt")) as vocab_file:
  for line in vocab_file:
    (key, val) = line.split()
    vocab[key] = int(val)

# Text will be labeled as either 'Positive' or 'Negative'.
with open(ARTIFACTS_DIR.joinpath("labels.txt")) as label_file:
  labels = label_file.read().splitlines()

Archive:  /tmp/iree/colab_artifacts/text_classification.tflite
 extracting: /tmp/iree/colab_artifacts/labels.txt  
 extracting: /tmp/iree/colab_artifacts/vocab.txt  


In [5]:
#@title Input and output processing

# Input text will be encoded as an integer array of fixed length 256.  The 
# input sentence will be mapped to integers from the vocab dictionary, and the 
# empty array spaces are filled with padding.

SENTENCE_LEN = 256
START = "<START>"
PAD = "<PAD>"
UNKNOWN = "<UNKNOWN>"

def tokenize_input(text):
  output = np.empty([1, SENTENCE_LEN], dtype=np.int32)
  output.fill(vocab[PAD])

  # Remove capitalization and punctuation from the input text.
  text_split = text.split()
  text_split = [text.lower() for text in text_split]
  text_split = [re.sub(r"[^\w\s']", '', text) for text in text_split]

  # Prepend <START>.
  index = 0
  output[0][index] = vocab[START]
  index += 1

  for word in text_split:
    output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]
    index += 1

  return output


def interpret_output(output):
  if output[0] >= output[1]:
    label = labels[0]
    confidence = output[0]
  else:
    label = labels[1]
    confidence = output[1]

  print("Label: " + label + "\nConfidence: " + str(confidence))

In [6]:
#@title Text samples
positive_text = "This is the best movie I've seen in recent years. Strongly recommend it!"
negative_text = "What a waste of my time."

print(positive_text)
print(tokenize_input(positive_text))

This is the best movie I've seen in recent years. Strongly recommend it!
[[   1   13    8    3  117   19  206  109   10 1134  152 2301  385   11
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0  

## Run using TFLite

Overview:

1.  Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
2.   Allocate tensors and get the input and output shape information
3.   Invoke the TFLite Interpreter to test the text classification function

In [7]:
interpreter = tflite.Interpreter(
      model_path=str(ARTIFACTS_DIR.joinpath("text_classification.tflite")))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def classify_text_tflite(text):
  interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))
  interpreter.invoke()
  output_data = interpreter.get_tensor(output_details[0]['index'])
  interpret_output(output_data[0])

In [8]:
print("Invoking text classification with TFLite\n")
print(positive_text)
classify_text_tflite(positive_text)
print()
print(negative_text)
classify_text_tflite(negative_text)

Invoking text classification with TFLite

This is the best movie I've seen in recent years. Strongly recommend it!
Label: Positive
Confidence: 0.8997293

What a waste of my time.
Label: Negative
Confidence: 0.6275043


## Run using IREE

Overview:

1.   Import the TFLite model to TOSA MLIR 
2.   Compile the TOSA MLIR into an IREE flatbuffer and VM module
3.   Run the VM module through IREE's runtime to test the text classification function

Both runtimes should generate the same output.


In [9]:
# Convert TFLite model to TOSA MLIR with IREE's import tool.
IREE_TFLITE_TOOL = iree_tflite.get_tool('iree-import-tflite')
!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification.mlir

with open(ARTIFACTS_DIR.joinpath("text_classification.mlir")) as mlir_file:
  tosa_mlir = mlir_file.read()

# The generated .mlir file could now be saved and used outside of Python, with
# IREE native tools or in apps, etc.

In [10]:
# The model contains very large constants, so recompile a truncated version to print.
!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification_truncated.mlir -mlir-elide-elementsattrs-if-larger=50

with open(ARTIFACTS_DIR.joinpath("text_classification_truncated.mlir")) as truncated_mlir_file:
  truncated_tosa_mlir = truncated_mlir_file.read()
  print(truncated_tosa_mlir)

module  {
  func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = "input_5", outputs = "Identity"}} {
    %0 = "tosa.const"() {value = opaque<"_", "0xDEADBEEF"> : tensor<10003x16xf32>} : () -> tensor<10003x16xf32>
    %1 = "tosa.const"() {value = opaque<"_", "0xDEADBEEF"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
    %2 = "tosa.const"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>
    %3 = "tosa.const"() {value = dense<[[0.091361463, -1.23269629, 1.33242488, 0.92142266, -0.445623249, 0.849273681, -1.27237022, 1.28574562, 0.436188251, -0.963210225, 0.745473146, -0.255745709, -1.4491415, -1.4687326, 0.900665163, -1.36293614], [-0.0968776941, 0.771379471, -1.36363328, -1.1110599, -0.304591209, -1.05579722

In [11]:
# Compile the TOSA MLIR into a VM module.
compiled_flatbuffer = compile_str(tosa_mlir, input_type="tosa", target_backends=["vmvx"])
vm_module = iree_rt.VmModule.from_flatbuffer(compiled_flatbuffer)

# Register the module with a runtime context.
config = iree_rt.Config("vmvx")
ctx = iree_rt.SystemContext(config=config)
ctx.add_vm_module(vm_module)
invoke_text_classification = ctx.modules.module["main"]

def classify_text_iree(text):
  result = invoke_text_classification(tokenize_input(text))
  interpret_output(result[0])

In [12]:
print("Invoking text classification with IREE\n")
print(positive_text)
classify_text_iree(positive_text)
print()
print(negative_text)
classify_text_iree(negative_text)

Invoking text classification with IREE

This is the best movie I've seen in recent years. Strongly recommend it!
Label: Positive
Confidence: 0.8997293

What a waste of my time.
Label: Negative
Confidence: 0.6275043
