##### Copyright 2021 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [1]:
#@title License header
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TFLite text classification sample with IREE

This notebook demonstrates how to download, compile, and run a TFLite model with IREE.  It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE.


# Setup

In [2]:
%%capture
!python -m pip install iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot -f https://github.com/google/iree/releases
!pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime

In [3]:
import numpy as np
import os
import urllib.request
import tempfile
import re
import tflite_runtime.interpreter as tflite

from iree import runtime as ireert
from iree.compiler import compile_str
from iree.tools import tflite as iree_tflite

ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), "iree", "colab_artifacts")
os.makedirs(ARTIFACTS_DIR, exist_ok=True)

## Model prep

In [4]:
#@title Download pretrained TFLite text classification model
MODEL_URL = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite"
urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR + "/text_classification.tflite")

('/tmp/iree/colab_artifacts/text_classification.tflite',
 <http.client.HTTPMessage at 0x7fbec89b4f10>)

In [5]:
#@title Extract model vocab and label metadata
!unzip -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite

# Load the vocab file into a dictionary.
vocab = {}
with open(ARTIFACTS_DIR + "/vocab.txt") as vocab_file:
  for line in vocab_file:
    (key, val) = line.split()
    vocab[key] = int(val)

with open(ARTIFACTS_DIR + "/labels.txt") as label_file:
  labels = label_file.read().splitlines()

Archive:  /tmp/iree/colab_artifacts/text_classification.tflite
 extracting: /tmp/iree/colab_artifacts/labels.txt  
 extracting: /tmp/iree/colab_artifacts/vocab.txt  


In [6]:
#@title Input and output processing
SENTENCE_LEN = 256
START = "<START>"
PAD = "<PAD>"
UNKNOWN = "<UNKNOWN>"

def tokenize_input(text):
  output = np.empty([1, SENTENCE_LEN], dtype=np.int32)
  output.fill(vocab[PAD])

  # Remove capitalization and punctuation from the input text.
  text_split = text.split()
  text_split = [text.lower() for text in text_split]
  text_split = [re.sub(r"[^\w\s']", '', text) for text in text_split]

  # Prepend <START>.
  index = 0
  output[0][index] = vocab[START]
  index += 1

  for word in text_split:
    output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]
    index += 1

  return output

def interpret_output(output):
  if output[0] >= output[1]:
    label = labels[0]
    confidence = output[0]
  else:
    label = labels[1]
    confidence = output[1]

  print("Label: " + label + "\nConfidence: " + str(confidence))

In [7]:
#@title Text samples
positive_text = "This is the best movie I've seen in recent years. Strongly recommend it!"
negative_text = "What a waste of my time."

print(tokenize_input(positive_text))

[[   1   13    8    3  117   19  206  109   10 1134  152 2301  385   11
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0   

## TFLite

In [8]:
interpreter = tflite.Interpreter(
      model_path=ARTIFACTS_DIR + "/text_classification.tflite", num_threads=None)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def classify_text_tflite(text):
  interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))
  interpreter.invoke()
  output_data = interpreter.get_tensor(output_details[0]['index'])
  interpret_output(output_data[0])

In [9]:
print("Invoking text classification with TFLite\n")
print(positive_text)
classify_text_tflite(positive_text)
print()
print(negative_text)
classify_text_tflite(negative_text)

Invoking text classification with TFLite

This is the best movie I've seen in recent years. Strongly recommend it!
Label: Positive
Confidence: 0.8997293

What a waste of my time.
Label: Negative
Confidence: 0.6275043


## IREE

Overview:

1.   Import the TFLite model to TOSA MLIR 
2.   Compile the TOSA MLIR into an IREE flatbuffer and VM module
3.   Run the VM module through IREE's runtime to test the text classification function




In [10]:
# Convert TFLite model to TOSA MLIR with IREE's import tool.
IREE_TFLITE_TOOL = iree_tflite.get_tool('iree-import-tflite')
!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification.mlir

with open(ARTIFACTS_DIR + "/text_classification.mlir") as mlir_file:
  tosa_mlir = mlir_file.read()

# Manually insert "iree.module.export" attribute until it is removed. 
# https://github.com/google/iree/issues/3968
index = 142
modified_tosa_mlir = tosa_mlir[:index] + ", iree.module.export" + tosa_mlir[index:]
print(modified_tosa_mlir)

module  {
  func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = "input_5", outputs = "Identity"}, iree.module.export} {
    %0 = "tosa.const"() {value = dense<"0x4FA3373DA6259D3D233E10BDFB5A523C2DF4B4BC669058BD5C2362BC53F91DBDDC920ABDE43D1ABB9785F1BBEC5CB53AA80D823C67D63ABD005096BC3BEE4DBDF7171D3DBCD58CBAB66C39BD4FE1443D8C75D1BD4A9469BC93D59A3D6D87F4BCA5D7353DB4D9D5BCD21F9F3A604D95BC9E3CB3BD958635BB3D4C69BD28B3D83CCD5E973D8ECE73BDF94AB7BCC5625ABC5124DDBD0AF2983D1AEC6D3D95CDBBBCA2FBCE3C7E9A82BDF477FBBCEBD50ABDB67E17BDF5E3D13B212751BD23144D3D252DB5BC6D04863DF7ABD93C666AE7BBC18BB3BDA32369BD396396BC2F458FBDA7BE273C43EC143DD29A8DBD0B7CCCBDED8620BD6460C3BD9B8E5BBD1131B83C695F963CABB2093DC8A5C0BC5FD067BDC874783C330C62BD8D98083D854D65BD6FE83C3C8E91A93DEC8E98BD8A93D6BD6C2D9D3C7FB56BBD0F553F3D287B33BCEE07433D6594623D100F913CDAF69FBCDBA937BDBDE3B23C789725BCC0C82B3A0727803C1CDC99BCC26EE9BC78369EBD2000BEBCEA79CEBCE7059BBCEDA02EBD61DA9E3D04158D3BB4BD74BD

In [11]:
# Compile the TOSA MLIR into a VM module.
compiled_flatbuffer = compile_str(modified_tosa_mlir, target_backends=["vmvx"])
vm_module = ireert.VmModule.from_flatbuffer(compiled_flatbuffer)

# Register the module with a runtime context.
config = ireert.Config("vmvx")
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
invoke_text_classification = ctx.modules.module["main"]

def classify_text_iree(text):
  result = invoke_text_classification(tokenize_input(text))
  interpret_output(result[0])

Created IREE driver vmvx: <iree.runtime.binding.HalDriver object at 0x7fbec83862f0>
SystemContext driver=<iree.runtime.binding.HalDriver object at 0x7fbec83862f0>


In [12]:
print("Invoking text classification with IREE\n")
print(positive_text)
classify_text_iree(positive_text)
print()
print(negative_text)
classify_text_iree(negative_text)

Invoking text classification with IREE

This is the best movie I've seen in recent years. Strongly recommend it!
Label: Positive
Confidence: 0.8997293

What a waste of my time.
Label: Negative
Confidence: 0.62750435
