# TensorFlow Lite Sentiment Analysis Model Training

Refer to https://www.tensorflow.org/lite/tutorials/model_maker_text_classification

Import packages

In [1]:
import numpy as np
import os

from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.text_classifier import AverageWordVecSpec
from tflite_model_maker.text_classifier import DataLoader

import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')

## Get Data

Download dataset. We are training on SST-2 (see https://deepai.org/dataset/stanford-sentiment-treebank)

In [2]:
data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')

Load the dataset into a Pandas dataframe and change the current label names (0 and 1) to a more human-readable ones (negative and positive) and use them for model training.

In [3]:
import pandas as pd

def replace_label(original_file, new_file):
    # Load the original file to pandas. We need to specify the separator as
    # '\t' as the training data is stored in TSV format
    df = pd.read_csv(original_file, sep='\t')

    # Define how we want to change the label name
    label_map = {0: 'negative', 1: 'positive'}

    # Excute the label change
    df.replace({'label': label_map}, inplace=True)

    # Write the updated dataset to a new file
    df.to_csv(new_file)

# Replace the label name for both the training and test dataset. Then write the
# updated CSV dataset to the current folder.
replace_label(os.path.join(os.path.join(data_dir, 'train.tsv')), 'train.csv')
replace_label(os.path.join(os.path.join(data_dir, 'dev.tsv')), 'dev.csv')

## Model Training

**Step 1. Choose a text classification model architecture.**
* Here we use the average word embedding model architecture, which will produce a small and fast model with decent accuracy. Other options include BERT.

In [4]:
spec = model_spec.get('average_word_vec')

**Step 2. Load the training and test data, then preprocess them according to a specific model_spec.**
* We will load the training and test dataset with the human-readable label name that were created earlier.
* DataLoader reads the requirement from model_spec and automatically executes the necessary preprocessing.

In [5]:
train_data = DataLoader.from_csv(
      filename='train.csv',
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      is_training=True)

test_data = DataLoader.from_csv(
      filename='dev.csv',
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      is_training=False)

**Step 3. Train the TensorFlow model with the training data.**
* The average word embedding model use batch_size = 32 by default.
* It takes 2104 steps to go through the 67,349 sentences in the training dataset.

In [6]:
model = text_classifier.create(train_data, model_spec=spec, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10


Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10


Epoch 10/10


**Step 4. Evaluate the model with the test data.**
* Default batch size is 32, it will take 28 steps to go through the 872 sentences in the test dataset.

In [15]:
loss, acc = model.evaluate(test_data)



In [16]:
print('Loss = {} \nAccuracy= {}'.format(loss, acc))

Loss = 0.5162637233734131 
Accuracy= 0.8337156176567078


**Step 5. Export as a TensorFlow Lite model.**

In [17]:
model.export(export_dir='average_word_vec')

* This model can be integrated into an Android app using the NLClassifier API of the TensorFlow Lite Task Library.

* See the TFLite Text Classification sample app for more details on how the model is used in a working app.

* Note 1: Android Studio Model Binding does not support text classification yet so please use the TensorFlow Lite Task Library.

* Note 2: There is a model.json file in the same folder with the TFLite model. It contains the JSON representation of the metadata bundled inside the TensorFlow Lite model. Model metadata helps the TFLite Task Library know what the model does and how to pre-process/post-process data for the model. You don't need to download the model.json file as it is only for informational purpose and its content is already inside the TFLite file.

* Note 3: If you train a text classification model using MobileBERT or BERT-Base architecture, you will need to use BertNLClassifier API instead to integrate the trained model into a mobile app.