# Multiclass text classification using BERT models from TF Hub

This notebook demonstrates fine tuning BERT models from [TF Hub](https://tfhub.dev) with multiclass text classification datasets.

The notebook performs the following steps:
1. [Import dependencies and setup parameters](#1.-Import-dependencies-and-setup-parameters)
2. [Prepare the dataset](#2.-Prepare-the-dataset)
3. [Build the model](#3.-Build-the-model)
4. [Fine tuning and evaluation](#4.-Fine-tuning-and-evaluation)
5. [Export the model](#5.-Export-the-model)
6. [Reload the model and make predictions](#6.-Reload-the-model-and-make-predictions)

## 1. Import dependencies and setup parameters

This notebook assumes that you have already followed the instructions in the [README.md](/notebooks/README.md) to setup a TensorFlow environment with all the dependencies required to run the notebook.

In [None]:
import os
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

from bert_utils import get_model_map
from tlt.utils.file_utils import download_and_extract_zip_file

In [None]:
# Note that tensorflow_text isn't used directly but the import is required to register ops used by the
# BERT text preprocessor
! pip3 install tensorflow-text==2.12.0 --no-deps
import tensorflow_text

This notebook will run one of the supported [BERT models from TF Hub](https://tfhub.dev/google/collections/bert/1). The table below has a list of the available models and links to their URLs in TF Hub.

In [None]:
# Load the TF Hub model map from json and print a list of the supported models
tfhub_model_map, models_df = get_model_map("tfhub_bert_model_map_classifier.json", return_data_frame=True)
models_df.style.hide(axis="index")

Specify the name of the BERT model to use. This string must match one of the models listed in the table above.

In [None]:
model_name = "small_bert/bert_en_uncased_L-2_H-128_A-2"
if model_name not in tfhub_model_map.keys():
    raise ValueError("The specified model name ({}) is not supported".format(model_name))

In [None]:
# Define a directory to download the dataset
dataset_directory = os.environ["DATASET_DIR"] if "DATASET_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "dataset")

# Define an output directory for the saved model to be exported
output_directory = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "output")

# Output directory for logs and checkpoints generated during training
if not os.path.isdir(output_directory):
    os.makedirs(output_directory)
    
tfhub_preprocess = tfhub_model_map[model_name]["preprocess"]
tfhub_bert_encoder = tfhub_model_map[model_name]["bert_encoder"]

print("Using TF Hub model:", model_name)
print("BERT encoder URL:", tfhub_bert_encoder)
print("Preprocessor URL:", tfhub_preprocess)
print("Dataset directory:", dataset_directory)
print("Output directory:", output_directory)

## 2. Prepare the dataset

This notebook gets the dataset from a text file or from the [TensorFlow Datasets catalog](https://www.tensorflow.org/datasets/catalog/overview).

The code ends up defining [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) objects for each split (train, validation, and test) and a map for the translating the numerical to string label.

Execute the following cell to set the batch size and declare the base class used for the dataset setup.

In [None]:
# Define the dataset batch size
batch_size = 32

# Base class used for defining the multi text classification dataset being used
class MultiTextClassificationData():
    def __init__(self, batch_size, label_map):
        self.batch_size = batch_size
        self.label_map = label_map
        self.reverse_label_map = {}
        self.train_ds = None
        self.val_ds = None
        self.test_ds = None
        self.dataset_name = ""
        
        for k, v in self.label_map.items():
            self.reverse_label_map[v] = k
        
    def get_str_label(self, numerical_value):
        if not isinstance(numerical_value, int):
            numerical_value = int(tf.math.round(numerical_value))
        
        if numerical_value in self.label_map.keys():
            return self.label_map[numerical_value]
        else:
            raise ValueError("The key {} was not found in the label map".format(numerical_value))

### Option A: Use a TensorFlow dataset

[TensorFlow Datasets](https://www.tensorflow.org/datasets) has a [catalog of datasets](https://www.tensorflow.org/datasets/catalog/overview) that can be specified by name. Information about the dataset is available in the catalog (including information on the size of the dataset and the splits).

The next cell demonstrates using the [`ag_news_subset`](https://www.tensorflow.org/datasets/catalog/ag_news_subset) dataset from the TensorFlow datasets catalog to get splits for training, validation, and test.

In [None]:
class TFDSMultiTextClassificationData(MultiTextClassificationData):
    def __init__(self, dataset_dir, tfds_name, train_split, val_split, test_split, label_map, batch_size):
        """
        Intialize the TFDSMultiTextClassificationData class for a dataset multi text classification dataset
        from the TensorFlow dataset catalog.
        
        :param dataset_dir: Path to a dataset directory to read/write data
        :param tfds_name: String name of the TensorFlow dataset to load
        :param train_split: String specifying which split to load for training (e.g. "train[:80%]"). See the
                            https://www.tensorflow.org/datasets/splits documentation for more information on
                            defining splits.
        :param val_split: String specifying the split to load for validation.
        :param test_split: String specifying the split to load for test.
        :param label_map: Dictionary where the key is a numerical value and the value is the string label
        :param batch_size: Batch size
        """
        # Init base class
        MultiTextClassificationData.__init__(self, batch_size, label_map) 
        
        [self.train_ds, self.val_ds, self.test_ds], info = tfds.load(tfds_name,
                     data_dir=dataset_dir,
                     split=[train_split, val_split, test_split],
                     batch_size=batch_size,
                     as_supervised=True,
                     shuffle_files=True,
                     with_info=True)
        self.dataset_name = tfds_name
        print(info)


# Name of the TFDS to use
tfds_name="ag_news_subset"

# Location where the dataset will be downloaded
dataset_directory = os.path.join(dataset_directory, tfds_name)
if not os.path.isdir(dataset_directory):
    os.makedirs(dataset_directory)

# Label map for sentiment analysis
label_map = {
    0: "World",
    1: "Sports",
    2: "Business",
    3: "Sci/Tech"
}
    
# Initialize the dataset splits using a dataset from the TensorFlow datasets catalog
dataset = TFDSMultiTextClassificationData(dataset_dir=dataset_directory,
                                           tfds_name=tfds_name,
                                           train_split="train[:50%]",
                                           val_split="train[:20%]",
                                           test_split="test[:20%]",
                                           label_map=label_map,
                                           batch_size=batch_size)

Skip to the next step [3. Build the model](#3.-Build-the-model) to continue using the TF dataset.

### Option B: Use your own dataset
Instead of using a dataset from TensorFlow datasets, another dataset from your local system or a download can be used.

In this example, we download the Conference Title dataset. This is a single tab-separated value file with two columns. The first column is the conference title and the second column is the label (VLDB, ISCAS, SIGGRAPH, INFOCOM, WWW):

```
<conference title>	<label>
<conference title>	<label>
<conference title>	<label>
...
```

If you are using a custom dataset that has a similarly formatted csv or tsv file, you can still use the class defined below. Just create your object passing in custom values for delimiter, header (whether the file has a header row), the label map, mapping function, etc.

In [None]:
class CustomCsvMultiTextClassificationData(MultiTextClassificationData):
    def __init__(self, csv_file, delimiter, header, train_percent, val_percent,
                 test_percent, label_map, batch_size, dataset_name, map_function=None):
        """
        Intialize the CustomCsvMultiTextClassificationData class for a dataset multi text
        classification dataset that uses a single csv file.
        
        :param csv_file: Path to the csv file
        :param delimiter: String character that separates the fields in each row
        :param header: Boolean indicating whether or not the csv file has a header line that should be skipped
        :param train_percent: Decimal value for the percentage of the dataset that should be used for training
                              (e.g. 0.8 for 80%)
        :param val_percent: Decimal value for the percentage of the dataset that should be used for validation
                            (e.g. 0.1 for 10%)
        :param test_percent: Decimal value for the percentage of the dataset that should be used for test
                             (e.g. 0.1 for 10%)
        :param label_map: Dictionary where the key is a numerical value and the value is the string label
        :param batch_size: Batch size
        :param dataset_name: Name of the dataset. This is used later in this notebook for naming the saved model
                             export folder and determining which input strings to use when testing the reloaded model
        :param map_function: (Optional) If the csv file has string labels instead of the numerical values, provide a
                             map function to apply on the dataset
        """
        # Init base class
        MultiTextClassificationData.__init__(self, batch_size, label_map)
        
        self.dataset_name = dataset_name
        
        if (train_percent + val_percent + test_percent) > 1:
            raise ValueError("The combined value of the train percentage, validation percentage, and " \
                             "test percentage cannot be greater than 1")
        
        if not os.path.exists(csv_file):
            raise FileNotFoundError("Unable to find the csv file at", csv_file)      

        custom_dataset = tf.data.experimental.CsvDataset(filenames=csv_file,
                                                         record_defaults=[tf.string, tf.string],
                                                         field_delim=delimiter,
                                                         use_quote_delim=True,
                                                         header=header)

        # Count the number of lines in the csv file to get the dataset length
        custom_dataset_len = sum(1 for line in open(csv_file))

        if header:
            custom_dataset_len -= 1

        # Optionally map the dataset labels using the map_function
        if map_function:
            custom_dataset = custom_dataset.map(map_function)
        
        # Create batches based on the specified batch size
        custom_dataset = custom_dataset.batch(batch_size)


        # Calculate sizes for the splits
        total_num_batches = int(custom_dataset_len / batch_size)
        train_size = int(train_percent * total_num_batches)
        val_size = int(val_percent * total_num_batches)
        test_size = int(test_percent * total_num_batches)

        # Create the train, validation, and test splits
        self.train_ds = custom_dataset.take(train_size)    
        self.val_ds = custom_dataset.skip(train_size).take(val_size)
        self.test_ds = custom_dataset.skip(train_size).skip(val_size)

        # Set the cardinality so that progress bars will work properly
        self.train_ds = self.train_ds.apply(tf.data.experimental.assert_cardinality(train_size))
        self.val_ds = self.val_ds.apply(tf.data.experimental.assert_cardinality(val_size))
        self.test_ds = self.test_ds.apply(tf.data.experimental.assert_cardinality(test_size))

# Modify the variables below to use a different dataset or a csv file on your local system.
# The csv_path variable should be pointing to a csv file with 2 columns (the label and the text)
dataset_url = "https://raw.githubusercontent.com/susanli2016/NLP-with-Python/master/data/title_conference.csv"
dataset_directory = os.path.join(dataset_directory, "titleconference")
csv_name = "title_conference.csv"
delimiter = ","
header = True  # Set to true if the csv file has a header row
csv_path = os.path.join(dataset_directory, csv_name)

if not os.path.exists(dataset_directory):
    os.makedirs(dataset_directory)

# If we don't already have the csv file, download and extract the zip file to get it.
if not os.path.exists(csv_path):
    df = pd.read_csv(dataset_url, header=0)
    df.to_csv(csv_path, index=False)

label_map = {
    "VLDB": 0,
    "ISCAS": 1,
    "SIGGRAPH": 2,
    "INFOCOM": 3, 
    "WWW": 4
}

int_to_label_map ={}
for k, v in label_map.items():
    int_to_label_map[v] = k

# Map function to translate labels in the csv file to numerical values when loading the dataset
def map_title(features, label):
    label = tf.py_function(lambda x: label_map[x.numpy().decode('utf-8')], [label], tf.int64)
    return features, label

# Initialize the dataset splits using the custom dataset
dataset = CustomCsvMultiTextClassificationData(csv_file=csv_path,
                                                delimiter=delimiter,
                                                header=header,
                                                train_percent=0.8,
                                                val_percent=0.1,
                                                test_percent=0.1,
                                                batch_size=batch_size,
                                                label_map=int_to_label_map,
                                                dataset_name=csv_name,
                                                map_function=map_title)

## 3. Build the model

Create the BERT model to fine tune using a input layer, the preprocessing layer (from TF Hub), the BERT encoder layer (from TF Hub), one dense layer, and a dropout layer.

In [None]:
input_layer = tf.keras.layers.Input(shape=(), dtype=tf.string, name='input_layer')
preprocessing_layer = hub.KerasLayer(tfhub_preprocess, name='preprocessing')
encoder_inputs = preprocessing_layer(input_layer)
encoder_layer = hub.KerasLayer(tfhub_bert_encoder, trainable=True, name='encoder')
outputs = encoder_layer(encoder_inputs)
net = outputs['pooled_output']
net = tf.keras.layers.Dense(16, activation='relu', name='fully_connected_layer')(net)
# Add dropout layer for regularization
net = tf.keras.layers.Dropout(0.2)(net)
net = tf.keras.layers.Dense(len(label_map), activation='softmax', name='classifier')(net)
classifier_model = tf.keras.Model(input_layer, net)

## 4. Fine tuning and evaluation

Train the model for the specified number of epochs, then evaluate the model using the test dataset.

> Note that there is a known error during custom dataset training: `train_function (Empty logs). Please use Model.compile(..., run_eagerly=True), or tf.config.run_functions_eagerly(True) for more information of where went wrong, or file a issue/bug to tf.keras.`
> If you see this error, try using the first dataset for at least partially training (it doesn't have to finish). Then re-run with the custom dataset and training should work. 

In [None]:
%%time

# The number of training epochs to run
num_train_epochs = 1

# Learning rate
learning_rate = 3e-5

# Maximum total input sequence length after WordPiece tokenization (longer sequences will be truncated)
max_seq_length = 128

classifier_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5), 
                        loss='sparse_categorical_crossentropy', 
                        metrics=['accuracy'], run_eagerly=True)

history = classifier_model.fit(dataset.train_ds,
                               validation_data=dataset.val_ds,
                               epochs=num_train_epochs)

Evaluate the accuracy using the test dataset. If the accuracy does not meet your expectations, try to increasing the size of the training dataset split or the number of training epochs.

In [None]:
loss, accuracy = classifier_model.evaluate(dataset.test_ds)

print(f'Loss: {loss}')
print(f'Accuracy: {accuracy}')

## Predict using a single batch from the test dataset, and then display the results along with the input text and the actual label.

In [None]:
num_steps = 1
predictions = classifier_model.predict(dataset.test_ds, batch_size=batch_size, steps=num_steps)

prediction_list = []
step_count = 0

for batch in dataset.test_ds:
    label_list = list(batch[1].numpy())
    text_list = list(batch[0].numpy())
    
    for i, (text, actual_label) in enumerate(zip(text_list, label_list)):
        score = tf.nn.softmax(predictions[i])
        score = tf.reduce_max(score)
        prediction = tf.math.argmax(predictions[i]).numpy()
        prediction = dataset.get_str_label(prediction)
        prediction_list.append([text.decode('utf-8'),
                                tf.get_static_value(score),
                                prediction,
                                dataset.get_str_label(actual_label)])
    
    step_count += 1
    if num_steps <= step_count:
        break
    
result_df = pd.DataFrame(prediction_list, columns=["Input Text", "Score", "Predicted Label", "Actual Label"])
result_df.style.hide(axis="index")



## 5. Export the model

Since training has completed, export the `saved_model.pb` to the output directory in a folder with the model and dataset name.

In [None]:
model_dir = "{}_{}".format(model_name, dataset.dataset_name)
model_dir = os.path.join(output_directory, model_dir)
classifier_model.save(model_dir, include_optimizer=False)

saved_model_path = os.path.join(model_dir, "saved_model.pb")
if os.path.exists(saved_model_path):
    print("Saved model location:", saved_model_path)

## 6. Reload the model and make predictions

Reload from the `saved_model.pb` in the output directory.

In [None]:
reloaded_model = tf.saved_model.load(model_dir)

The next section defines a list of strings to send as input to the reloaded model. If you are using a dataset other than the [AG News dataset](https://www.tensorflow.org/datasets/catalog/ag_news_subset), you can update the snippet below with your own list of input text.

In [None]:

input_text = ["WASHINGTON - Employers stepped up hiring in August, expanding payrolls by 144,000 and lowering the unemployment rate to 5.4 percent.",
              "PRESENTACION, Philippines (Reuters) - Philippine communist rebels freed Wednesday two soldiers they had held as 'prisoners of war' for'\
              'more than five months, saying they wanted to rebuild confidence in peace talks with the government.", 
              "Geneva - Worldwide sales of industrial robots surged to record levels in the first half of 2004 after equipment prices fell while labour' \
              'costs grew, the United Nations Economic Commission for Europe said in a report to be released today."]
    
if not input_text:
    raise ValueError("Please define the list of input_text strings.")

# Send the input text to the reloaded model
predict_results = tf.nn.softmax(reloaded_model(tf.constant(input_text)))

# Get the results into a data frame to display
result_list = [[input_text[i],
                tf.get_static_value(tf.reduce_max(predict_results[i])),
                dataset.get_str_label(tf.math.argmax(predict_results[i]))] for i in range(len(input_text))]
result_df = pd.DataFrame(result_list, columns=["Input Text", "Score", "Predicted Label"])
result_df.style.hide(axis="index")

## Citations

```
@misc{zhang2015characterlevel,
    title={Character-level Convolutional Networks for Text Classification},
    author={Xiang Zhang and Junbo Zhao and Yann LeCun},
    year={2015},
    eprint={1509.01626},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
```