## Finetuning Pre-trained Language Models for Biomedical Lay Summarization

First, let's try to check GPU specifications

In [None]:
# crash colab to get more RAM
!kill -9 -1

To check that we are having enough RAM we can run the following command.
If the randomely allocated GPU is too small, the above cells can be run 
to crash the notebook hoping to get a better GPU.

In [None]:
!nvidia-smi

Sun Apr 30 19:20:18 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    50W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
%%capture
! pip install datasets transformers rouge-score nltk

Let's start by loading and preprocessing the dataset.



In [None]:
import transformers
print(transformers.__version__)

4.28.1


In [None]:
from datasets import load_dataset, load_metric

Next, we download the PLOS & eLife train and validation dataset

In [None]:
train_dataset = load_dataset("tomasg25/scientific_lay_summarisation","elife",split="train")
val_dataset = load_dataset("tomasg25/scientific_lay_summarisation","elife", split="validation")

Downloading builder script:   0%|          | 0.00/6.32k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.56k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading and preparing dataset scientific_lay_summarisation/elife (download: 425.22 MiB, generated: 275.99 MiB, post-processed: Unknown size, total: 701.22 MiB) to /root/.cache/huggingface/datasets/tomasg25___scientific_lay_summarisation/elife/1.0.0/bf538a761aabe5d3389d3e4aa028b094528c5d7a6225147b843a63a3803e79f4...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/357M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/88.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/4346 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/241 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/241 [00:00<?, ? examples/s]

Dataset scientific_lay_summarisation downloaded and prepared to /root/.cache/huggingface/datasets/tomasg25___scientific_lay_summarisation/elife/1.0.0/bf538a761aabe5d3389d3e4aa028b094528c5d7a6225147b843a63a3803e79f4. Subsequent calls will reuse this data.


It's always a good idea to take a look at some data samples. Let's do that here.

In [None]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=4):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

We can see that the input data is the `article` - a scientific report and the target data is the `summary` - a lay summary of the report.

In [None]:
from transformers import AutoTokenizer

In [None]:
!pip install sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentencepiece
  Downloading sentencepiece-0.1.98-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m61.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.98


In [None]:
tokenizer = AutoTokenizer.from_pretrained("ainize/bart-base-cnn") # flax-community/t5-base-cnn-dm, ainize/bart-base-cnn, google/mt5-small, allenai/led-base-16384 sshleifer/distill-pegasus-xsum-16-4

Downloading (…)okenizer_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

In [None]:
max_input_length = 1024 ##8192
max_output_length = 512
batch_size = 2

Now, let's write down the input data processing function that will be used to map each data sample to the correct model format.
As explained earlier `article` represents here our input data and `summary` is the target data. 

In [None]:
def process_data_to_model_inputs(batch):
    # tokenize the inputs and labels
    inputs = tokenizer(
        batch["article"],
        padding="max_length",
        truncation=True,
        max_length=max_input_length,
    )
    outputs = tokenizer(
        batch["summary"],
        padding="max_length",
        truncation=True,
        max_length=max_output_length,
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask

    # create 0 global_attention_mask lists
    batch["global_attention_mask"] = len(batch["input_ids"]) * [
        [0 for _ in range(len(batch["input_ids"][0]))]
    ]

    # since above lists are references, the following line changes the 0 index for all samples
    batch["global_attention_mask"][0][0] = 1
    batch["labels"] = outputs.input_ids

    # We have to make sure that the PAD token is ignored
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels]
        for labels in batch["labels"]
    ]

    return batch

Great, having defined the mapping function, let's preprocess the training data

In [None]:
train_dataset = train_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "summary"],
)

Map:   0%|          | 0/4346 [00:00<?, ? examples/s]

and validation data

In [None]:
val_dataset = val_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "summary"],
)

Map:   0%|          | 0/241 [00:00<?, ? examples/s]

Finally, the datasets should be converted into the PyTorch format as follows.

In [None]:
train_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)
val_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)

Let's load the model via the `AutoModelForSeq2SeqLM` class.

In [None]:
from transformers import AutoModelForSeq2SeqLM

In [None]:
led = AutoModelForSeq2SeqLM.from_pretrained("ainize/bart-base-cnn", gradient_checkpointing=True, use_cache=False)  ## ainize/bart-base-cnn sshleifer/distill-pegasus-xsum-16-4,  allenai/led-base-16384

During training, we want to evaluate the model on Rouge, the most common metric used in summarization, to make sure the model is indeed improving during training.

In [None]:
# set generate hyperparameters
led.config.num_beams = 2
led.config.max_length = 512
led.config.min_length = 100
led.config.length_penalty = 2.0
led.config.early_stopping = True
led.config.no_repeat_ngram_size = 3

Next, we also have to define the function the will compute the `"rouge"` score during evalution.

In [None]:
rouge = load_metric("rouge")

  rouge = load_metric("rouge")


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

In [None]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

Now, we're ready to start training. Let's import the `Seq2SeqTrainer` and `Seq2SeqTrainingArguments`.

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [None]:
# enable fp16 apex training
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    output_dir="./",
    logging_steps=5,
    eval_steps=10,
    save_steps=10,
    save_total_limit=2,
    gradient_accumulation_steps=4,
    num_train_epochs=10,
)

The training arguments, along with the model, tokenizer, datasets and the `compute_metrics` function can then be passed to the `Seq2SeqTrainer`

In [None]:
trainer = Seq2SeqTrainer(
    model=led,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Rouge2 Precision,Rouge2 Recall,Rouge2 Fmeasure
10,3.4728,3.178955,0.1034,0.0747,0.0823
20,3.4026,3.0512,0.1186,0.0752,0.0883
30,3.246,2.990401,0.1298,0.0732,0.0909
40,3.2144,2.949623,0.1371,0.0822,0.0998
50,3.1847,2.928597,0.1256,0.0806,0.0958
60,3.0682,2.897598,0.132,0.0948,0.108
70,3.1057,2.877398,0.1294,0.0882,0.1027
80,2.9894,2.860825,0.1362,0.0794,0.0982
90,3.0591,2.83892,0.1287,0.0918,0.105
100,3.0211,2.830138,0.1291,0.0961,0.108




KeyboardInterrupt: ignored