## Causal Language Modeling (CLM) - Preprocess, Training and Inference

In this lab, we will explore **Causal Language Modeling (CLM)**, which is a core task in training autoregressive language models like GPT-2. CLM is the process of predicting the next word in a sequence, given the previous words. This type of modeling forms the backbone of text generation tasks, where the model learns to generate coherent text by focusing only on previous tokens in the sequence.

The lab is divided into three major sections:
1. **Preprocessing**: Preparing a dataset and the labels for the CLM task and tokenizing them.
2. **Training**: Fine-tuning a pre-trained language model like GPT-2 on a specific dataset using the CLM task.
3. **Inference**: Evaluating the model’s performance by generating text based on input prompts.

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset

### 1. Preprocessing

1. **Load the Domain-Specific Dataset**:
   - The first step in fine-tuning GPT-2 is to load a dataset that is specific to the domain of interest. In this case, we are using a publicly available **medical dataset** from the **PubMed** collection. PubMed contains a vast number of medical articles, and fine-tuning on such a dataset can help GPT-2 generate more accurate and context-specific medical text.
   - The dataset we're using is `"japhba/pubmed_simple"`, which is a simplified version of PubMed data. This dataset can be easily accessed using the `datasets` library from Hugging Face.


In [2]:
ds = load_dataset("japhba/pubmed_simple", split="train")
train_dataset = ds.shuffle(seed=42).select(range(1000))
eval_dataset = ds.shuffle(seed=42).select(range(1000, 1500))

We can check the contents of any one of the examples in the dataset to understand the structure of the data.

In [None]:
train_dataset[0]

We see that the each entry is a dictionary with two keys:
- "abstract": The abstract of the medical article
- "country": The country where the article was published (we will not use this information in this lab)

2. **Tokenize the Dataset**:
   - The next step is to **tokenize** the data so that it can be processed by GPT-2.
   - GPT-2 does not natively use a padding token, since it does not require fixed length inputs. For this reason, we will substite it with the EOS token as suggested by transformers library (remember, we are also passing an attention mask, so whatever value is used for padding will be ignored by the model!)

In [None]:
# TODO: Tokenize the Dataset using the GPT-2 tokenizer
# Hint: Use the `map` method of the dataset object
model_name = "gpt2"
tokenizer = ...
tokenizer.pad_token = ...  # Set padding token to EOS token

def tokenize_function(samples):
    # Tokenize the text column `abstract` in the dataset (set `truncation=True` and `padding="max_length"`)
    # NOTE: You can either use tokenzier as a global variable, or
    # use a closure to capture the tokenizer variable
    return ...

# NOTE: you should map the dataset using the tokenize function. 
# (You can optionally consider using the argument `batched=True` for better performance,
# as long as tokenize_function can handle the batch! [note2: the tokenizer can do that!])
tokenized_dataset = ...
tokenized_dataset_eval = ...


3. **Add labels to the Dataset for Next Token Prediction**:
   - In this step, we will add the **labels** that will be used during the next token prediction task. 
   - In autoregressive language modeling, the **labels** represent the same sequence as the input, shifted one token to the right. This is because the model is trained to predict the next token in the sequence given the previous tokens.
   - The shifting of the tokens is already handled automatically by the `Trainer` class. We just pass an extra attribute named `labels` to the dataset (when this argument is passed to the model, it will know to compute the loss for us!)

In [5]:
# TODO: Add labels to the dataset
def add_labels(samples):
    # NOTE: The labels, in causal modeling, should be the same as the input_ids
    samples["labels"] = ...
    return samples

tokenized_dataset = ...
tokenized_dataset_eval = ...

### 2. Training

1. **Fine-Tune the GPT-2 Model**:
   - Set up the model and finetune it using the medical dataset. 
   - The pipeline to be followed is the same that we have already seen in the previous lab (`lab03 - 01-bert`)

In [None]:
# Set Training Parameters
training_args = TrainingArguments(
    output_dir="./gpt2-medical-finetuned",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=6,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=100,
    eval_steps=10,
    eval_strategy="steps",
)

# TODO: Initialize GPT-2 Model
model = ...

# TODO: Fine-Tune the Model with the `Trainer` method and pass also the eval_dataset 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset_eval,
)

trainer.train()

# Save the Fine-Tuned Model
model.save_pretrained("./gpt2-medical-finetuned")


### 3. Inference

1. **Compare Text Generation Before and After Fine-Tuning**:
   - Generate text using both the original pre-trained GPT-2 model and the fine-tuned model.
   - Provide the same input prompt and observe the differences in the outputs.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# TODO: Load the pre-trained and fine-tuned GPT-2 models
pretrained_model = ...
finetuned_model = ...

# TODO: Tokenize the prompt
tokenizer = ...
tokenizer.pad_token = ...

prompt = "The patient presents with chest pain and shortness of breath."

inputs = ...
input_ids = inputs['input_ids']

In [None]:
# TODO: Generate Output of the Model Before Fine-Tuning (use `generate` and then `decode` methods of the model and tokenizer) 
output_pretrained = ...
generated_pretrained = ...

print("**Before Fine-Tuning (Pre-Trained GPT-2)**:")
print(generated_pretrained)

In [None]:
# TODO: Generate Output of the Model After Fine-Tuning
output_finetuned = ...
generated_finetuned = ...

print("**After Fine-Tuning (Fine-Tuned on Medical Dataset)**:")
print(generated_finetuned)

<span style="color:red">Extra stuff!</span>

Training the model in this way produces batches with potentially very different lengths. This can be inefficient, as the model will have to pad the sequences to the length of the longest sequence in the batch.

To avoid this, we can use a technique called **Dynamic Padding**. This technique groups the sequences in the batch by length and pads them to the length of the longest sequence in each group. This way, the model only has to pad the sequences to the length of the longest sequence in each group, which can significantly reduce the amount of padding required.

As a first exercise, quantify the number of pad tokens being used in various situations:
1. You pad all batches to the maximum allowed sequence length (1024 for GPT-2, this is what we used so far)
2. You pad the entire batch to the length of the longest sequence in the batch (generate the batches by randomly sampling sentences)
3. You pad the entire batch to the length of the longest sequence in the batch (generate the batches by placing sentences of similar lengths together)

Next, introduce dynamic padding and compare the execution times of the previous execution and the one with dynamic padding.

You can use the following resources to help you with this exercise:
- `group_by_length` parameter ([TrainingArguments](https://huggingface.co/docs/transformers/v4.46.0/en/main_classes/trainer#transformers.TrainingArguments.group_by_length)) (parameter to group together samples with similar lengths)
- `DataCollatorForSeq2Seq` ([DataCollatorForSeq2Seq](https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorForSeq2Seq)) (collator function that aggregates samples into batches and pads them to the maximum length of the batch)

Note: you may find that the validation losses you observe may be different from the previous ones. This is because the cross entropy loss is computed as an average across tokens, and the number of tokens in a batch can vary depending on the padding strategy used.