There are two types of language modeling, causal and masked. Causal language models are frequently used for text generation. These models can be used for creative applications like choosing your own text adventure or for an intelligent coding assistant like Copilot or CodeParrot.

Causal language modeling predicts the next token in a sequence of tokens, and the model can only attend to tokens on the left. This means the model cannot see future tokens. GPT-2 is an example of a causal language model.

This guide illustrates how to:
1. Finetune DistilGPT2 on the r/askscience subset of the ELI5 dataset.
2. Use the finetuned model for inference.

# Libraries

In [1]:
pip install transformers datasets evaluate

Note: you may need to restart the kernel to use updated packages.


In [2]:
import math
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, \
AutoModelForCausalLM, TrainingArguments, Trainer

mps_device = torch.device("mps")

# Load Data

In [3]:
# Load a smaller subset of the r/askscience subset of the ELI5 dataset from the 🤗 Datasets library
# Experiment and make sure everything works before spending more time training on the full dataset
eli5 = load_dataset("eli5_category", split="train[:5000]")

# Split the dataset into train and test sets
eli5 = eli5.train_test_split(test_size=0.2)

# Inspect an example
# NB: the output may look like a lot, but we’re only really interested in the text field
# This is an unsupervised task. Labels not required because the next word is the label.
eli5["train"][0]

{'q_id': '5oeiz7',
 'title': "Why is it easier to fall asleep when you're cold as opposed to it being harder when you're hot?",
 'selftext': '',
 'category': 'Other',
 'subreddit': 'explainlikeimfive',
 'answers': {'a_id': ['dcip69u', 'dcip85b'],
  'text': ["Cells are more excited and move more when heated, when cooled they slow down. Sort of like when you microwave popcorn, if it's heated it'll pop and jump all around full of energy :) When there is an absence of heat (aka, cold) it is restful and unmoving. Your brain can rest easier when it isn't popping :)",
   "The body naturally cools down when it wants to sleep or is sleeping. A low power state, if you will. The body is better at regulating temperature when it's cooler and under blankets that when it's sweating and under blankets. Just a guess."],
  'score': [8, 4],
  'text_urls': [[], []]},
 'title_urls': ['url'],
 'selftext_urls': ['url']}

# Preprocessing

In [4]:
# Load DistilGPT2 tokenizer to process the 'text' subfield
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

In [5]:
# notice the 'text' subfield is actually nested inside answers. 
# extract the 'text' subfield from its nested structure with the flatten method
eli5 = eli5.flatten()
eli5["train"][0]

{'q_id': '5oeiz7',
 'title': "Why is it easier to fall asleep when you're cold as opposed to it being harder when you're hot?",
 'selftext': '',
 'category': 'Other',
 'subreddit': 'explainlikeimfive',
 'answers.a_id': ['dcip69u', 'dcip85b'],
 'answers.text': ["Cells are more excited and move more when heated, when cooled they slow down. Sort of like when you microwave popcorn, if it's heated it'll pop and jump all around full of energy :) When there is an absence of heat (aka, cold) it is restful and unmoving. Your brain can rest easier when it isn't popping :)",
  "The body naturally cools down when it wants to sleep or is sleeping. A low power state, if you will. The body is better at regulating temperature when it's cooler and under blankets that when it's sweating and under blankets. Just a guess."],
 'answers.score': [8, 4],
 'answers.text_urls': [[], []],
 'title_urls': ['url'],
 'selftext_urls': ['url']}

In [11]:
# after flattening, text is now its own field - answers.text
# Instead of tokenizing each sentence separately, convert the list to a string so you can jointly tokenize them
# Note that unrequired columns can be removed during this step
def preprocess_function(examples):
    return tokenizer([" ".join(x) for x in examples["answers.text"]])

tokenized_eli5 = eli5.map(
    preprocess_function,
    batched=True,
    num_proc=48,
    remove_columns=eli5["train"].column_names,
)

Map (num_proc=48):   0%|          | 0/4000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1049 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1058 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1924 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3080 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (7291 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (4615 > 1024). Running this sequence through the model will result in indexing errors


Map (num_proc=48):   0%|          | 0/1000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1727 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1145 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2788 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1187 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1031 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

In [None]:
# The dataset now contains the token sequences...
# but some of these are longer than the maximum input length for the model
# Define block_size for splitting; should be shorter than the maximum input length but short enough for GPU RAM
block_size = 128

# Second preprocessing function to concatenate all the sequences
# split the concatenated sequences into shorter chunks defined by block_size
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
lm_dataset = tokenized_eli5.map(group_texts, batched=True, num_proc=4)

In [None]:
# Create a batch of examples with data collator
# NB: It’s more efficient to dynamically pad the sentences to the longest length in a batch during collation...
# instead of padding the whole dataset to the maximum length

# Use the end-of-sequence token as the padding token and set mlm=False. 
# This will use the inputs as labels shifted to the right by one element
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Training

In [None]:
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
#model.to(mps_device)

In [None]:
training_args = TrainingArguments(
    output_dir="causal_language_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["test"],
    data_collator=data_collator,
)

trainer.train()

# Evaluation

In [None]:
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")