# Assignment 3: Fine tuning a multiclass classification BERT model

**Description:** This assignment covers fine-tuning of a multiclass classification. You will compare two different types of solutions using BERT-based models. You should also be able to develop an intuition for:


* Working with BERT
* Using multiple models to focus on different sub-tasks
* Different metrics to measure the effectiveness of your model
* Modifying your models to deal with class imbalance



The assignment notebook closely follows the lesson notebooks. We will use the 20 newsgroups dataset and will leverage some of the models, or part of the code, for our current investigation.

**You are strongly encouraged to read through the ENTIRE notebook before answering any questions or writing any code.**

The initial part of the notebook is purely setup. We will then generate our BERT model and see if and how we can improve it.

Fine-tuning a BERT model requires a GPU to work in a timely fashion. This notebook should be run on a Google Colab leveraging a GPU. By default, when you open the notebook in Colab it will try to use a GPU. Total runtime of the entire notebook (with solutions and a Colab GPU) should be about 1h.


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/datasci-w266/2024-fall-main/blob/master/assignment/a3/Multiclass_text_classification.ipynb)

The overall assignment structure is as follows:

1. Setup

  1.1 Libraries & Helper Functions

  1.2 Data Acquisition

  1.3 Training/Test/Validation Sets for BERT-based models

2. Classification with a fine tuned BERT model

  2.1 Create the specified BERT model

  2.2 Fine tune the BERT model as directed

  2.3 Examine the predictions with various metrics

3. Classification using two stages

  3.1 Relabel the data to group the often confused classes

  3.2 Train the first stage model on the relabeled data

  3.3 Separate the data for just the confused classes

  3.4 Train the second stage model on the two classes

  3.5 Combine and evaluate the predictions from the two stages

4. Look at examples of misclassifications, see what might have changed



**INSTRUCTIONS:**:

* Questions are always indicated as **QUESTION:**, so you can search for this string to make sure you answered all of the questions. You are expected to fill out, run, and submit this notebook, as well as to answer the questions in the **answers** file as you did in a1 and a2.

* **### YOUR CODE HERE** indicates that you are supposed to write code.

* If you want to, you can run all of the cells in section 1 in bulk. This is setup work and no questions are in there. At the end of section 1 we will state all of the relevant variables that were defined and created in section 1.

* **IMPORTANT NOTE:** Because the data we're using is downloaded each time we run section 1, a different split of train, validation, and test records is created.  This means that the accuracy, precision, recall, and F1 scores will change, although the delta will be small.  Please enter the values from your final run so that the answer values in your answers file correspond to the answer values in the outputs in your notebook.


### 1. Setup

Lets get all our libraries and download and process our data.

In [1]:
!pip install -q transformers==4.17

In [2]:
#Download Keras 2 versions of software
!pip install tensorflow==2.15.0 --quiet #15 13
!pip install tf_keras==2.15.0 --quiet
!pip install tensorflow-text==2.15.0 --quiet #15

In [3]:
!pip install pyarrow==14.0.2 --quiet
!pip install -q datasets

In [4]:
!pip install pydot

In [5]:
!pip install -U scikit-learn

In [6]:
from sklearn.metrics import classification_report

In [7]:
from datasets import load_dataset

In [8]:
from collections import Counter
import numpy as np
import tensorflow as tf
from tensorflow import keras

import seaborn as sns
import matplotlib.pyplot as plt
from pprint import pprint

In [9]:
from transformers import BertTokenizer, TFBertModel

We're going to use the 20 newsgroups dataset as it is ideal for exploring multiclass classification.  It includes posts from 20 different newsgroups.  Our task will be to correctly label a post with it's group.  We'll download a version available from Hugging Face.

In [10]:
ds = load_dataset("TopicNet/20-Newsgroups")

In [11]:
ds

Take a look at the records.  We basically have a long string of text and an associated id label.  That id is the Usenet group where the posting occured. The records are the raw text in the raw_text field.  They vary significantly in size.  The record contains a number of other records we will ignore for now.

In [12]:
ds["train"][:2]

Notice the "targets" are just integers that are an offset into the list of target names which are the group to which the record belongs.

In [13]:
ds["train"]["target"][:5]


And counterintuitively, the "ids" are the actual names of the groups.

In [14]:
ds["train"]["id"][:5]

Now we need to assemble the training data. We need to create parallel lists of normalized content.

In [15]:
train_tags = []
train_labels = []
train_text = []

for example in ds['train']:  # Assuming 'ds' is your Hugging Face dataset object
    # Extract the 'id' and remove the underscore and trailing digits
    id_string = example['id']
    cleaned_id = id_string.rsplit('_', 1)[0]
    train_text.append(example['raw_text'])
    train_tags.append(cleaned_id)

    # Extract the 'target' value
    train_labels.append(example['target'])

# Print the extracted lists (optional)
print(train_text[:5])
print(train_tags[:5])  # Print first 5 elements for verification
print(train_labels[:5])  # Print first 5 elements for verification

We also need to create a dictionary of the 20 group names to which the posts belong and map those to the integer based tags.

In [16]:
# Create an empty dictionary to store the tag-label mappings
tag_label_dict = {}

# Iterate over the lists using zip
for tag, label in zip(train_tags, train_labels):
    # If the tag is not already in the dictionary, add it with its label
    if tag not in tag_label_dict:
        tag_label_dict[tag] = str(label)  # Convert label to string

# Print the resulting dictionary (optional)
print(tag_label_dict)

# Get the keys as a list
tag_list = list(tag_label_dict.keys())

# Get the values as a list
label_list = list(tag_label_dict.values())

# Print the lists (optional)
print("Tags:", tag_list)
print("Labels:", label_list)

In [17]:
def sort_dict_by_values(input_dict):

    # Create separate sorted lists of the integer tags and the associated labels
    # We'll need this for doing analysis of the results of our classifier

    items = list(input_dict.items())

    items.sort(key=lambda item: int(item[1]))  # Convert values to integers for sorting

    sorted_values = [item[1] for item in items]
    sorted_keys = [item[0] for item in items]

    return sorted_values, sorted_keys  # Return as a tuple


sorted_values, sorted_keys = sort_dict_by_values(tag_label_dict)
print("Sorted Values:", sorted_values)
print("Sorted Keys:", sorted_keys)

In [18]:
def transform_input(input_string):
  # normalize the input
  # Split the string by '.' and take the first three parts and then join
  parts = input_string.split('.')[:-1]
  output_string = '_'.join(parts)

  return output_string

test_tags = []
test_labels = []
test_texts = []

for example in ds['test']:
    # Extract the 'id' and remove the underscore and trailing digits
    id_string = example['id']
    cleaned_id = transform_input(id_string)
    test_texts.append(example['raw_text'])
    test_tags.append(cleaned_id)

    # Extract the 'target' value
    test_labels.append(example['target'])

# Print the extracted lists (optional)
print(test_texts[:5])
print(test_labels[:5])  # Print first 5 elements for verification

In [19]:
len(train_tags)

The variable ''target_names'' stores all of the names of the labels.

In [20]:
# Convert back to a list (if needed)
target_names = sorted_keys


In [21]:
print(target_names)

We already have a set aside test set and a train set.  Let's explicitly set aside part of our training set for validation purposes.

In [22]:
#len(train_texts)
valid_texts = train_text[10000:]
valid_labels = train_labels[10000:]
train_texts = train_text[:10000]
train_labels = train_labels[:10000]

In [23]:
len(train_texts)

The training set will always have 10000 records and the validation set will always have 1301 records.

In [24]:
len(valid_texts)

In [25]:
#get the labels in a needed data format for validation (needs to be label ids)
npvalid_labels = np.asarray(valid_labels)
# Convert train_labels to a NumPy array
train_labels = np.array(train_labels)
# Convert test_labels to a NumPy array
test_labels = np.array(test_labels)

In [26]:
train_labels[:50]

Here are the variables we've already defined for the data:

* train_texts - an array of text strings for training
* test_texts - an array of text strings for testing
* valid texts - an array of text strings for validation
* train_labels - an array of integers representing the labels associated with train_texts
* test_labels - an array of integers representing the labels associated with test_texts
* valid_labels - an array of integers representing the labels associated with valid_texts
* target_names - an array of label strings that correspond to the integers in the *_labels arrays

### 2. Classification with a fine tuned BERT model

Let's pick our BERT model.  We'll start with the base BERT model and we'll use the cased version since our data has capital and lower case letters.

In [27]:
#make it easier to use a variety of BERT subword models
model_checkpoint = 'bert-base-cased'

In [28]:
bert_tokenizer = BertTokenizer.from_pretrained(model_checkpoint)
bert_model = TFBertModel.from_pretrained(model_checkpoint)

We're setting our maximum training record length to 200.  BERT models can handle more and after you've completed the assignment you're welcome to try larger and small sized records.

In [29]:
max_length = 200

Now we'll tokenize our three data slices.  This will take a minute or two.

In [30]:
train_texts[:2]
# Check the type of elements in train_texts
print(type(train_texts[0]))

In [31]:
# If train_texts contains elements that are not strings, convert them to strings.
# For example, if train_texts contains integers, use:
train_texts = [str(text) for text in train_texts]
valid_texts = [str(text) for text in valid_texts]
test_texts = [str(text) for text in test_texts]

In [32]:
# tokenize the dataset, truncate at `max_length`,
# and pad with 0's when less than `max_length` and return a tf Tensor
train_encodings = bert_tokenizer(train_texts, truncation=True, padding=True, max_length=max_length, return_tensors='tf')
valid_encodings = bert_tokenizer(valid_texts, truncation=True, padding=True, max_length=max_length, return_tensors='tf')
test_encodings = bert_tokenizer(test_texts, truncation=True, padding=True, max_length=max_length, return_tensors='tf')

Notice our input_ids for the first training record and their padding. The train_encodings also includes an array of token_type_ids and an attention_mask array.

In [33]:
train_encodings

In [34]:
train_encodings.input_ids[:1]

Write a function to create this multiclass bert model.

Keep in mind the following:
* Each record can have one of n labels where n = the size of target_names.
* We'll still want a hidden size layer of size 201
* We'll want our hidden layer to make use of the **pooler output** from BERT
* We'll also want to use dropout
* Our classification layer will need to be appropriately sized and use the correct non-linearity for a multi-class problem.
* Since we have multiple labels we can no longer use binary cross entropy.  Instead we need to change our loss metric to a categorical cross entropy.  Which of the two categorical cross entropy metrics will work best here?
* Make sure that training affects **all** of the layers in BERT.


In [35]:
def create_bert_multiclass_model(checkpoint = model_checkpoint,
                                 num_classes = 20,
                                 hidden_size = 201,
                                 dropout=0.3,
                                 learning_rate=0.00005):
    """
    Build a simple classification model with BERT. Use the Pooler Output for classification purposes.
    """
    ### YOUR CODE HERE











    ### END YOUR CODE
    return classification_model

In [36]:
pooler_bert_model = create_bert_multiclass_model(checkpoint=model_checkpoint, num_classes=20)

In [37]:
pooler_bert_model.summary()

**QUESTION:** 2.1 How many trainable parameters are in your dense hidden layer?

**QUESTION:** 2.2 How many trainable parameters are in your classification layer?

In [38]:
keras.utils.plot_model(pooler_bert_model, show_shapes=False, show_dtype=False, show_layer_names=True, dpi=90)

In [39]:
#It takes 10 to 14 minutes to complete an epoch when using a GPU
pooler_bert_model_history = pooler_bert_model.fit([train_encodings.input_ids, train_encodings.token_type_ids, train_encodings.attention_mask],
                                                  train_labels,
                                                  validation_data=([valid_encodings.input_ids, valid_encodings.token_type_ids, valid_encodings.attention_mask],
                                                  npvalid_labels),
                                                  batch_size=8,
                                                  epochs=1)

Now we need to run evaluate against our fine-tuned model.  This will give us an overall accuracy based on the test set.

In [40]:
#eval b=8 e=1 dim=201
score = pooler_bert_model.evaluate([test_encodings.input_ids, test_encodings.token_type_ids, test_encodings.attention_mask],
                                                  test_labels)

print('Test loss:', score[0])
print('Test accuracy:', score[1])

**QUESTION:** 2.3 What is the Test accuracy score you get from your model? (Just copy and paste the value into the answers sheet and round to five significant digits.)

In [41]:
#run predict for the first three elements in the test data set
predictions = pooler_bert_model.predict([test_encodings.input_ids[:3], test_encodings.token_type_ids[:3], test_encodings.attention_mask[:3]])

In [42]:
predictions

In [43]:
#run and capture all predictions from our test set using model.predict
### YOUR CODE HERE
### END YOUR CODE

#now we need to get the highest probability in the distribution for each prediction
#and store that in a tf.Tensor
predictions_model1 = tf.argmax(predictions_model1, axis=-1)
predictions_model1

There are two ways to see what's going on with our classifier.  Overall accuracy is interesting but it can be misleading.  We need to make sure that each of our categories' prediction performance is operating at an equal or higher level than the overall.

Here we'll use the classification report from scikit learn.  It expects two inputs as arrays.  One is the ground truth (y_true) and the other is the associated prediction (y_pred).  This is based on gethering all the predictions from our our test set.

In [44]:
print(classification_report(test_labels, predictions_model1.numpy(), target_names=target_names))

**QUESTION:** 2.4 What is the key difference between the macro average F1 score and the weighted average F1 score?

**QUESTION:** 2.5 What is the macro average F1 score you get from the classification report?

Now we'll generate another very valuable visualization of what's happening with our classifier -- a confusion matrix.

In [45]:
cm = tf.math.confusion_matrix(test_labels, predictions_model1)
cm = cm/cm.numpy().sum(axis=1)[:, tf.newaxis]

And now we'll display it!

In [46]:
plt.figure(figsize=(20,7))
sns.heatmap(
    cm, annot=True,
    xticklabels=target_names,
    yticklabels=target_names)
plt.xlabel("Predicted")
plt.ylabel("True")

### 3. Classification using two stages

Okay, not bad.  As you can see, some categories are easier to distinguish than others. Look for the class with the lowest F1 score (it should be the one at the bottom of the list). In the confusion matrix, which other class is that one being mistaken for most often?

You might notice that the categories in this dataset are somewhat heirarchical. There are more obvious differences between groups of news categories (e.g. computers vs recreation) and then subtler differences between categories within those groups (e.g. PC vs Mac, within computers).

When this happens, one idea is to train a series of models, to first separate out the more obvious groups of classes, and then use more specialized sub-models to classify only a subset of the classes. Let's try that here.

#### Step 1: New model with 19 classes

For simplicity, we'll just combine two categories in our first step. We'll replace the label of the last class with the label of the class it's most often mistaken for. (That way, we'll have labels from 0 to 18 instead of 0 to 19, and don't have to renumber everything, though you would have to if you group them more.)

In [47]:
label_to_replace = 19

# label_to_replace_with = ...
### YOUR CODE HERE


### END YOUR CODE

train_labels_19class = train_labels.copy()
train_labels_19class[train_labels_19class == label_to_replace] = label_to_replace_with

valid_labels_19class = npvalid_labels.copy()
valid_labels_19class[valid_labels_19class == label_to_replace] = label_to_replace_with

test_labels_19class = test_labels.copy()
test_labels_19class[test_labels_19class == label_to_replace] = label_to_replace_with

Now let's create a new model with the same architecture, but to predict probabilities for 19 classes instead of 20. We're using all of the data in this first step, so we'll use the encodings we already preprocessed as inputs, but use the new labels that only have 19 classes.

In [48]:
bert_model_19class = create_bert_multiclass_model(checkpoint = model_checkpoint, num_classes=19)

In [49]:
bert_model_19class_history = bert_model_19class.fit([train_encodings.input_ids, train_encodings.token_type_ids, train_encodings.attention_mask],
                                                  train_labels_19class,
                                                  validation_data=([valid_encodings.input_ids, valid_encodings.token_type_ids, valid_encodings.attention_mask],
                                                                   valid_labels_19class),
                                                  batch_size=8,
                                                  epochs=1)

In [50]:
#Evaluate the fine tuned 19-class model against the test data with 19-class labels
### YOUR CODE HERE
### END YOUR CODE
print('Test loss:', score[0])
print('Test accuracy:', score[1])

**QUESTION:**

3.1 What is the test accuracy you get when you run the new first stage model with only 19 classes?


In [51]:
#run and capture all the predictions from the 19 class data
### YOUR CODE HERE
### END YOUR CODE

predictions_19class

In [52]:
target_names_19class = target_names[:label_to_replace_with] \
                     + ['** COMBINED CLASS **'] \
                     + target_names[label_to_replace_with+1:19]

print(classification_report(test_labels_19class, predictions_19class.numpy(),
                            target_names=target_names_19class))

**QUESTION:**

3.2 What is the F1 score you get for the combined class when you run the new first stage model with only 19 classes?


#### Step 2: New model with only the two classes combined in step one

Now, our first stage model is able to determine which text is one of the two often confused classes, but we need to train a more specific model to distinguish between just these two classes. Ideally, this model will only focus on the more subtle differences between these two news categories, since it doesn't have to learn everything else about the other categories.

For this model, we're only going to train using the text examples that are one of the two confused categories. We'll keep the encodings we already tokenized, so we need to separate out the input_ids, token_type_ids, and attention_mask for just the rows that have one of these two labels.

In [53]:
train_mask_2class = (train_labels_19class == label_to_replace_with)
train_encodings_2class = {'input_ids': train_encodings.input_ids[train_mask_2class],
                          'token_type_ids': train_encodings.token_type_ids[train_mask_2class],
                          'attention_mask': train_encodings.attention_mask[train_mask_2class]}
train_labels_2class = train_labels.copy()[train_mask_2class]
train_labels_2class = (train_labels_2class == label_to_replace_with).astype(int)

valid_mask_2class = (valid_labels_19class == label_to_replace_with)
valid_encodings_2class = {'input_ids': valid_encodings.input_ids[valid_mask_2class],
                          'token_type_ids': valid_encodings.token_type_ids[valid_mask_2class],
                          'attention_mask': valid_encodings.attention_mask[valid_mask_2class]}
valid_labels_2class = npvalid_labels.copy()[valid_mask_2class]
valid_labels_2class = (valid_labels_2class == label_to_replace_with).astype(int)

test_mask_2class = (test_labels_19class == label_to_replace_with)
test_encodings_2class = {'input_ids': test_encodings.input_ids[test_mask_2class],
                          'token_type_ids': test_encodings.token_type_ids[test_mask_2class],
                          'attention_mask': test_encodings.attention_mask[test_mask_2class]}
test_labels_2class = test_labels.copy()[test_mask_2class]
test_labels_2class = (test_labels_2class == label_to_replace_with).astype(int)

In [54]:
train_labels_2class.shape

In [55]:
train_labels_2class

Create and train a new model with the same architecture as before, except that it only predicts two classes. (Note that we could change this to a binary prediction model, but we'll keep it multiclass for consistency here.)

In [56]:
bert_model_2class = create_bert_multiclass_model(checkpoint=model_checkpoint, num_classes=2)

In [57]:
bert_model_2class_history = bert_model_2class.fit([train_encodings_2class['input_ids'],
                                                   train_encodings_2class['token_type_ids'],
                                                   train_encodings_2class['attention_mask']],
                                                  train_labels_2class,
                                                  validation_data=([valid_encodings_2class['input_ids'],
                                                                    valid_encodings_2class['token_type_ids'],
                                                                    valid_encodings_2class['attention_mask']],
                                                                   valid_labels_2class),
                                                  batch_size=8,
                                                  epochs=1)

In [58]:
#Evaluate the two-class model against the two-class test set.
### YOUR CODE HERE
### END YOUR CODE
print('Test loss:', score[0])
print('Test accuracy:', score[1])

In [59]:
#run and capture all the predictions from the 2-class test data
### YOUR CODE HERE
### END YOUR CODE
predictions_2class

In [60]:
# Run the sklearn classification_report again with the 2-class predictions
### YOUR CODE HERE
### END YOUR CODE

**QUESTION:**

3.3 What is the macro average F1 score you get when you run the new second stage model with only 2 classes?

#### Step 3: Combine the predicted labels from the two steps

To combine our models into two steps, start with the predictions from the first step. Keep all predicted labels except the ones with a predicted value of label_to_replace_with (the label we gave to both of the confused classes in the first step).

Wherever the first model predicted the combined category, we'll replace the predictions with the label from the second model. If we used these models in inference, we'd only send an example to the second model if the first model predicted that it was from the combined class.

In [61]:
# Now get the examples that the first model predicted as in the combined class
test_mask_2class = (predictions_19class.numpy() == label_to_replace_with)
test_encodings_2class = {'input_ids': test_encodings.input_ids[test_mask_2class],
                         'token_type_ids': test_encodings.token_type_ids[test_mask_2class],
                         'attention_mask': test_encodings.attention_mask[test_mask_2class]}

# Run those examples through the step 2 model and save the predictions
predictions_2class = bert_model_2class.predict([test_encodings_2class['input_ids'],
                                                test_encodings_2class['token_type_ids'],
                                                test_encodings_2class['attention_mask']],)
predictions_2class = tf.argmax(predictions_2class, axis=-1)

# Replace the step 2 model's predicted labels with the original values from the 20-class dataset
predictions_2class = predictions_2class.numpy()
predictions_2class[predictions_2class == 0] = label_to_replace
predictions_2class[predictions_2class == 1] = label_to_replace_with

# Combine the labels from both steps for the full test dataset
predictions_2steps = predictions_19class.numpy()
predictions_2steps[test_mask_2class] = predictions_2class

predictions_2steps

Now let's look at the classification report and confusion matrix, using the combined predictions from our two step model (compared to the original labels). Did the overall results get better?

In [62]:
# Run the sklearn classification_report with all 20 classes from the 2-step predictions
### YOUR CODE HERE
### END YOUR CODE

In [63]:
cm = tf.math.confusion_matrix(test_labels, predictions_2steps)
cm = cm/cm.numpy().sum(axis=1)[:, tf.newaxis]

In [64]:
plt.figure(figsize=(20,7))
sns.heatmap(
    cm, annot=True,
    xticklabels=target_names,
    yticklabels=target_names)
plt.xlabel("Predicted")
plt.ylabel("True")

**QUESTION:**

3.4 What is the macro average F1 score you get from the combined two-step model?

3.5 What is the difference in points between the macro weighted F1 score for the original model and the combined two-step model?

3.6 What is the new F1 score for the last category (i.e. label_to_replace, the one that had the lowest F1 score in the original model)?

3.7 What is the new F1 score for the other category that you combined with the last category in the two-step model (i.e. label_to_replace_with)?

3.8 Which metric (precision or recall) is now lower for the other category (i.e. label_to_replace_with)?

### Look at examples of misclassifications

What happened in the two-step model? Did everything improve, or did something get worse? We were concerned about the last news category, which had a very low F1 score in the original model. In the two-step model, the F1 score for that category should have gone up.

But for the other category that the original model often confused with the last category, the F1 score might have gone down. In particular, one of the two component metrics, precision or recall, probably went down. (We ask you which one went down in question 3.7 above.)

We might be able to tell what happened from the confusion matrix, but it's also always a good idea to look at actual examples that were misclassified, to see if we can spot any patterns. We can also isolate more specific examples, like test examples that the original model got right, but the two-step model got wrong. Let's do that below.

**CRITICAL NOTE:**  If nothing prints out when you run the code below, there are two possibilities.  The first is that there is some error in the code or variable names you have created in earlier cells.  The second possibility is that given your current train, validation, and test split, the second model predicted the "label_to_replace_with" class and the first model did so too.  This is unlikely but it is possible. In either case, you must go back and re-run the *ENTIRE* notebook to make sure you get a new train, validation, and test split which will allow you to observe the first and second models disagreeing. Please make sure you enter the metric values from this new run into your answers file.

In [65]:
# Make a vector the length of our test set, with 1 if the second model predicted the
# "label_to_replace_with" class, and 0s otherwise
select_predictions = (predictions_2steps == label_to_replace_with)


In [66]:
# Now only keep a 1 if that was not the correct label, i.e. it was a false positive
select_predictions = select_predictions * (test_labels != label_to_replace_with)

In [67]:
# And now only keep a 1 if the original model predicted the correct label instead
select_predictions = select_predictions * (test_labels == predictions_model1.numpy())

In [68]:
# Print out the original and clean text of the examples that met the above conditions
for i in np.where(select_predictions)[0]:

    print('Prediction: model1 = %s, model2 = %s):\nText: %s\n\n' %
          (target_names[predictions_model1[i]],
           target_names[predictions_2steps[i]],
           test_texts[i][:1000].replace('\n', ' ')))

**QUESTION:**

4.1 Why do you think the two-step model got these examples wrong, when the original model got them right?

- A. The two-step model saw less examples of the "label_to_replace" class, because we replaced them with the "label_to_replace_with" examples. So it didn't learn the kind of text in that class as well as the original model.

- B. In the two-step process, the step 1 model overpredicted the combined class, and the step 2 model overpredicted the "label_to_replace_with" class. A third class is now getting mistaken more often for the "label_to_replace_with" class, than in the original model.

- C. It's probably just random that the original model got these specific examples right and the two-step model got them wrong.



4.2 Is there anything you might try next, to try to make the two-step model better?

- A. Try to balance the training data across classes at each step, or add class weights when calling model.fit.

- B. Try to combine another similar category with the two easily confused ones, for a step 1 model with 18 classes and the step 2 model with 3 classes.

- C. Try both A and B