# transformers: LM finetuning

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import numpy as np
import torch
import evaluate
from transformers import (
    set_seed,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer
)

from hf_utils import (
    load_yelp,
    load_imdb,
    DistilBertSeqClassif,
    DistilGPT2SeqClassif
)

In [None]:
# set random seed manually
set_seed(123)

## Load data

In [None]:
# load data
ds = load_imdb(tiny=True)
# ds = load_yelp(tiny=True)

print(ds)

In [None]:
# get label names
label_names = ds['train'].features['label'].names

print(label_names)

## Create model

In [None]:
# set model name
# model_name = 'google-bert/bert-base-cased'
model_name = 'distilbert/distilbert-base-uncased'

In [None]:
# create tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

print(tokenizer)

In [None]:
# create seq. classifier
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    device_map='cpu',
    torch_dtype=torch.bfloat16, # use brain floating point format
    num_labels=len(label_names), # set number of target labels
    id2label={idx: label for idx, label in enumerate(label_names)},
    label2id={label: idx for idx, label in enumerate(label_names)}
)

print(model)

In [None]:
# create BERT-like seq. classifier
# model = DistilBertSeqClassif(
#     num_labels=len(label_names),
#     num_hidden=None,
#     activation='leaky_relu',
#     drop_rate=None
# )

# print(model)

In [None]:
# create GPT-like seq. classifier
# model = DistilGPT2SeqClassif(
#     num_labels=len(label_names),
#     num_hidden=None,
#     activation='leaky_relu',
#     drop_rate=None
# )

# print(model)

## Train

In [None]:
# preprocess data
def preprocess(examples):
    return tokenizer(
        examples['text'],
        add_special_tokens=True, # add special tokens at sentence start/end (CLS token at start)
        padding=False, # turn off padding
        truncation=True, # turn on truncation to max. length
        return_tensors=None # do not return PyTorch tensors
    )

ds_preprocessed = ds.map(preprocess, batched=True)

print(ds_preprocessed)

In [None]:
# create data collator
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding=True, # turn on padding (when assembling batches)
    return_tensors='pt' # return PyTorch tensors
)

In [None]:
# create evaluation function
accuracy_metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)

    return accuracy_metric.compute(
        predictions=preds,
        references=labels
    )

In [None]:
# set training args
training_args = TrainingArguments(
    output_dir='your-model',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    push_to_hub=False
)

In [None]:
# create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_preprocessed['train'],
    eval_dataset=ds_preprocessed['test'],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
# start training
trainer.train()