# Text simplifcation model
**Credits**: This work has been adapted from the example code provided in the `Transformers` library released under the `Apache license`.

## Fine-tuning T5-base model

In [1]:
import os

parent_dir = os.path.dirname(os.getcwd())
data_dir = f'{parent_dir}/data'
results_dir = f'{parent_dir}/results'

model_name = "t5-base"
model_dir = f"{model_name.replace('/', '-')}-checkpoints"

# Check if model directory exists and is not empty
if os.path.exists(model_dir) and any([item.startswith('checkpoint-') for item in os.listdir(model_dir)]):
    model_checkpoint = os.path.join(model_dir, os.listdir(model_dir)[-1])
else:
    model_checkpoint = model_name

## Loading the dataset

We will use the [Datasets](https://github.com/huggingface/datasets) library to process our data and use the [Evaluate](https://github.com/huggingface/evaluate) get the metric we need to use for evaluation.

In [2]:
from datasets import load_dataset

raw_datasets = load_dataset("csv", data_files=f'{data_dir}/data.tsv', delimiter='\t')

all_columns = raw_datasets.column_names['train']
required_columns = ['original', 'english simplified']
unrequired_columns = [col for col in all_columns if col not in required_columns]
original_col = required_columns[0]
target_col = required_columns[-1]

# Remove unrequired columns
raw_datasets['train'] = raw_datasets['train'].remove_columns(unrequired_columns)

# Filter out rows where 'english simplified' is None
raw_datasets['train'] = raw_datasets['train'].filter(lambda example: example['english simplified'] is not None)

The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:

In [3]:
raw_datasets["train"][0]

{'original': "Public trust in physicians has declined over the last 50 years. Future physicians will need to mend the patient-physician trust relationship. In conjunction with the American Medical Association's Accelerating Change in Medical Education initiative, the Mayo Clinic Alix School of Medicine implemented the Science of Health Care Delivery (SHCD) curriculum-a 4-year curriculum that emphasizes interdisciplinary training across population-centered care; person-centered care; team-based care; high-value care; leadership; and health policy, economics, and technology-in 2015. In this medical student perspective, the authors highlight how the SHCD curriculum has the potential to address issues that have eroded patient-physician trust. The curriculum reaches this aim through didactic and/or experiential teachings in health equity, cultural humility and competence, shared decision making, patient advocacy, and safety and quality of care. It is the authors' hope that novel medical edu

## Dataset train/validation/test split

We split the dataset in the below ratio:
- Training set: 99%
- Validation set: 0.5%
- Test set: 0.5%

In [3]:
# Calculate sizes for train, validation, and test sets
total_n = raw_datasets['train'].num_rows
split_n = int(0.005 * total_n)

# Define indices for train, validation, and test splits
train_indices = list(range(total_n - 2 * split_n))
validation_indices = list(range(total_n - 2 * split_n, total_n - split_n))
test_indices = list(range(total_n - split_n, total_n))

# Perform rigid train-validation-test split
raw_datasets["validation"] = raw_datasets["train"].select(indices=validation_indices).shuffle(seed=42)
raw_datasets["test"] = raw_datasets["train"].select(indices=test_indices).shuffle(seed=42)
raw_datasets["train"] = raw_datasets["train"].select(indices=train_indices).shuffle(seed=42)

# Display raw_datasets to verify the splits
print(raw_datasets)

# used later for tokenization
max_input_length = 512
max_target_length = 512

DatasetDict({
    train: Dataset({
        features: ['original', 'english simplified'],
        num_rows: 63693
    })
    validation: Dataset({
        features: ['original', 'english simplified'],
        num_rows: 321
    })
    test: Dataset({
        features: ['original', 'english simplified'],
        num_rows: 321
    })
})


In [6]:
# # keep only a subsample of the datasets
# raw_datasets["train"] = raw_datasets["train"].select(range(10))
# raw_datasets["validation"] = raw_datasets["validation"].select(range(1))
# raw_datasets["test"] = raw_datasets["test"].select(range(1))

# raw_datasets

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [10]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [11]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,original,english simplified
0,"Keck School of Medicine of USC\nTeaching hospitals in California\nChildren's hospitals in the United States\nHospitals in Los Angeles\nHealthcare in Los Angeles\nHospitals established in 1901\nEast Hollywood, Los Angeles\nPediatric trauma centers","The Keck School of Medicine of USC is a medical school in Los Angeles, California. They have teaching hospitals, including children's hospitals, that provide healthcare to people in Los Angeles. The school was established in 1901 and is located in East Hollywood, Los Angeles. They also have a pediatric trauma center."
1,"AIDS is a computer virus written in Turbo Pascal 3.01a which overwrites COM files. AIDS is the first virus known to exploit the MS-DOS ""corresponding file"" vulnerability. In MS-DOS, if both and exist, then will always be executed first. Thus, by creating infected files, AIDS code will always be executed before the intended code.","This computer virus, called AIDS, is designed to mess up your computer files. It works by taking over important files on your computer, making it hard for your computer to run properly."
2,"The early period of the COVID-19 pandemic necessitated a rapid increase in out-of-office care. To capture the impact from COVID-19 on care for patients with hypertension, a questionnaire was disseminated to community health center clinicians. The extent, types, and causes of care delays and disruptions were assessed along with adaptations and innovations used to address them. Clinician attitudinal changes and perspectives on future hypertension care were also assessed. Of the 65 respondents, most (90.8%) reported their patients with hypertension experienced care delays or disruptions, including lack of follow-up, lack of blood pressure assessment, and missed medication refills or orders. To address care delays and disruptions for patients with hypertension, respondents indicated that their health center increased the use of telehealth or other technology, made home blood pressure devices available to patients, expanded outreach and care coordination, provided medication refills for longer periods of time, and used new care delivery options. The use of self-measured blood pressure monitoring (58.5%) and telehealth (43.1%) was identified as the top adaptations that should be sustained to increase access to and patient engagement with hypertension care; however, barriers to both remain. Policy and system level changes are needed to support value-based care models that include self-measured blood pressure and telehealth.","The COVID-19 pandemic made it difficult for people with high blood pressure to get the care they needed. We wanted to understand how this affected patients and their doctors. So we asked doctors at community health centers how the pandemic affected their patients with high blood pressure. Most doctors said their patients had trouble getting care, like missing appointments, not having their blood pressure checked, or not getting their medication refills. To solve these problems, doctors started using new ways to provide care, like telehealth, providing home blood pressure monitors, and making it easier for patients to get their medications. Many doctors believe that using home blood pressure monitoring and telehealth should continue because they can help more people get the care they need for high blood pressure. However, there are still challenges in using these new approaches. We need to make changes to how healthcare is organized to make sure these new ways of providing care work well for everyone."
3,"The poor translation of research findings into routine clinical practice is common in all areas of healthcare. Having a better understanding of how researchers and clinicians experience engagement in and with research, their working relationships and expectations of each other, may be one way to help to facilitate collaborative partnerships and therefore increase successful translation of research into clinical practice. To explore the views of clinical and research staff about their experiences of working together during research projects and identify the facilitators and barriers. We conducted four focus groups with 18 participants - clinicians, researchers and those with a dual clinical-research role, recruited from one mental health Trust and one university. Data was analysed using thematic analysis. Eight themes were identified under the headings of two research questions 1) Barriers and facilitators of either engaging in or with research from the perspective of clinical staff, with themes of understanding the benefits of the research; perceived knowledge and personal qualities of researchers; lack of time and organisational support to be involved in and implement research; and lack of feedback about progress and outcome of research. 2) Barriers and facilitators for engaging with clinicians when conducting research, from the perspective of researchers, with themes of understanding what clinicians need to know and how they need to feel to engage with research; demonstrating an understanding of the clinician's world; navigating through the clinical world; and demands of the researcher role. There was agreement between clinicians and researchers about the barriers and facilitators for engaging clinicians in research. Both groups identified that it was the researcher's responsibility to form and maintain good working relationships. Better support for researchers in their role calls for training in communication skills and bespoke training to understand the local context in which research is taking place.","It's common for medical research findings to not be used in everyday patient care. To improve this, we need to understand how researchers and doctors work together. We want to know how they feel about research, how they interact, and what they expect from each other. We did a study with 18 people – doctors, researchers, and people who work in both roles – from a mental health center and a university. We talked to them in groups and found eight main themes. Both doctors and researchers agreed that researchers need to build strong relationships with doctors. They also agreed that researchers need more training to understand how to communicate with doctors and to learn about the specific challenges of the places where they are doing research."
4,"Cardiothoracic surgery is the field of medicine involved in surgical treatment of organs inside the thoracic cavity — generally treatment of conditions of the heart (heart disease), lungs (lung disease), and other pleural or mediastinal structures.","Cardiothoracic surgery is a branch of medicine that involves operating on organs in the chest, like the heart and lungs. This type of surgery is used to treat various conditions related to these organs."


## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,
- we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [5]:
def preprocess_function(examples):
    inputs = [f'summarization: {original}' for original in examples[original_col]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples[target_col], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

In [14]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
# tokenized_datasets = tokenized_datasets.filter(lambda example: len(example['labels']) < max_target_length)
tokenized_datasets_reduced = tokenized_datasets.filter(lambda example: len(example['labels']) > 5 and len(example['labels']) < 256)

Map:   0%|          | 0/63693 [00:00<?, ? examples/s]

Map:   0%|          | 0/321 [00:00<?, ? examples/s]

Map:   0%|          | 0/321 [00:00<?, ? examples/s]

Filter:   0%|          | 0/63693 [00:00<?, ? examples/s]

Filter:   0%|          | 0/321 [00:00<?, ? examples/s]

Filter:   0%|          | 0/321 [00:00<?, ? examples/s]

In [17]:
# Save tokenized_datasets to disk as it is time-consuming to tokenize
tokenized_datasets_reduced.save_to_disk(f'{data_dir}/tokenized_datasets_reduced_en')

Saving the dataset (0/1 shards):   0%|          | 0/58614 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/301 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/298 [00:00<?, ? examples/s]

**We observe that our reduced dataset with only 256 tokens are sufficient to cover most records as below.**

This was done to save memory and improve training speed.
* 58.6k / 63.6k (92.0%)

In [16]:
import pandas as pd

df = pd.Series([len(label) for label in tokenized_datasets_reduced['train']['labels']])
df.describe()

count    58614.000000
mean       132.330757
std         55.514934
min          6.000000
25%         91.000000
50%        131.000000
75%        173.000000
max        255.000000
dtype: float64

## Fine-tuning the model

In [None]:
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenized_datasets = load_from_disk(f'{data_dir}/tokenized_datasets_reduced_en') # Load tokenized_datasets from disk
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

The last thing to define for our `Seq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:

In [50]:
import numpy as np
import evaluate

rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`:

In [None]:
batch_size = 16

training_args = Seq2SeqTrainingArguments(
    model_dir,
    eval_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=1,
    save_strategy="steps",
    num_train_epochs=30, # 10 done
    predict_with_generate=True,
    fp16=True,
    generation_max_length=max_target_length, 
    # push_to_hub=True
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

We can now finetune our model by just calling the `train` method:

## Model evaluation

We load the best checkpoint for the model and evaluate its performance against other similar medical text simplification models.

In [None]:
import os
import nltk
from transformers import AutoModelForSeq2SeqLM

model_name = "t5-base"
model_dir = f"{model_name.replace('/', '-')}-checkpoints"

# Check if model directory exists and is not empty
if os.path.exists(model_dir) and os.listdir(model_dir):
    model_checkpoint = os.path.join(model_dir, os.listdir(model_dir)[-1])
else:
    model_checkpoint = model_name

model_t = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [13]:
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenized_datasets = load_from_disk(f'{data_dir}/tokenized_datasets_reduced_en') # Load tokenized_datasets from disk

test_subset = tokenized_datasets['test'].shuffle().select(range(10))
test_sources = [r[original_col] for r in test_subset]
test_references = [r[target_col] for r in test_subset]
print(f"Sample Input: {test_sources[0]}")

Sample Input: The frequency of skin ulceration makes an important contributor to the morbidity burden in people with sickle cell disease. Many treatment options are available to the healthcare professional, although it is uncertain which treatments have been assessed for effectiveness in people with sickle cell disease. This is an update of a previously published Cochrane Review.


In [14]:
import torch
from tqdm import tqdm

model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_t.to(model_device)

batch_size = 16
total_samples = len(test_sources)
prediction = []

# Process inputs in batches
for start_idx in tqdm(range(0, total_samples, batch_size)):
    end_idx = min(start_idx + batch_size, total_samples)
    # tokenized_batch = test_subset[start_idx:end_idx]
    batch_sources = test_sources[start_idx:end_idx]

    tokenized_batch = tokenizer(batch_sources, max_length=max_input_length, truncation=True, padding=True, return_tensors="pt").to(model_device)

    with torch.no_grad():
        output = model_t.generate(**tokenized_batch, num_beams=8, do_sample=True, min_length=0, max_length=max_target_length)

    # Decode the generated output and add to prediction
    batch_prediction = tokenizer.batch_decode(output, skip_special_tokens=True)
    prediction.extend(batch_prediction)
# prediction = [nltk.sent_tokenize(o.strip())[0] for o in decoded_output]

100%|██████████| 1/1 [00:07<00:00,  7.45s/it]


In [8]:
import pandas as pd
import numpy as np

predictions_df = pd.DataFrame(np.array([test_sources, test_references, prediction]).T, columns=['Original sentence', 'Simplified sentence', 'Predicted sentence'])
predictions_df.to_csv(f'{results_dir}/predictions-{model_name.replace("/", "-")}.tsv', sep='\t')

In [9]:
predictions_df

Unnamed: 0,Original sentence,Simplified sentence,Predicted sentence
0,Ovarian cancer is the third most common gynaec...,Ovarian cancer is the third most common type o...,Ovarian cancer is the third most common gyneco...
1,No studies compared liberal versus conservativ...,There were no studies comparing giving lots of...,This review looked at whether giving fluids to...
2,In children with urinary tract infection (UTI)...,Kids with urinary tract infections (UTIs) can ...,This study looked at whether a test called pro...
3,"We searched the Cochrane Depression, Anxiety a...",We looked for studies that tested different tr...,We searched a database called the Cochrane Dep...
4,Two review authors independently assessed reco...,Researchers looked at studies comparing botuli...,This study looked at whether a medicine called...
...,...,...,...
293,We searched the Cochrane Kidney and Transplant...,We looked for studies relevant to our review i...,We searched a database called the Cochrane Kid...
294,Ovulatory disturbance is a key diagnostic feat...,Having irregular ovulation is a key sign of po...,Polycystic ovarian syndrome (PCOS) is a seriou...
295,We identified 18 RCTs examining a range of com...,This study looked at 18 different ways to help...,We looked at 18 studies that tested different ...
296,Abortion is common worldwide and increasingly ...,Abortion is common around the world. More and...,"Abortion is common around the world, and more ..."


In [15]:
from metrics import fk, ari, bleu, rouge, meteor, sari, bertscore

# Calculate metrics
fk_score = fk(prediction)
ari_score = ari(prediction)
bleu_score = bleu(test_references, prediction)
rouge_score = rouge(test_references, prediction)
meteor_score = meteor(test_references, prediction)
sari_score = sari(test_sources, test_references, prediction)
bertscore_score = bertscore(test_references, prediction)

print("FK index:", fk_score)
print("ARI index:", ari_score)
print("BLEU Score:", bleu_score)
print("ROUGE Score:", rouge_score)
print("METEOR Score:", meteor_score)
print("SARI Score:", sari_score)
print("BERTScore:", bertscore_score)

# NOTE: There can be minor variance in results every time the evaluation is run, a mean of many samples is recommeded.

FK index: 10.690000000000001
ARI index: 13.05
BLEU Score: 0.22555821117555325
ROUGE Score: {'rouge1': np.float64(0.5491531509821976), 'rouge2': np.float64(0.29510241772094437), 'rougeL': np.float64(0.4504893575357102), 'rougeLsum': np.float64(0.44968715561786)}
METEOR Score: 0.4742659218879215
SARI Score: 0.5067694622718653
BERTScore: 0.744801664352417


|                          | Readability |       | Lexical  |        |         |       | Simplification | Semantic  |
|--------------------------|:-----------:|:-----:|:--------:|:------:|:-------:|:-----:|:--------------:|:---------:|
| Models                   | FK          | ARI   | Rouge1   | Rouge2 | Rouge-L | BLEU  | SARI           | BertScore |
| UL-BART (Devraj et al.)  | 11.97       | **13.73** | 38.00    | 14.00  | 36.00   | 39.0  | 40.00          | N/A       |
| NapSS (Lu et al.)        | **10.97**       | 14.27 | 48.05    | 19.94  | 44.76   | 12.3  | 40.37          | 25.73     |
| T5-base-finetuned (ours) | 11.46       | 14.41 | **54.88**    | **28.85**  | **42.51**   | **51.00**  | **73.70**          | **74.83**     |

## The End
Thank you!!