# Text Simplification finetune with WikiLarge
This notebook fine-tunes models for text simplification.
Originally used in Google Colab with L4 GPU mostly

In [None]:
# Install dependencies
!pip install -q sentencepiece huggingface_hub
!pip install -q --upgrade datasets fsspec transformers

In [None]:
# login to huggingface, may not be needed
from huggingface_hub import login

HF_TOKEN = '' #put your hf token

login(HF_TOKEN)

import torch
device_idx = 0
name = torch.cuda.get_device_name(device_idx)
gpumodel = "Unknown"
if "T4" in name.upper():
    gpumodel = "T4"
elif "L4" in name.upper():
    gpumodel = "L4"
print(f"GPU name: {name}")
print(f"Detected model: {gpumodel}")
if gpumodel == "L4":
    torch.set_float32_matmul_precision("high")


Using `dataset_cleaning.ipynb` we built `wikilarge_dataset` and `wikilarge_dataset_clean`.

These folders are WikiLarge in Huggingface's Dataset format.

In [None]:
from datasets import load_from_disk, load_dataset
import os

DATASET_PATH       = os.path.join(os.getcwd(), "datasets/wikilarge_dataset")
CLEAN_DATASET_PATH = os.path.join(os.getcwd(), "datasets/wikilarge_dataset_clean")

if os.path.exists(CLEAN_DATASET_PATH):
    dataset = load_from_disk(CLEAN_DATASET_PATH)
    print("Loaded clean dataset from disk")
elif os.path.exists(DATASET_PATH):
    dataset = load_from_disk(DATASET_PATH)
    print("Warning: Loaded non-clean dataset from disk.")
    print("See `dataset_cleaning.ipynb` for more info.")
else:
    print("Error: No dataset found on disk.")
    print("See `dataset_cleaning.ipynb` for more info.")

    print(" * New fallback: using our cleaned dataset from HF hub:")
    dataset = load_dataset("eilamc14/wikilarge-clean")
    print("Loaded clean dataset from HF")

Loaded clean dataset from disk


Normal seq2seq - T5, BART etc.

In [None]:
# Load model and tokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

model.config.use_cache = False
model.gradient_checkpointing_enable()

print("Some model info for testing...")
print("pad:", tokenizer.pad_token_id, "eos:", tokenizer.eos_token_id)
print("cfg pad:", model.config.pad_token_id, "cfg eos:", model.config.eos_token_id)
print("start:", model.config.decoder_start_token_id)

# model.config.dropout = 0.15
print("dropout: ", model.config.dropout) #= 0.15
print("atten_dropout: ", model.config.attention_dropout) #= 0.10
print("active_dropout: ", model.config.activation_dropout) #= 0.10

print("gradient checkpointing:", model.is_gradient_checkpointing)
print("use_cache:", model.config.use_cache)


In [None]:
# Preprocess data

def preprocess_function(examples):
    inputs = ["Simplify: " + ex for ex in examples["source"]]
    targets = examples["target"]

    model_inputs = tokenizer(
        inputs,
        max_length=256,
        truncation=True,
        add_special_tokens=True
    )

    labels = tokenizer(
        text_target=targets,
        max_length=256,
        truncation=True,
        add_special_tokens=True
    )

    model_inputs["labels"] = labels["input_ids"]

    model_inputs["src_text"] = examples["source"]

    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)

Adding SARI and Identical-ratio to the training

Note -> we are using HF's "evaluate" for SARI because it is faster to download and use, but may give slightly different SARI results than EASSE which is used in the final results. 

In [None]:
!pip install -q evaluate sacremoses sacrebleu

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m59.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import numpy as np, evaluate

sari_metric = evaluate.load("sari")

def compute_metrics(eval_preds):
    preds  = eval_preds.predictions if hasattr(eval_preds, "predictions") else eval_preds[0]
    labels = eval_preds.label_ids if hasattr(eval_preds, "label_ids") else eval_preds[1]
    inputs = eval_preds.inputs if hasattr(eval_preds, "inputs") else eval_preds[2]

    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        print("pad is none")
        pad_id = getattr(tokenizer, "eos_token_id", 0)

    preds = np.where(preds != -100, preds, pad_id)
    decoded_preds = tokenizer.batch_decode(
        preds, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    labels = np.where(labels != -100, labels, pad_id)
    decoded_refs = tokenizer.batch_decode(
        labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    inputs = np.where(inputs != -100, inputs, pad_id)
    decoded_srcs = tokenizer.batch_decode(
        inputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    decoded_srcs = [s.replace("Simplify: ", "", 1) for s in decoded_srcs]

    sari_score = sari_metric.compute(
        sources=decoded_srcs,
        predictions=decoded_preds,
        references=[[r] for r in decoded_refs]
    )["sari"]

    identical_ratio = sum(p.strip() == s.strip() for p, s in zip(decoded_preds, decoded_srcs)) / len(decoded_preds)

    return {"sari": float(sari_score), "identical_ratio": float(identical_ratio)}

Downloading builder script: 0.00B [00:00, ?B/s]

Args and Trainer setup: (Change params for different runs here)

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, EarlyStoppingCallback

def get_training_args(model_name_or_path: str, output_dir: str = os.path.join(os.getcwd(), "models")):
    """Returns suitable TrainingArguments depending on the model."""
    steps = 300 #higher for lower batch size

    if "t5" in model_name_or_path:
        learning_rate = 5e-5
        batch_size = 64 # ~8 without gradient_checkpoint

    elif "bart" in model_name_or_path:
        learning_rate = 3e-5
        batch_size = 16

    else:
        learning_rate = 2e-5
        batch_size = 96


    return Seq2SeqTrainingArguments(
        output_dir=f"{output_dir}/{model_name_or_path.replace('/', '_')}",
        eval_strategy="steps",
        eval_steps=steps,

        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,

        save_total_limit=3,
        num_train_epochs=5,
        save_strategy="steps",
        save_steps=steps,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False, #for eval_loss
        logging_dir=f"{output_dir}/logs",
        logging_steps=steps-100,
        report_to="none",

        predict_with_generate=True,
        generation_max_length=128,
        generation_num_beams=4,

        include_for_metrics=["inputs"],
        #fp16=True, #True for most, trouble with T5*
        bf16=True, #T4 does not support bf16, L4 does
        group_by_length=True,

        weight_decay=0.01,
        label_smoothing_factor=0.1,
        warmup_ratio=0.1,
        max_grad_norm=0.5,
        #optim="adafactor", #T5*
        optim="adamw_torch_fused",
        lr_scheduler_type="linear",

        dataloader_num_workers=8, #2 on t4 gpu, 8-11 on L4
        dataloader_pin_memory=True,
        dataloader_persistent_workers=True,
    )


In [None]:
args = get_training_args(model_name)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    label_pad_token_id=-100,
    pad_to_multiple_of=8
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    processing_class=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
    compute_metrics=compute_metrics
)

Training from Checkpoint if exists.

For 'restart' with a new run you will need to manually delete/move the checkpoint from the output folder (`models/<model_name>` by default)

In [None]:
try:
    trainer.train(resume_from_checkpoint=True)
except ValueError as e:
    print("No checkpoint found. Training from scratch:")
    trainer.train()

Quick check with `test` subset of WikiLarge-clean and better generation args:

In [None]:
import numpy as np

ds_name = "test"
_prev = trainer.args.group_by_length
trainer.args.group_by_length = False

# Run prediction on test split
pred_output = trainer.predict(tokenized_datasets[ds_name],
                              num_beams=4,
                              no_repeat_ngram_size=3,
                              length_penalty=1.0,
                              )

preds = np.where(pred_output.predictions != -100,
                  pred_output.predictions,
                  tokenizer.pad_token_id)
# Decode predictions
decoded_preds = tokenizer.batch_decode(
    preds,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

trainer.args.group_by_length = _prev

labels = np.where(pred_output.label_ids != -100,
                  pred_output.label_ids,
                  tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

srcs = tokenized_datasets[ds_name]["src_text"]
print(pred_output.metrics)
print(f"SARI:{pred_output.metrics[f'{ds_name}_sari']:.2f}")
print(f"identical ratio:{pred_output.metrics[f'{ds_name}_identical_ratio']:.2f}")
# Print first few examples
for i in range(5):
    print(f"SRC : {srcs[i]}")
    print(f"REF : {decoded_labels[i]}")
    print(f"PRED: {decoded_preds[i]}")
    print("---")

Step,Training Loss,Validation Loss,Model Preparation Time,Sari,Identical Ratio
2400,2.4292,2.518215,0.0057,48.910321,0.0
2600,2.4543,2.457514,0.0057,49.873072,0.0
2800,2.4169,2.492606,0.0057,47.872178,0.0


{'test_loss': 2.2495431900024414, 'test_model_preparation_time': 0.0057, 'test_sari': 48.78379847979586, 'test_identical_ratio': 0.0, 'test_runtime': 14.616, 'test_samples_per_second': 8.279, 'test_steps_per_second': 0.137}
SARI:48.78
identical ratio:0.00
SRC : One side of the armed conflicts is composed mainly of the Sudanese military and the Janjaweed , a Sudanese militia group recruited mostly from the Afro-Arab Abbala tribes of the northern Rizeigat region in Sudan .
REF : one side of the armed conflicts is made of sudanese military and the janjaweed, a sudanese militia recruited from the afro - arab abbala tribes of the northern rizeigat region in sudan.
PRED: the janjaweed , a sudanese militia group recruited mostly from the afro - arab abbala tribes of the northern rizeigat region in sudan .
---
SRC : His next work , Saturday , follows an especially eventful day in the life of a successful neurosurgeon .
REF : his next work at saturday will be a successful neurosurgeon.
PRED: hi

Save the model and tokenizer to `models/trained/<model_name>`

In [None]:
# Save model and tokenizer
model.save_pretrained(f"/content/drive/MyDrive/NLP_Project/models/trained/{model_name.replace('/', '_')}")
tokenizer.save_pretrained(f"/content/drive/MyDrive/NLP_Project/models/trained/{model_name.replace('/', '_')}")

('/content/drive/MyDrive/NLP_Project/models/trained/microsoft_prophetnet-large-uncased-cnndm/tokenizer_config.json',
 '/content/drive/MyDrive/NLP_Project/models/trained/microsoft_prophetnet-large-uncased-cnndm/special_tokens_map.json',
 '/content/drive/MyDrive/NLP_Project/models/trained/microsoft_prophetnet-large-uncased-cnndm/prophetnet.tokenizer',
 '/content/drive/MyDrive/NLP_Project/models/trained/microsoft_prophetnet-large-uncased-cnndm/added_tokens.json')