# Fine Tune BERT for Multi-label Classification
**Based on work from:** [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb)

Using hugging faces' BERT model, I will be training it to classify a string based on multiples classes.

In [104]:
from text_gen_model import TextGenModelManager
from transformers import AutoModelForCausalLM, GenerationConfig, PreTrainedTokenizerFast, AutoTokenizer
from datasets import load_dataset, DatasetDict
import pandas as pd
import numpy as np

## Load Dataset

In [3]:
# just to view the head of the data
class_df = pd.read_csv("./saved_data/hot_encoded_class.csv")
class_df.head()

Unnamed: 0,prompt,nta,yta,esh,nah,info
0,title: (UPDATE) AITA for telling my step-daugh...,1,0,0,0,0
1,title: AITA - I missed my daughter’s award cer...,0,1,0,0,0
2,title: AITA For Barring My Husband From The Be...,1,0,0,0,0
3,title: AITA For Firing An Employee After His P...,0,1,0,0,0
4,title: AITA For Refusing To Crochet Something ...,1,0,0,0,0


In [4]:
del class_df

In [119]:
dataset = load_dataset("csv", data_files="./saved_data/hot_encoded_class.csv")

Found cached dataset csv (/home/cstainsby/.cache/huggingface/datasets/csv/default-3272b440c9a2ece7/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)

100%|██████████| 1/1 [00:00<00:00, 364.94it/s]


In [117]:
dataset

DatasetDict({
    train: Dataset({
        features: ['prompt', 'nta', 'yta', 'esh', 'nah', 'info'],
        num_rows: 276
    })
})

In [115]:
example_row = dataset["train"][0]
example_row

KeyError: "Column train not in the dataset. Current columns in the dataset: ['prompt', 'nta', 'yta', 'esh', 'nah', 'info']"

In [112]:
dataset["train"]

KeyError: "Column train not in the dataset. Current columns in the dataset: ['prompt', 'nta', 'yta', 'esh', 'nah', 'info']"

In [68]:
labels = [label for label in dataset["train"].features.keys() if label not in ["prompt"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

['nta', 'yta', 'esh', 'nah', 'info']

## Preprocess Data
You cant pass direct text into BERT, it takes input ids. To get these, we must tokenize the data.

In [69]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [70]:
text = example_row["prompt"]
print(type(text))
encoding = tokenizer.encode(text, padding="max_length", truncation=True, max_length=128)
encoding

<class 'str'>


[101,
 2516,
 1024,
 1006,
 10651,
 1007,
 9932,
 2696,
 2005,
 4129,
 2026,
 3357,
 1011,
 2684,
 2000,
 1523,
 2175,
 3198,
 2014,
 2613,
 3611,
 1524,
 2043,
 2016,
 2356,
 2033,
 2000,
 3477,
 2005,
 2014,
 4946,
 9735,
 1029,
 4180,
 1024,
 1031,
 2434,
 2695,
 1033,
 1006,
 16770,
 1024,
 1013,
 1013,
 7479,
 1012,
 2417,
 23194,
 1012,
 4012,
 1013,
 1054,
 1013,
 26445,
 10760,
 12054,
 11484,
 1013,
 7928,
 1013,
 20868,
 18037,
 4143,
 1013,
 9932,
 2696,
 1035,
 2005,
 1035,
 4129,
 1035,
 2026,
 1035,
 3357,
 2850,
 18533,
 2121,
 1035,
 2000,
 1035,
 2175,
 1035,
 3198,
 1035,
 2014,
 1013,
 1029,
 21183,
 2213,
 1035,
 3120,
 1027,
 3745,
 1004,
 21183,
 2213,
 1035,
 5396,
 1027,
 16380,
 1035,
 10439,
 1004,
 21183,
 2213,
 1035,
 2171,
 1027,
 16380,
 6491,
 2546,
 1007,
 4931,
 4364,
 1012,
 2009,
 1521,
 1055,
 2042,
 1037,
 2204,
 1016,
 3134,
 2144,
 1045,
 1521,
 2310,
 102]

In [71]:
tokenizer.decode(encoding)

'[CLS] title : ( update ) aita for telling my step - daughter to “ go ask her real dad ” when she asked me to pay for her plane tickets? content : [ original post ] ( https : / / www. reddit. com / r / amitheasshole / comments / irxyza / aita _ for _ telling _ my _ stepdaughter _ to _ go _ ask _ her /? utm _ source = share & utm _ medium = ios _ app & utm _ name = iossmf ) hey guys. it ’ s been a good 2 weeks since i ’ ve [SEP]'

In [72]:
labels

['nta', 'yta', 'esh', 'nah', 'info']

In [120]:
def preprocess_data(examples):
  # take a batch of texts
  text = examples["prompt"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
  # add labels
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  # create numpy array of shape (batch_size, num_labels)
  labels_matrix = np.zeros((len(text), len(labels)))
  # fill numpy array
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]

  encoding["labels"] = labels_matrix.tolist()
  
  return encoding

In [121]:
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)

Loading cached processed dataset at /home/cstainsby/.cache/huggingface/datasets/csv/default-3272b440c9a2ece7/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-ea9db39876a15c6f.arrow


In [122]:
encoded_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 276
    })
})

### Split Dataset
Split the dataset into a training, validation, and testing dataset 

In [133]:
# 90% train + valid, 10% test
trainvalid_test_dataset = encoded_dataset["train"].train_test_split(test_size=0.1)

trainvalid_dataset = trainvalid_test_dataset["train"]
train_valid_dataset =  trainvalid_dataset.train_test_split(test_size=0.1)


train_test_valid_dataset = DatasetDict({
   'train': trainvalid_test_dataset["train"],
   'test': trainvalid_test_dataset['test'],
   'validation': train_valid_dataset['test']})
train_test_valid_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 248
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 28
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 25
    })
})

## Define the Model

In [134]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

## Train the Model

In [135]:
batch_size = 8
metric_name = "f1"

In [136]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"bert-finetuned-sem_eval-english",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    #push_to_hub=True,
)

In [156]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch

# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [138]:
encoded_dataset['train'][0]['labels']

[1.0, 0.0, 0.0, 0.0, 0.0]

In [139]:
encoded_dataset['train']['input_ids'][0]

[101,
 2516,
 1024,
 1006,
 10651,
 1007,
 9932,
 2696,
 2005,
 4129,
 2026,
 3357,
 1011,
 2684,
 2000,
 1523,
 2175,
 3198,
 2014,
 2613,
 3611,
 1524,
 2043,
 2016,
 2356,
 2033,
 2000,
 3477,
 2005,
 2014,
 4946,
 9735,
 1029,
 4180,
 1024,
 1031,
 2434,
 2695,
 1033,
 1006,
 16770,
 1024,
 1013,
 1013,
 7479,
 1012,
 2417,
 23194,
 1012,
 4012,
 1013,
 1054,
 1013,
 26445,
 10760,
 12054,
 11484,
 1013,
 7928,
 1013,
 20868,
 18037,
 4143,
 1013,
 9932,
 2696,
 1035,
 2005,
 1035,
 4129,
 1035,
 2026,
 1035,
 3357,
 2850,
 18533,
 2121,
 1035,
 2000,
 1035,
 2175,
 1035,
 3198,
 1035,
 2014,
 1013,
 1029,
 21183,
 2213,
 1035,
 3120,
 1027,
 3745,
 1004,
 21183,
 2213,
 1035,
 5396,
 1027,
 16380,
 1035,
 10439,
 1004,
 21183,
 2213,
 1035,
 2171,
 1027,
 16380,
 6491,
 2546,
 1007,
 4931,
 4364,
 1012,
 2009,
 1521,
 1055,
 2042,
 1037,
 2204,
 1016,
 3134,
 2144,
 1045,
 1521,
 2310,
 102]

In [140]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_test_valid_dataset["train"],
    eval_dataset=train_test_valid_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [141]:
trainer.train()



Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.227275,0.92,0.95,0.92


TrainOutput(global_step=155, training_loss=0.22187984835716987, metrics={'train_runtime': 545.0264, 'train_samples_per_second': 2.275, 'train_steps_per_second': 0.284, 'total_flos': 81566624163840.0, 'train_loss': 0.22187984835716987, 'epoch': 5.0})

In [157]:
trainer.evaluate()

{'eval_loss': 0.2272748500108719,
 'eval_f1': 0.92,
 'eval_roc_auc': 0.95,
 'eval_accuracy': 0.92,
 'eval_runtime': 3.4647,
 'eval_samples_per_second': 7.216,
 'eval_steps_per_second': 1.155,
 'epoch': 5.0}

In [160]:
predictions, label_ids, metrics = trainer.predict(train_test_valid_dataset["test"])

In [183]:
import matplotlib.pyplot as plt 


In [185]:
def id_matrix_to_label_list(label_matrix):
    id_list = []
    for label_pred_row in label_matrix:
        index = list(label_pred_row).index(1)
        id = id2label[index]
        id_list.append(id)
    return id_list

def count_frequency(lst):
    freq_dict = {}
    for value in lst:
        if value in freq_dict:
            freq_dict[value] += 1
        else:
            freq_dict[value] = 1
    return freq_dict

def merge_dicts(dict1, dict2):
    merged_dict = dict1.copy()
    for key, value in dict2.items():
        if key in merged_dict:
            merged_dict[key] = (merged_dict[key], value)
        else:
            merged_dict[key] = value
    return merged_dict


In [186]:
real = id_matrix_to_label_list(train_test_valid_dataset["test"]["labels"])
predicted = id_matrix_to_label_list(label_ids)

real_freq = count_frequency(real)
predicted_freq = count_frequency(predicted)

total_freqs = merge_dicts(real_freq, predicted_freq)

{'nta': (26, 26), 'esh': (2, 2)}

In [None]:
x = np.arange(len(labels))  # the label locations
width = 0.25  # the width of the bars
multiplier = 0

fig, ax = plt.subplots(layout='constrained')

for attribute, measurement in total_freqs.items():
    offset = width * multiplier
    rects = ax.bar(x + offset, measurement, width, label=attribute)
    ax.bar_label(rects, padding=3)
    multiplier += 1

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Length (mm)')
ax.set_title('Penguin attributes by species')
ax.set_xticks(x + width, labels)
ax.legend(loc='upper left', ncols=3)
ax.set_ylim(0, 250)

plt.show()

## Test it Out

In [143]:
text = """AITA for telling my husband the nanny is in charge?

I want to preface this by saying that I am aware this is a very privileged issue but I’m trying to get some perspective on my opinion.

My husband and I have 3 kids that are 10 months, 3 years and 6 years old. My husband has a high profile job and it means he’s gone often. I work a regular 9-5. We originally used daycare for our oldest but my middle was born right when the pandemic began, so we hired a nanny. She originally worked when I did. But by the time baby came around, I was very overwhelmed doing bath and bedtime on my own, on top of developing postpartum depression. After a breakdown, we spoke with the nanny and she agreed to adjust her hours so she’s helping me with dinner, bath and bed.

We’ve gotten close over the past 6 months doing this. In many ways, she’s become like a third parent to the kids. She’s so good with them. We’ve created a routine that works well. I tend to the baby during bath and bed, she handles the older 2. It’s a nice rhythm and my mental health has gotten so much better.

My husband isn’t traveling all the time but most nights, he isn’t even home for dinner and bed. He will help me weekends he’s home. But because he’s gone so often, he’s reluctant to be firm with the kids.

There are times he’s come home when our nanny is there. He tries to help her with bath and bed, but allows the boys to rough house, lets them break the routine and it seriously throws them off and delays bedtime.

My nanny shared with me she feels awkward. Obviously she doesn’t want to undermine her employer but it just makes her job harder. But my husband also doesn’t want her to go home when he arrives as he says he can’t handle it alone.

I told him if that’s the case, then he needs to defer to the nanny and follow her lead. She knows our boys best and she has to deal with the aftermath when they don’t listen and give her a hard time.

My husband feels that she’s just an employee and he’s the dad. His salary does pay for her. However, I don’t feel this is fair to her.

I told him he either follows her lead for bed and bath or he doesn’t help at all. He told me I’m allowing the nanny to take over and replace him. AITA?"""

In [145]:
encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}

outputs = trainer.model(**encoding)

In [146]:
logits = outputs.logits
logits.shape

torch.Size([1, 5])

In [147]:
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= 0.5)] = 1
# turn predicted id's into actual label names
predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
print(predicted_labels)

['nta']


## Save the Model For Later Use

In [171]:
trainer.save_model("./saved_models/store/AITAclassmodel/")