# Training and Fine-Tuning BERT for Classification
## Classfying Goodreads Reviews By Book Genre

By Maria Antoniak, Melanie Walsh, and the [AI for Humanists](https://aiforhumanists.com/) Team

Updated: 2024-11-05
<br></br>

This notebook will demonstrate how users can train and fine-tune a BERT model for classification with the popular HuggingFace `transformers` Python library.

We will fine-tune a BERT model on Goodreads reviews from the [UCSD Book Graph](https://mengtingwan.github.io/data/goodreads.html) with the goal of predicting the genre of the book being reviewed. The genres include:
- poetry
- comics & graphic
- fantasy & paranormal
- history & biography
- mystery, thriller, & crime
- romance
- young adult  

**Basic steps involved in using BERT and HuggingFace:**
1. Divide your data into training and test sets.
2. Encode your data into a format BERT will understand.
3. Combine your data and labels into datset objects.
4. Load the pre-trained BERT model.
5. Fine-tune the model using your training data.
6. Predict new labels and evaluate performance on your test data.



<br><br>

## **Import necessary Python libraries and modules**

First, we will import necessary Python libraries and modules. These include as `gdown`, for downloading large files from Google Drive (where we will get our UCSD Goodreads reviews), as well as scikit-learn (`sklearn`) and PyTorch (`torch`), for various machine learning tools.

In [None]:
# Basic Python modules
from collections import defaultdict
import random
import pickle

# For working with JSON files
import json

# For data manipulation and analysis
import pandas as pd
import numpy as np

# For machine learning tools and evaluation
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

# For deep learning
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
import torch

# For plotting and data visualization
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import ticker
sns.set(style='ticks', font_scale=1.2)

The HuggingFace [`transformers` Python library](https://huggingface.co/transformers/installation.html) is included in Colab by default now, so we do not need to install it (but this is how you would install it with `pip`).

In [None]:
#!pip3 install transformers

From `transformers`, we will import modules for `DistilBert`, a *distilled* or smaller version of a BERT model that runs more quickly and uses less computing power. This makes it ideal for those just getting started with BERT.

In [None]:
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import Trainer, TrainingArguments

<br><br>

## **Set parameters and file paths**

In [4]:
# This is the name of the BERT model that we want to use.
# We're using DistilBERT to save space (it's a distilled version of the full BERT model),
# and we're going to use the cased (vs uncased) version.
model_name = 'distilbert-base-cased'

# This is the name of the program management system for NVIDIA GPUs. We're going to send our code here.
device_name = 'cuda'

# This is the maximum number of tokens in any document sent to BERT.
max_length = 512

# This is the name of the directory where we'll save our model. You can name it whatever you want.
pt_model_path = './models/pt-model'
tflite_model_path = './models/tflite-model/model.tflite'
tf_model_path = "./models/tf-model"
LABELS = [
    'romance', 'fantasy_paranormal', 'poetry', 'children',
    'young_adult', 'mystery_thriller_crime', 'comics_graphic', 'history_biography'
]
# ==== Example ====
tests = [
    "A romantic story between two star-crossed lovers.",       # romance
    "A tale of vampires and magical adventures.",             # fantasy_paranormal
    "Roses are red, violets are blue, this poem is for you.",# poetry
    "Fun stories and illustrations for children.",           # children
    "A coming-of-age story about a young adult finding themselves.", # young_adult
    "Detectives chase a cunning criminal through the city.", # mystery_thriller_crime
    "A comic book adventure with superheroes and villains.", # comics_graphic
    "A detailed biography of Abraham Lincoln and his presidency.", # history_biography
    "Two lovers navigate a complicated relationship in Paris.",     # romance
    "A wizard battles dark forces in a haunted castle.",            # fantasy_paranormal
    "Ode to the stars shining bright in the night sky.",            # poetry
    "Colorful tales of talking animals and magical forests.",       # children
    "A young adult embarks on a journey of self-discovery.",        # young_adult
    "A thrilling chase as the detective hunts a serial thief.",     # mystery_thriller_crime
    "Superheroes team up to save the city from a giant robot.",    # comics_graphic
    "An in-depth account of Marie Curie's life and discoveries."    # history_biography
]

In [None]:
# Alias 
import numpy as np
import tensorflow as tf
from transformers import DistilBertTokenizerFast

# Load tokenizer to use throughout the code
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)


<br><br>

## **Load and sample Goodreads data**

In this cell, we create a Python dictionary with each genre and the link to the corresponding UCSD Goodreads review data for that genre.

*If you manually click on any of the URLs, you will be able to download the data for that genre. For example, here's the link for poetry: https://datarepo.eng.ucsd.edu/mcauley_group/gdrive/goodreads/byGenre/goodreads_reviews_poetry.json.gz*

In [None]:
# This is where our target data is hosted on the web. You only need these paths for the book review dataset.

# Source: https://mengtingwan.github.io/data/goodreads.html#datasets

genre_url_dict = {'poetry':                 'https://mcauleylab.ucsd.edu/public_datasets/gdrive/goodreads/byGenre/goodreads_reviews_poetry.json.gz',
                  'children':               'https://mcauleylab.ucsd.edu/public_datasets/gdrive/goodreads/byGenre/goodreads_reviews_children.json.gz',
                  'comics_graphic':         'https://mcauleylab.ucsd.edu/public_datasets/gdrive/goodreads/byGenre/goodreads_reviews_comics_graphic.json.gz',
                  'fantasy_paranormal':     'https://mcauleylab.ucsd.edu/public_datasets/gdrive/goodreads/byGenre/goodreads_reviews_fantasy_paranormal.json.gz',
                  'history_biography':      'https://mcauleylab.ucsd.edu/public_datasets/gdrive/goodreads/byGenre/goodreads_reviews_history_biography.json.gz',
                  'mystery_thriller_crime': 'https://mcauleylab.ucsd.edu/public_datasets/gdrive/goodreads/byGenre/goodreads_reviews_mystery_thriller_crime.json.gz',
                  'romance':                'https://mcauleylab.ucsd.edu/public_datasets/gdrive/goodreads/byGenre/goodreads_reviews_romance.json.gz',
                  'young_adult':            'https://mcauleylab.ucsd.edu/public_datasets/gdrive/goodreads/byGenre/goodreads_reviews_young_adult.json.gz'}

In [None]:
print(len(genre_url_dict))

Next we loop through this dictionary and use `gdown` to download the Goodreads review data for each genre from Google Drive.

Now we will load the first 100,000 reviews from each link and randomly sample 2,000 reviews.

In [None]:
import requests
# Stream reviews from URL and collect a subset
def load_reviews(url, head=10000, sample_size=2000):
    reviews = []
    count = 0

    response = requests.get(url, stream=True)
    with gzip.open(response.raw, 'rt', encoding='utf-8') as file:
        for line in file:
            d = json.loads(line)
            reviews.append(d['review_text'])
            count += 1

            # Stop if we have reached the 100,000 limit
            if head is not None and count >= head:
                break

    # Return random sample of reviews
    return random.sample(reviews, min(sample_size, len(reviews)))

# Reviews by genre
genre_reviews_dict = {}

# Load reviews for each genre
for genre, url in genre_url_dict.items():
    print(f'Loading reviews for genre: {genre}')
    genre_reviews_dict[genre] = load_reviews(url, head=10000, sample_size=2000)


Let's preview a couple of the key-value pairs in `genre_reviews_dict`

In [None]:
 for _genre, _reviews in genre_reviews_dict.items():
    print(_genre)
    print(random.sample(_reviews, 1)[0])

Here we use `pickle` to save this Python dictionary to a `.pickle` file so we can easily load it later.

*The `pickle` module allows you to save and load Python objects like lists and dictionaries.*

In [None]:
pickle.dump(genre_reviews_dict, open('genre_reviews_dict.pickle', 'wb'))
# genre_reviews_dict = pickle.load(open('genre_reviews_dict.pickle', 'rb'))

<br><br>

## **Split the data into training and test sets**

When training a machine learning model, it is necessary to split your training data into two parts: a "training" set and a "test" set.

We will train our BERT model on the "training" set of Goodreads reviews and then we will evaluate how well it is performing by running it on the "test" set of Goodreads reviews that the model has never seen before.

Normally, to tune the hyperparameters, you should also create a "validation" set for tuning, and only use the "test" set once, at the end of all tuning. For simplicity, in this tutorial, we will only using a training and test set.

In [None]:
train_texts = []
train_labels = []

test_texts = []
test_labels = []

for _genre, _reviews in genre_reviews_dict.items():

  _reviews = random.sample(_reviews, 1000) # Use a very small set as an example.

  for _review in _reviews[:800]:
    train_texts.append(_review)
    train_labels.append(_genre)
  for _review in _reviews[800:]:
    test_texts.append(_review)
    test_labels.append(_genre)

Show how many Goodreads reviews and labels we have in each category: 6400 training reviews, 6400 training labels (genres), 1600 test reviews, 1600 test labels (genre)

In [None]:
len(train_texts), len(train_labels), len(test_texts), len(test_labels)

Here's an example of a training label and review:

In [None]:
train_labels[0], train_texts[0]

<br><br>

## **Run a baseline model (logistic regression)**

Here we train and evaluate a simple TF-IDF baseline model using logistic regression.

We find better-than-random performance, even for a very small dataset. We'll see whether BERT can beat this good baseline!

In [None]:
vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform(train_texts)
X_test = vectorizer.transform(test_texts)

We train a logistic regression model from scikit-learn on the Goodreads training data, and then we use the trained model to make predictions on our Goodreads review test set.

In [None]:
model = LogisticRegression(max_iter=1000).fit(X_train, train_labels)
predictions = model.predict(X_test)

We can use scikit-learn's `classification_report` function to evaluate how well the logistic regression model's predictions match up with the true labels for the Goodreads reviews.

Importantly, we can see that our average scores are above random performance (we have 8 classes, so random performance would be ~0.2).

In [None]:
print(classification_report(test_labels, predictions))

<br><br>

## **Encode data for BERT**

We're going to transform our texts and labels into a format that BERT (via Huggingface and PyTorch) will understand. This is called *encoding* the data.

Here are the steps we need to follow:

1. The labels&mdash;in this case, Goodreads genres&mdash;need to be turned into integers rather than strings.

2. The texts&mdash;in this case, Goodreads reviews&mdash;need to be truncated if they're more than 512 tokens or padded if they're fewer than 512 tokens. The tokens, or words in the texts, also need to be separated into "word pieces" and matched to their embedding vectors.

3. We need to add special tokens to help BERT:

| BERT special token | Explanation |
| --------------| ---------|
| [CLS] | Start token of every document. |
| [SEP] | Separator between each sentence |
| [PAD] | Padding at the end of the document as many times as necessary, up to 512 tokens |
|  &#35;&#35; | Start of a "word piece" |




Here we will load `DistilBertTokenizerFast` from the HuggingFace library, which will do all the work of encoding the texts for us. The `tokenizer()` will break word tokens into word pieces, truncate to 512 tokens, and add padding and special BERT tokens.

In [None]:
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) # The model_name needs to match our pre-trained model.

Here we will create a map of our labels, or Goodreads genres, to integer keys. We take the unique labels, and then we make a dictionary that associates each label/tag with an integer.

**Note:** HuggingFace documentation sometimes refers to "labels" as "tags" but these are the same thing. We use "labels" throughout this notebook for clarity.

In [None]:
unique_labels = set(label for label in train_labels)
label2id = {label: id for id, label in enumerate(unique_labels)}
id2label = {id: label for label, id in label2id.items()}

In [None]:
label2id.keys()

In [None]:
id2label.keys()

Now let's encode our texts and labels!

In [None]:
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=max_length)
test_encodings  = tokenizer(test_texts, truncation=True, padding=True, max_length=max_length)

train_labels_encoded = [label2id[y] for y in train_labels]
test_labels_encoded  = [label2id[y] for y in test_labels]

**Examine a Goodreads review in the training set after encoding**

In [None]:
' '.join(train_encodings[0].tokens[0:500])

**Examine a Goodreads review in the test set after encoding**

In [None]:
' '.join(test_encodings[0].tokens[0:100])

**Examine the training labels after encoding**

In [None]:
set(train_labels_encoded)

**Examine the test labels after encoding**

In [None]:
set(test_labels_encoded)

<br><br>

## **Make a custom Torch dataset**

Here we combine the encoded labels and texts into dataset objects. We use the custom Torch `MyDataSet` class to make a `train_dataset` object from  the `train_encodings` and `train_labels_encoded`. We also make a `test_dataset` object from `test_encodings`, and `test_labels_encoded`.

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
train_dataset = MyDataset(train_encodings, train_labels_encoded)
test_dataset = MyDataset(test_encodings, test_labels_encoded)

**Examine a Goodreads review in the Torch `training_dataset` after encoding**

In [None]:
' '.join(train_dataset.encodings[0].tokens[0:100])

**Examine a Goodreads review in the Torch `test_dataset` after encoding**

In [None]:
' '.join(test_dataset.encodings[1].tokens[0:100])

<br><br>

## **Load pre-trained BERT model**

Here we load a pre-trained DistilBERT model and send it to CUDA.

**Note:** If you decide to repeat fine-tuning after already running the following cells, make sure that you re-run this cell to re-load the original pre-trained model before fine-tuning again.

In [None]:
# The model_name needs to match the name used for the tokenizer above.
model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=len(id2label)).to(device_name)

<br><br>

## **Set the BERT fine-tuning parameters**

These are the arguments we'll set in the HuggingFace TrainingArguments objects, which we'll then pass to the HuggingFace Trainer object. There are many more possible arguments, but here we highlight the basics and some common gotchas.

When training your own model, you should search over these parameters to find the best settings for your particular dataset. You should use a held-out set of validation data for this step.

In [None]:
import transformers
print(transformers.__version__)


| Parameter | Explanation |
|-----------| ------------|
| num_train_epochs | total number of training epochs (how many times to pass through the entire dataset; too much can cause overfitting) |
| per_device_train_batch_size | batch size per device during training |
| per_device_eval_batch_size |  batch size for evaluation |
|  warmup_steps |  number of warmup steps for learning rate scheduler (set lower because of small dataset size) |
| weight_decay | strength of weight decay (reduces size of weights, like regularization) |
| output_dir | output directory for the fine-tuned model and configuration files |
| logging_dir | directory for storing logs |
| logging_steps | how often to print logging output (so that we can stop training early if the loss isn't going down) |
| evaluation_strategy | evaluate while training so that we can see the accuracy going up |

In [None]:
training_args = TrainingArguments(
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=20,   # batch size for evaluation
    learning_rate=5e-5,              # initial learning rate for Adam optimizer
    warmup_steps=100,                # number of warmup steps for learning rate scheduler (set lower because of small dataset size)
    weight_decay=0.01,               # strength of weight decay
    output_dir='./models/results',          # output directory
    logging_dir='./logs',            # directory for storing logs
    logging_steps=100,               # number of steps to output logging (set lower because of small dataset size)
    eval_strategy='steps',     # evaluate during fine-tuning so that we can see progress
    report_to=[],  # Disables wandb logging
)

<br><br>

## **Fine-tune the BERT model**

First, we define a custom evaluation function that returns the accuracy. You could modify this function to return precision, recall, F1, and/or other metrics.

In [None]:
def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  acc = accuracy_score(labels, preds)
  return {
      'accuracy': acc,
  }

Then we create a HuggingFace `Trainer` object using the `TrainingArguments` object that we created above. We also send our `compute_metrics` function to the `Trainer` object, along with our test and train datasets.

**Note:** This is what we've been aiming for this whole time! All the work of tokenizing, creating datasets, and setting the training arguments was for this cell.

In [None]:
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=test_dataset,           # evaluation dataset (usually a validation set; here we just send our test set)
    compute_metrics=compute_metrics      # our custom evaluation function
)

Time to finally fine-tune!

Be patient; if you've set everything in Colab to use GPUs, then it should only take a minute or two to run, but if you're running on CPU, it can take hours.

After every 10 steps (as we specified in the TrainingArguments object), the trainer will output the current state of the model, including the training loss, validation ("test") loss, and accuracy (from our `compute_metrics` function).

You should see the loss going down and the accuracy going up. If instead they are staying the same or oscillating, you probably need to change the fine-tuning parameters.

In [None]:
# Turn off weights and biases logging, which requires an API key

import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
trainer.train()

<br><br>

## **Save fine-tuned model**

The following cell will save the model and its configuration files to a directory in Colab. To preserve this model for future use, you should download the model to your computer.

In [None]:
trainer.save_model(pt_model_path)

(Optional) If you've already fine-tuned and saved the model, you can reload it using the following line. You don't have to run fine-tuning every time you want to evaluate.

In [None]:
# trainer = DistilBertForSequenceClassification.from_pretrained(cached_model_directory_name)

<br><br>

## **Evaluate fine-tuned model**

The following function of the `Trainer` object will run the built-in evaluation, including our `compute_metrics` function.

In [None]:
trainer.evaluate()

But we might want to do more fine-grained analysis of the model, so we extract the predicted labels.

In [None]:
predicted_results = trainer.predict(test_dataset)

In [None]:
predicted_results.predictions.shape

In [None]:
predicted_labels = predicted_results.predictions.argmax(-1) # Get the highest probability prediction
predicted_labels = predicted_labels.flatten().tolist()      # Flatten the predictions into a 1D list
predicted_labels = [id2label[l] for l in predicted_labels]  # Convert from integers back to strings for readability

In [None]:
len(predicted_labels)

In [None]:
print(classification_report(test_labels,
                            predicted_labels))

<br><br>

## **Pull out correct and incorrect classifications for examination**

Let's use our predicted labels for some analysis!

Now that we've fine-tuned and pulled out our predicted labels, the BERT part of this tutorial is done. You can now use the predicted labels in the same way you would use any set of predicted labels from any classification model. We'll show some examples here.

First, let's print out some example predictions that were correct.

In [None]:
for _true_label, _predicted_label, _text in random.sample(list(zip(test_labels, predicted_labels, test_texts)), 20):
  if _true_label == _predicted_label:
    print('LABEL:', _true_label)
    print('REVIEW TEXT:', _text[:100], '...')
    print()

Now let's print out some misclassifications.

In [None]:
for _true_label, _predicted_label, _text in random.sample(list(zip(test_labels, predicted_labels, test_texts)), 20):
  if _true_label != _predicted_label:
    print('TRUE LABEL:', _true_label)
    print('PREDICTED LABEL:', _predicted_label)
    print('REVIEW TEXT:', _text[:100], '...')
    print()

Finally, let's create some heatmaps to examine misclassification patterns. We could use these patterns to think about similarities and differences between genres, according to book reviewers.

In [None]:
genre_classifications_dict = defaultdict(int)
for _true_label, _predicted_label in zip(test_labels, predicted_labels):
  genre_classifications_dict[(_true_label, _predicted_label)] += 1

dicts_to_plot = []
for (_true_genre, _predicted_genre), _count in genre_classifications_dict.items():
  dicts_to_plot.append({'True Genre': _true_genre,
                        'Predicted Genre': _predicted_genre,
                        'Number of Classifications': _count})

df_to_plot = pd.DataFrame(dicts_to_plot)
df_wide = df_to_plot.pivot_table(index='True Genre',
                                 columns='Predicted Genre',
                                 values='Number of Classifications')

In [None]:
plt.figure(figsize=(9,7))
sns.set(style='ticks', font_scale=1.2)
sns.heatmap(df_wide, linewidths=1, cmap='Purples')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

Looks good! We can see that overall, our model is assigning the correct labels for each genre.

Now, let's remove the diagonal from the plot to highlight the misclassifications.

In [None]:
genre_classifications_dict = defaultdict(int)
for _true_label, _predicted_label in zip(test_labels, predicted_labels):
  if _true_label != _predicted_label: # Remove the diagonal to highlight misclassifications
    genre_classifications_dict[(_true_label, _predicted_label)] += 1

dicts_to_plot = []
for (_true_genre, _predicted_genre), _count in genre_classifications_dict.items():
  dicts_to_plot.append({'True Genre': _true_genre,
                        'Predicted Genre': _predicted_genre,
                        'Number of Classifications': _count})

df_to_plot = pd.DataFrame(dicts_to_plot)
df_wide = df_to_plot.pivot_table(index='True Genre',
                                 columns='Predicted Genre',
                                 values='Number of Classifications')

In [None]:
plt.figure(figsize=(9,7))
sns.set(style='ticks', font_scale=1.2)
sns.heatmap(df_wide, linewidths=1, cmap='Purples')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

There's much more you can do with your own dataset and labels! Classification can be used to apply a small set of labels across a big dataset; to explore misclassifications to better understand users; and much more! We hope you'll use this tutorial in all kinds of creative ways.

# PT to TFLite using Optimum-cli

In [None]:
!pip install optimum[exporters] --upgrade
!optimum-cli --help
# python3 alias need to setup


In [None]:
!optimum-cli export tflite --model ./distilbert-reviews-genres --task text-classification --sequence_length 512 "models/tflite-model-optimum"

In [None]:

# ==== Load TFLite model ====
interpreter = tf.lite.Interpreter(model_path="models/tflite-model-optimum")


interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
for detail in input_details:
    print(f"Expected dtype: {detail['dtype']}")

output_details = interpreter.get_output_details()
def predict_tflite_op(texts):
    # texts: list of strings
    inputs = tokenizer(
        texts,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="np"
    )

    # TFLite model yêu cầu int64
    input_ids = inputs["input_ids"].astype(np.int64)         # shape (batch, seq_len)
    attention_mask = inputs["attention_mask"].astype(np.int64)

    # Set inputs
    for detail in input_details:
        if "input_ids" in detail["name"]:
            interpreter.set_tensor(detail["index"], input_ids)
        elif "attention_mask" in detail["name"]:
            interpreter.set_tensor(detail["index"], attention_mask)
        else:
            interpreter.set_tensor(detail["index"], np.zeros(input_ids.shape, dtype=np.int64))

    # Run inference
    interpreter.invoke()

    # Get output
    output_data = interpreter.get_tensor(output_details[0]["index"])
    preds = np.argmax(output_data, axis=1)

    return [LABELS[i] for i in preds]

# Example
for text in tests:
    label = predict_tflite_op(text)
    print(f"Text: {text}\nPredicted label: {label}\n")


    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


Expected dtype: <class 'numpy.int64'>
Expected dtype: <class 'numpy.int64'>
Text: A romantic story between two star-crossed lovers.
Predicted label: ['fantasy_paranormal']

Text: A tale of vampires and magical adventures.
Predicted label: ['comics_graphic']

Text: Roses are red, violets are blue, this poem is for you.
Predicted label: ['children']

Text: Fun stories and illustrations for children.
Predicted label: ['history_biography']

Text: A coming-of-age story about a young adult finding themselves.
Predicted label: ['history_biography']

Text: Detectives chase a cunning criminal through the city.
Predicted label: ['romance']

Text: A comic book adventure with superheroes and villains.
Predicted label: ['poetry']

Text: A detailed biography of Abraham Lincoln and his presidency.
Predicted label: ['mystery_thriller_crime']

Text: Two lovers navigate a complicated relationship in Paris.
Predicted label: ['fantasy_paranormal']

Text: A wizard battles dark forces in a haunted castle.
P

In [16]:
import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

# ==== Load PyTorch model ====
pt_model = DistilBertForSequenceClassification.from_pretrained(pt_model_path)
pt_model.eval()
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)

# ==== Inference function PyTorch ====
def predict_pt(texts):
    inputs = tokenizer(
        texts,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    with torch.no_grad():
        outputs = pt_model(**inputs)
        preds = torch.argmax(outputs.logits, dim=1)
    return [LABELS[i] for i in preds.tolist()]

# Transforming PT model to TF and to TFLite

## Pytorch to Tensorflow

In [10]:
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification


# ==== Load và convert PyTorch model sang TensorFlow ====
tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_model_path, from_pt=True)

# ==== Optional: tạo serving signature ====
@tf.function(input_signature=[
    tf.TensorSpec([None, 512], tf.int32, name="input_ids"),
    tf.TensorSpec([None, 512], tf.int32, name="attention_mask")
])
def serving_fn(input_ids, attention_mask):
    return tf_model(input_ids=input_ids, attention_mask=attention_mask)

# ==== Lưu TF SavedModel với signature ====
tf.saved_model.save(tf_model, tf_model_path, signatures={"serving_default": serving_fn})

print(f"✅ PyTorch model đã convert sang TensorFlow SavedModel tại {tf_model_path}")





TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.
All PyTorch model weights were used when initializing TFDistilBertForSequenceClassification.

All the weights of TFDistilBertForSequenceClassification were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.
































INFO:tensorflow:Assets written to: ./models/tf-model\assets


INFO:tensorflow:Assets written to: ./models/tf-model\assets


✅ PyTorch model đã convert sang TensorFlow SavedModel tại ./models/tf-model


## Test TF Model

In [11]:
import tensorflow as tf
from transformers import DistilBertTokenizerFast

# ==== Load TF model (SavedModel) ====
tf_model = tf.saved_model.load(tf_model_path)

# ==== Load tokenizer ====
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)

# ==== Inference function TensorFlow ====
def predict_tf(texts):
    inputs = tokenizer(
        texts,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="tf"
    )
    # Call serving_default signature
    outputs = tf_model.signatures["serving_default"](
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"]
    )
    # outputs trả về dict với key như 'logits'
    logits = outputs["logits"].numpy()
    preds = logits.argmax(axis=1)
    return [LABELS[i] for i in preds]

## Tensorflow to TFLite

In [12]:
import tensorflow as tf
import numpy as np

# ==== Convert ====
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# ==== Save ====
with open(tflite_model_path, "wb") as f:
    f.write(tflite_model)

print(f"✅ TF SavedModel đã convert sang TFLite tại {tflite_model_path}")

✅ TF SavedModel đã convert sang TFLite tại ./models/tflite-model/model.tflite


## Inference TFLite model

In [22]:

# ==== Load TFLite model ====
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

for detail in input_details:
    print(f"Expected dtype: {detail['dtype']}")


# ==== Inference function TFLite ====
def predict_tflite(texts):
    inputs = tokenizer(
        texts,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="np"
    )
    input_ids = inputs["input_ids"].astype(np.int32)
    attention_mask = inputs["attention_mask"].astype(np.int32)

    # Set inputs
    for detail in input_details:
        if "input_ids" in detail["name"]:
            interpreter.set_tensor(detail["index"], input_ids)
        elif "attention_mask" in detail["name"]:
            interpreter.set_tensor(detail["index"], attention_mask)

    # Run inference
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]["index"])
    preds = np.argmax(output_data, axis=1)
    return [LABELS[i] for i in preds]


Expected dtype: <class 'numpy.int32'>
Expected dtype: <class 'numpy.int32'>


## Compare model performance

In [23]:
for text in tests:
    label_pt =     predict_pt([text])[0]   # PyTorch predict
    # label_tf =     predict_tf([text])[0]
    label_tflite = predict_tflite([text])[0]  # TFLite predict trả list, lấy phần tử đầu
    print(f"Text: {text} {"✅" if label_tflite == label_pt else ""}\nTFLite label: {label_tflite}\nPyTorch label: {label_pt}")
    # print(f"Tensor label: {label_tf}")

Text: A romantic story between two star-crossed lovers. ✅
TFLite label: fantasy_paranormal
PyTorch label: fantasy_paranormal
Text: A tale of vampires and magical adventures. ✅
TFLite label: comics_graphic
PyTorch label: comics_graphic
Text: Roses are red, violets are blue, this poem is for you. ✅
TFLite label: children
PyTorch label: children
Text: Fun stories and illustrations for children. ✅
TFLite label: history_biography
PyTorch label: history_biography
Text: A coming-of-age story about a young adult finding themselves. ✅
TFLite label: history_biography
PyTorch label: history_biography
Text: Detectives chase a cunning criminal through the city. ✅
TFLite label: romance
PyTorch label: romance
Text: A comic book adventure with superheroes and villains. ✅
TFLite label: poetry
PyTorch label: poetry
Text: A detailed biography of Abraham Lincoln and his presidency. ✅
TFLite label: mystery_thriller_crime
PyTorch label: mystery_thriller_crime
Text: Two lovers navigate a complicated relation