# Fine-Tuning GPT-2 for text summarisation using Hugging Face *Transformers*

source: https://github.com/jwhogg

### Check we have GPU to train on:

In [None]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

### Get Dataset:
- we will be using the CNN/DailyMail Dataset, which has articles, and their corresponding summaries
- `raw_dataset` is a DataSet Dict with 'train'/'validation'/'test' split
- for CNN/DailyMail, version must be specified, as it has V1-3
```
DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})
```

In [None]:
# !pip install datasets
# from datasets import load_dataset
# raw_datasets = load_dataset("cnn_dailymail", "3.0.0")

from datasets import load_dataset

raw_datasets = load_dataset("imdb")

### Get the model tokeniser:

In [None]:
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
if tokenizer.pad_token is None: #assigning a value to the pad token so we can pad up to gpt2's input length
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

### Create encodings for train/test using tokeniser:
- 'map()' lets us run the tokenizer function on the train/valid/test dicts individually, tokenizing each column ('article' & 'highlights')

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

##### Create small datasets for development

In [None]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]

### Import Model

In [None]:
from transformers import GPT2Tokenizer, GPT2Model

model = GPT2Model.from_pretrained('gpt2-medium').to(device)

### Training
Transformers has a Trainer class that can speed up training of models, and does a lot of the work for us
Trainer is defined as a dict of arguments and a compute_metrics function, but first we need to define these:
Training args:
- use just default args to start with
- add arg: evaluation_strategy="epoch" to report metrics every epoch

In [None]:
from transformers import TrainingArguments
#if the code throws an 'accelerate'-related error, try to re-install transformers with relevant torch dependencies, then restart the notebook
#!pip install transformers[torch]
training_args = TrainingArguments("test_trainer")

### Configure training metrics
Trainer can take a `compute_metrics()` function, which takes predictions and labels (in a tuple), and returns a dict with metric names and values
we can use the Datasets library to get access to common metrics- 'accuracy' is one of these

In [None]:
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

### Define Trainer

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

### Train and Evaluate

In [None]:
trainer.train()
trainer.evaluate()

*We* are now done! the [training args](https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/trainer#transformers.TrainingArguments) or dataset can be tweaked to try to improve performance

**Remember to save your model!**
```python
model.save_pretrained("path/to/model.pt")
```
- for google colab, you will need to download the model to your local machine, as the colab files are wiped clean when the runtime ends