There are two types of language modeling, causal and masked. Causal language models are frequently used for text generation. These models can be used for creative applications like choosing your own text adventure or for an intelligent coding assistant like Copilot or CodeParrot.

Causal language modeling predicts the next token in a sequence of tokens, and the model can only attend to tokens on the left. This means the model cannot see future tokens. GPT-2 is an example of a causal language model.

This guide illustrates how to:
1. Finetune DistilGPT2 on the r/askscience subset of the ELI5 dataset.
2. Use the finetuned model for inference.

# Libraries

In [None]:
pip install transformers datasets evaluate

In [None]:
import math
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, \
AutoModelForCausalLM, TrainingArguments, Trainer, pipeline

mps_device = torch.device("mps")

# Load Data

In [None]:
# Load a smaller subset of the r/askscience subset of the ELI5 dataset from the 🤗 Datasets library
# Experiment and make sure everything works before spending more time training on the full dataset
eli5 = load_dataset("eli5_category", split="train[:5000]")

# Split the dataset into train and test sets
eli5 = eli5.train_test_split(test_size=0.2)

# Inspect an example
# NB: the output may look like a lot, but we’re only really interested in the text field
# This is an unsupervised task. Labels not required because the next word is the label.
eli5["train"][0]

# Preprocessing

In [None]:
# Load DistilGPT2 tokenizer to process the 'text' subfield
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

In [None]:
# notice the 'text' subfield is actually nested inside answers. 
# extract the 'text' subfield from its nested structure with the flatten method
eli5 = eli5.flatten()
eli5["train"][0]

In [None]:
# after flattening, text is now its own field - answers.text
# Instead of tokenizing each sentence separately, convert the list to a string so you can jointly tokenize them
# Note that unrequired columns can be removed during this step
def preprocess_function(examples):
    return tokenizer([" ".join(x) for x in examples["answers.text"]])

tokenized_eli5 = eli5.map(
    preprocess_function,
    batched=True,
    num_proc=24,
    remove_columns=eli5["train"].column_names,
)

In [None]:
# The dataset now contains the token sequences...
# but some of these are longer than the maximum input length for the model
# Define block_size for splitting; should be shorter than the maximum input length but short enough for GPU RAM
block_size = 128

# Second preprocessing function to concatenate all the sequences
# split the concatenated sequences into shorter chunks defined by block_size
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
lm_dataset = tokenized_eli5.map(group_texts, batched=True, num_proc=4)

In [None]:
# Create a batch of examples with data collator
# NB: It’s more efficient to dynamically pad the sentences to the longest length in a batch during collation...
# instead of padding the whole dataset to the maximum length

# Use the end-of-sequence token as the padding token and set mlm=False. 
# This will use the inputs as labels shifted to the right by one element
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Training

In [None]:
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
model.to(mps_device)

In [None]:
training_args = TrainingArguments(
    output_dir="causal_language_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["test"],
    data_collator=data_collator,
)

trainer.train()

# Evaluation

In [None]:
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

# Inference

In [None]:
# Use trained model for inference
# Create a prompt to generate text from
prompt = "Somatic hypermutation allows the immune system to"

In [None]:
# Inference using a pipeline object
generator = pipeline("text-generation", model="causal_language_model")
generator(prompt)