# <font color="#003660">Applied Machine Learning for Text Analysis (M.184.5331)</font>


# <font color="#003660">Week 6: Generating Texts with Transformers</font>

# <font color="#003660">Notebook 2: Fine-tuning a Masked Language Model</font>

<center><br><img width=256 src="https://raw.githubusercontent.com/olivermueller/aml4ta-2021/main/resources/dag.png"/><br></center>

<p>
<center>
<div>
    <font color="#085986"><b>By the end of this lesson, you ...</b><br><br>
        ... are able to fine-tune a masked language model on your own data, which is useful to train a decoder model.
    </font>
</div>
</center>
</p>

The following content is heavily inspired by the following excellent sources:


*   Tunstall et al. (2021): Natural Language Processing with Transformers. O'Reilly. https://www.oreilly.com/library/view/natural-language-processing/9781098103231/
*   Hugging Face (2021): Transformer Models - Hugging Face Course. https://huggingface.co/course/



# How to Fine-tune a Masked Language Model?

For many NLP applications, you can simply take a pre-trained model from the Hugging Face Hub and fine-tune it directly on your data for the task at hand (e.g., sentiment analysis). This approach will usually produce good results, provided that the corpus used for pretraining is not too different from the corpus used for fine-tuning.

However, if your dataset is very different from the dataset used for pre-training, this approach might fail. In such cases, you can boost the performance of many downstream tasks by first fine-tuning *the language model* (not the model for the actual task of interest!) on in-domain data.

The figure below illustrates this process, which was first proposed by [Howard and Ruder in 2018](https://arxiv.org/abs/1801.06146).

<center><img width=600 src="https://raw.githubusercontent.com/olivermueller/aml4ta-2021/main/resources/ulmfit.png"/><br></center>

In this notebook, we go through this process for fine-tuning a [masked langugae model](https://youtu.be/mqElG5QJWUg). 

# Import Packages

In [None]:
!pip install transformers[sentencepiece]
!pip install datasets

In [None]:
import pandas as pd
import numpy as np
import math
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling
from transformers import TrainingArguments
from transformers import Trainer

# Load Pre-trained Model

First, we load a model for mask language modeling and a corresponding tokenizer from the model hub.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
model_name = "distilbert-base-uncased"

In [None]:
model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Testdrive the Model 🚗

Let's see what missing words the pre-trained model generates.

In [None]:
text = "This is a great [MASK]."

In [None]:
input_ids = tokenizer(text, return_tensors="pt").to(device)
input_ids

In [None]:
token_logits = model(**input_ids).logits
token_logits

In [None]:
token_logits.shape

Identify the location of the [MASK] and retrieve its logits. We then pick the [MASK] candidates with the highest logits.

In [None]:
mask_token_index = torch.where(input_ids["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
mask_token_logits

In [None]:
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
top_5_tokens

Replace the [MASK] by the top candidates.

In [None]:
for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

# Prepare a Dataset for Fine-tuning

Now let's fine tune the model on domain-specific texts. We will use the famous IMDB movie reviews dataset for this purpose.

In [None]:
imdb_dataset = load_dataset("imdb")
imdb_dataset

In [None]:
imdb_dataset["train"][0]

Tokenize the texts and remove unneeded columns.

In [None]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    return result

In [None]:
tokenized_datasets = imdb_dataset.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)
tokenized_datasets

In [None]:
tokenized_datasets["train"][0]

For masked language modeling, a [common preprocessing step](https://youtu.be/8PmhEIXhBvI) is to concatenate all the examples and then split the whole corpus into chunks of equal size. This is quite different from our usual approach, where we simply tokenize individual examples. Why concatenate everything together? The reason is that individual examples might get truncated if they’re too long, and that would result in losing information that might be useful for the language modeling task! 

The function below, taken from https://huggingface.co/course/chapter7/3?fw=pt, does exactly this, and some other preprocessing steps.

In [None]:
chunk_size = 128

def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

During the above preprocessing, we have added a new column `labels` to the dataset. The labels are simply the IDs of the tokens from the input sequence. As you will see shortly, during training we will replace some IDs of the input sequences by [MASK]. After the replacement, the labels column will still contain the "truth".

In [None]:
lm_datasets["train"][1]["input_ids"][0:10]

In [None]:
lm_datasets["train"][1]["labels"][0:10]

# Fine-tune with Trainer API

To replace some input tokens by [MASK], we can use `DataCollatorForLanguageModeling()` function, which will perform the replacement on the fly during training.

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [None]:
samples = [lm_datasets["train"][i] for i in range(2)]

for chunk in data_collator(samples)["input_ids"]:
  print(f"\n'>>> {tokenizer.decode(chunk)}'")

Let's downsample our dataset so that we don't have to wait tooo long.

In [None]:
train_size = 10000
test_size = int(0.1 * train_size)

downsampled_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size, test_size=test_size, seed=42
)

downsampled_dataset

Now we can finally start fine-tuning our model with the Trainer API.

In [None]:
batch_size = 128
logging_steps = len(downsampled_dataset["train"]) // batch_size

training_args = TrainingArguments(
    output_dir=f"{model_name}-mlm-finetuned-imdb",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    logging_steps=logging_steps,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    data_collator=data_collator,
)

Before we start fine-tuning, we calculate the original model's (pre-trained, but not fine-tuned) [perplexity](https://youtu.be/NURcDHhYe98) as a benchmark. 

In [None]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

Perform the fine-tuning!

In [None]:
trainer.train()

Calculate perplexity again.

In [None]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

# Testdrive the Fine-tuned Model 🛫

Let's see what missing tokens our fine-tuned model predicts (the code below is a copy&paste from above).

In [None]:
text = "This is a great [MASK]."
input_ids = tokenizer(text, return_tensors="pt").to(device)
token_logits = model(**input_ids).logits
mask_token_index = torch.where(input_ids["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")