# Fine-Tune a Model with Hugging Face's `Trainer` API

* Loosely derived from: https://huggingface.co/docs/transformers/tasks/sequence_classification
* Train, evaluate, output metrics

### `pip install` necessary packages, then restart runtime.

In [None]:
# uncomment lines below, run this cell, then restart the runtime before continuing
# !pip install transformers > out1
# !pip install datasets > out2
# !pip install numpy==1.23.4 > out3
# !pip install sentencepiece > out4
# !pip install wandb > out5

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy
import time
import torch

from datasets import load_dataset, Dataset
from sklearn import metrics
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, ConfusionMatrixDisplay
from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification, pipeline, AutoConfig

In [None]:
# Log in to your W&B account
import wandb
wandb.login()

### Verify if a GPU is available; print some details

In [None]:
print(f'Is a GPU available? {torch.cuda.is_available()}')
device_count = torch.cuda.device_count()
print(f'Number of GPUs available: {device_count}')
print(f'Current GPU index: {torch.cuda.current_device()}')
for i in range(device_count):
    print(f'Device {i}:')
    print(f'\t{torch.cuda.device(i)}')
    print(f'\t{torch.cuda.get_device_name(i)}')

In [None]:
!nvidia-smi

### Load model and dataset

In [None]:
# label2id: maps the class names to integers.
# id2label: maps the integers to class names.
id2label = {0:"NEGATIVE", 1:"POSITIVE"}
label2id = {"NEGATIVE":0, "POSITIVE":1}

# load pre-trained BERT embeddings and associated tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',id2label = id2label, label2id=label2id)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [None]:
rt = load_dataset('rotten_tomatoes')
rt

Examine a negative and a positive example from the `rotten_tomatoes` dataset

In [None]:
# positive review:
print(rt['train'][0])
print()
# negative review:
print(rt['train'][-7])

#### Tokenize dataset

In [None]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

train_dataset, val_dataset, test_dataset = load_dataset('rotten_tomatoes', split=['train', 'validation', 'test'])
train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
val_dataset = val_dataset.map(tokenize, batched=True, batch_size=len(val_dataset))
test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
val_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
# model._get_name()
# model.name_or_path

### Train your model
* Set up a function to compute metrics: accuracy, precision, recall, f1
* Define training arguments and instantiate `Trainer`

In [None]:
# For wandb, declare some global variables.
dtg_run = time.strftime(f'%d%H%M%b%y').upper() #ex. '112036OCT21' ... add underscores for readability if desired

PROJ = 'basic-demo'
TAGS = ['rotten_tomatoes','dsp']
TITLE = f"rot_tom_finetuned_model_{dtg_run}"
BASE_MODEL = model._get_name()
NUM_EPOCHS = 3
BATCH_SIZE = 8
RNDM_SEED = 42
STEPS = 200

In [None]:
# Determine parameters to track for each run
wandb_config_dict = dict(base_model=BASE_MODEL,
                         epochs=NUM_EPOCHS,
                         batch_size=BATCH_SIZE,
                         seed=RNDM_SEED,
                         steps=STEPS)

# Initialize
wandb_init_dict = dict(name=TITLE,
                       project=PROJ,
                       tags=TAGS,
                       notes="this is an example run",
                       config=wandb_config_dict)

wandb.init(**wandb_init_dict)


In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# MODIFIED
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    seed = RNDM_SEED,
    do_eval=True,
    logging_steps=STEPS,
    logging_dir='./logs',
    report_to="wandb" # NEW ARGUMENT FOR THIS LESSON 
) # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.set_logging

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
# trainer.args

#### Training Loop

In [None]:
# 8530 training examples
# 8530/16 -> np.ceil -> 534 -> 534 * 3 epochs = 1602 steps total.
# 8530/8 -> np.ceil -> 3201 -> 3201 * 3 epochs = 3201 steps total.
trainer.train()

### Test your model

In [None]:
test_out = trainer.predict(test_dataset=test_dataset)

### In a notebook, don't forget to finish logging to W&B!

In [None]:
wandb.finish()

**Examine the elements of your test output**

In [None]:
test_out

**Verify the number of positive/negative predictions**

In [None]:
test_out.predictions.argmax(-1)

In [None]:
# Count all _predicted_ "Positive" classifications
np.count_nonzero(test_out.predictions.argmax(-1))

**Make a confusion matrix**

In [None]:
# https://en.wikipedia.org/wiki/Confusion_matrix

# Our plot shows:
# True Negative, False Positive
# False Negative, True Positive

# https://www.w3schools.com/python/python_ml_confusion_matrix.asp

preds = test_out.predictions.argmax(-1)
actual = test_out.label_ids

# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html#sklearn.metrics.confusion_matrix
# tn, fp, fn, tp
results_confusion_matrix = metrics.confusion_matrix(actual, preds)

cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix=results_confusion_matrix, display_labels = ["Negative", "Positive"])

# Choose different colors: https://matplotlib.org/stable/tutorials/colors/colormaps.html
cm_display.plot(cmap='hot')
# cm_display.plot(cmap='inferno')
# cm_display.plot(cmap='gray')
# cm_display.plot(cmap='Reds')
# cm_display.plot(cmap='binary')
# cm_display.plot(cmap='flag')
plt.show()

**Examine misclassified examples**

In [None]:
# create array of indices in test_dataset where examples were misclassified
# sum of any misclassified example in a binary classification == 1
misclassified = np.argwhere((actual+preds)==1).reshape(-1)
print(misclassified[:4],misclassified[-1:-5:-1])

print('Positive misclassified as negative:')
print('===================================')
for i,ex in enumerate(np.concatenate((misclassified[:4],misclassified[-1:-5:-1]))):
    if i == 4: 
        print('\nNegative misclassified as positive:')
        print('===================================\n')
    print(rt['test'][int(ex)]) # cast np.int64 as int