##### Copyright 2021 The IREE Authors

In [None]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# TFLite text classification sample with IREE

This notebook demonstrates how to download, compile, and run a TFLite model with IREE.  It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE.  The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.


## Setup

In [1]:
%%capture
!python -m pip install --upgrade tf-nightly  # Needed for experimental_tflite_to_tosa_bytecode in TF>=2.14
!python -m pip install iree-compiler iree-runtime iree-tools-tflite -f https://iree.dev/pip-release-links.html
!python -m pip install tflite-runtime-nightly

In [2]:
from tensorflow.python.pywrap_mlir import experimental_tflite_to_tosa_bytecode

In [3]:
import numpy as np
import urllib.request
import pathlib
import tempfile
import re
import tflite_runtime.interpreter as tflite

from iree import runtime as iree_rt
from iree.compiler import compile_file, compile_str
from iree.tools import tflite as iree_tflite

ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), "iree", "colab_artifacts")
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

# Print version information for future notebook users to reference.
!iree-compile --version

IREE (https://iree.dev):
  IREE compiler version 20230831.630 @ 9ed3dab7ac4fcda959f5b8ebbcd7732aeb4b0c8d
  LLVM version 18.0.0git
  Optimized build


### Load the TFLite model

1.   Download files for the pretrained model
2.   Extract model metadata used for input pre-processing and output post-processing
3.   Define helper functions for pre- and post-processing

These steps will differ from model to model.  Consult the model source or reference documentation for details.


In [4]:
#@title Download pretrained text classification model
MODEL_URL = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite"
urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath("text_classification.tflite"))

(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),
 <http.client.HTTPMessage at 0x7d22f6441630>)

In [5]:
#@title Extract model vocab and label metadata
!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite

# Load the vocab file into a dictionary.  It contains the most common 1,000
# words in the English language, mapped to an integer.
vocab = {}
with open(ARTIFACTS_DIR.joinpath("vocab.txt")) as vocab_file:
  for line in vocab_file:
    (key, val) = line.split()
    vocab[key] = int(val)

# Text will be labeled as either 'Positive' or 'Negative'.
with open(ARTIFACTS_DIR.joinpath("labels.txt")) as label_file:
  labels = label_file.read().splitlines()

Archive:  /tmp/iree/colab_artifacts/text_classification.tflite
 extracting: /tmp/iree/colab_artifacts/labels.txt  
 extracting: /tmp/iree/colab_artifacts/vocab.txt  


In [6]:
#@title Input and output processing

# Input text will be encoded as an integer array of fixed length 256.  The
# input sentence will be mapped to integers from the vocab dictionary, and the
# empty array spaces are filled with padding.

SENTENCE_LEN = 256
START = "<START>"
PAD = "<PAD>"
UNKNOWN = "<UNKNOWN>"

def tokenize_input(text):
  output = np.empty([1, SENTENCE_LEN], dtype=np.int32)
  output.fill(vocab[PAD])

  # Remove capitalization and punctuation from the input text.
  text_split = text.split()
  text_split = [text.lower() for text in text_split]
  text_split = [re.sub(r"[^\w\s']", '', text) for text in text_split]

  # Prepend <START>.
  index = 0
  output[0][index] = vocab[START]
  index += 1

  for word in text_split:
    output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]
    index += 1

  return output


def interpret_output(output):
  if output[0] >= output[1]:
    label = labels[0]
    confidence = output[0]
  else:
    label = labels[1]
    confidence = output[1]

  print("Label: " + label + "\nConfidence: " + str(confidence))

In [7]:
#@title Text samples
positive_text = "This is the best movie I've seen in recent years. Strongly recommend it!"
negative_text = "What a waste of my time."

print(positive_text)
print(tokenize_input(positive_text))

This is the best movie I've seen in recent years. Strongly recommend it!
[[   1   13    8    3  117   19  206  109   10 1134  152 2301  385   11
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0  

## Run using TFLite

Overview:

1.  Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
2.   Allocate tensors and get the input and output shape information
3.   Invoke the TFLite Interpreter to test the text classification function

In [8]:
interpreter = tflite.Interpreter(
      model_path=str(ARTIFACTS_DIR.joinpath("text_classification.tflite")))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def classify_text_tflite(text):
  interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))
  interpreter.invoke()
  output_data = interpreter.get_tensor(output_details[0]['index'])
  interpret_output(output_data[0])

In [9]:
print("Invoking text classification with TFLite\n")
positive_text = "This is the best movie I've seen in recent years. Strongly recommend it!"
print(positive_text)
classify_text_tflite(positive_text)
print()
negative_text = "What a waste of my time."
print(negative_text)
classify_text_tflite(negative_text)

Invoking text classification with TFLite

This is the best movie I've seen in recent years. Strongly recommend it!
Label: Positive
Confidence: 0.8997294

What a waste of my time.
Label: Negative
Confidence: 0.6275043


## Run using IREE

Overview:

1.   Import the TFLite model to TOSA MLIR
2.   Compile the TOSA MLIR into an IREE flatbuffer and VM module
3.   Run the VM module through IREE's runtime to test the text classification function

Both runtimes should generate the same output.


In [10]:
# Convert TFLite model to TOSA MLIR (bytecode) with IREE's import tool.
tosa_mlirbc_file = ARTIFACTS_DIR.joinpath("text_classification.mlirbc")
!iree-import-tflite {ARTIFACTS_DIR}/text_classification.tflite -o={tosa_mlirbc_file}

# The generated .mlirbc file could now be saved and used outside of Python, with
# IREE native tools or in apps, etc.

2023-08-31 21:32:50.137814: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9511] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-08-31 21:32:50.137865: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-08-31 21:32:50.137895: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [11]:
# Compile the TOSA MLIR into a VM module.
compiled_flatbuffer = compile_file(tosa_mlirbc_file, input_type="tosa", target_backends=["vmvx"])

# Register the module with a runtime context.
config = iree_rt.Config("local-task")
ctx = iree_rt.SystemContext(config=config)
vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, compiled_flatbuffer)
ctx.add_vm_module(vm_module)
invoke_text_classification = ctx.modules.module["main"]

def classify_text_iree(text):
  result = invoke_text_classification(tokenize_input(text)).to_host()[0]
  interpret_output(result)

In [12]:
print("Invoking text classification with IREE\n")
print(positive_text)
classify_text_iree(positive_text)
print()
print(negative_text)
classify_text_iree(negative_text)

Invoking text classification with IREE

This is the best movie I've seen in recent years. Strongly recommend it!
Label: Positive
Confidence: 0.8997294

What a waste of my time.
Label: Negative
Confidence: 0.6275043
