In [None]:
# Got CUDA out of memory errors due to fragmentation, used this line to fix the issue
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

Read in data

Data can be found here: https://cs.pomona.edu/~dkauchak/simplification/

In [None]:
import tarfile
f = tarfile.open('sentence-aligned.v2.tar.gz')
f.extractall('./data')
f.close()

In [None]:
def get_dataset(path):
  data = []
  f = open(path, 'r')
  lines = f.readlines()
  for line in lines:
    data.append(line)
  return data

In [None]:
# Using ver. 2.0 of the dataset, where each datapoint is of the format:
#   TOPIC \t NUMBER \t SENTENCE
x = get_dataset('./data/sentence-aligned.v2/normal.aligned')
y = get_dataset('./data/sentence-aligned.v2/simple.aligned')

In [None]:
# Extract sentences
x = [s.split('\t')[2] for s in x]
y = [s.split('\t')[2] for s in y]

In [None]:
from datasets import Dataset, DatasetDict
import pandas as pd
import regex as re

data = {'text_inputs':x, 'text_labels':y}
dataset = Dataset.from_pandas(pd.DataFrame(data=data))

In [None]:
# filter out samples where input and target are the same
dataset = dataset.filter(lambda x: x['text_inputs'] != x['text_labels'])

In [None]:
# split into trn, dev, test
dataset = dataset.train_test_split(test_size=0.2)
dev_and_test = dataset['test'].train_test_split(test_size=0.5)

In [None]:
# build final dataset dict
dataset = DatasetDict({
  'train': dataset['train'],
  'dev': dev_and_test['train'],
  'test': dev_and_test['test']
  })

In [None]:
# BLEU score increases by 2pts (0.55 to 0.57) when we clean the data of all non-English characters (mainly affects names and places)
def clean_data(data):
  inputs = data['text_inputs']
  inputs = [re.sub(r'[^\x00-\x7f]',r'',x) for x in inputs]

  labels = data['text_labels']
  labels = [re.sub(r'[^\x00-\x7f]',r'',x) for x in labels]

  return {'text_inputs':inputs,'text_labels':labels}

In [None]:
dataset = dataset.map(clean_data,batched=True)

Load in model and relevant hyperparameters

In [None]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Using T5 small model
tokenizer = T5Tokenizer.from_pretrained("t5-small")

# manually add tokens for parentheses (LRB and RRB) to the tokenizer and model
special_tokens_dict = {'additional_special_tokens': ['-LRB-', '-RRB-']}
tokenizer.add_special_tokens(special_tokens_dict)

st_model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
st_model.resize_token_embeddings(len(tokenizer))

Preprocessing 

We took guidance from the HF T5 Tutorials - https://huggingface.co/docs/transformers/model_doc/t5

In [None]:
# prefix we chose follows the format and style shown in the original T5 paper: https://arxiv.org/pdf/1910.10683.pdf
# (see their Appendix D for examples)
prefix = "simplify: "

In [None]:
# define max input,target lengths
max_input_length = 128
max_target_length = 128

In [None]:
def preprocess_inputs(data):
  return tokenizer([prefix + input for input in data['text_inputs']], padding='max_length', max_length=max_input_length, truncation=True, return_tensors='pt')

def preprocess_labels(data):
  encoding = tokenizer(data['text_labels'], padding='max_length', max_length=max_target_length, truncation=True, return_tensors='pt')
  return {'labels': encoding['input_ids']}

In [None]:
dataset = dataset.map(preprocess_inputs,batched=True)
dataset = dataset.map(preprocess_labels,batched=True)

In [None]:
# split into trn, dev, and test datasets. Note that we're not using all available data for the sake of time
train_dataset = dataset['train'].remove_columns(['text_inputs','text_labels']).shuffle(seed=42).select(range(5000))
dev_dataset = dataset['dev'].remove_columns(['text_inputs','text_labels']).shuffle(seed=42).select(range(625))
test_dataset = dataset['test'].remove_columns(['text_inputs','text_labels']).shuffle(seed=42).select(range(625))

Finetune model

In [None]:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer)

In [None]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

In [None]:
import regex as re
# converts parentheses tags into parentheses for outputting to website
def postprocess_text_for_output(text):
  text = re.sub('-LRB-', '(', text)
  text = re.sub('-RRB-', ')', text)
  return text

In [None]:
import evaluate
import numpy as np
import textstat

# compute rouge, bleu, sari, and readability
def my_compute_metrics(eval_pred):
  inputs = eval_pred.inputs
  labels = eval_pred.label_ids
  pred_ids = eval_pred.predictions
  if isinstance(pred_ids, tuple):
    pred_ids = pred_ids[0]

  preds = np.argmax(pred_ids, axis=-1)

  # Decode inputs to compute metrics (specifically, SARI requires inputs)
  inputs_text = tokenizer.batch_decode(inputs, skip_special_tokens=True)

  # Decode predictions to compute metrics
  preds_text = tokenizer.batch_decode(preds, skip_special_tokens=True)

  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  labels_text = tokenizer.batch_decode(labels, skip_special_tokens=True)

  preds_text, labels_text = postprocess_text(preds_text, labels_text)
  
  # init metrics
  metric1 = evaluate.load('rouge')
  metric2 = evaluate.load('bleu')
  metric3 = evaluate.load('sari')

  # compute metrics
  metrics = metric1.compute(predictions=preds_text, references=labels_text)
  metrics['bleu'] = metric2.compute(predictions=preds_text, references=labels_text)['bleu']
  metrics['readability'] = sum([textstat.flesch_kincaid_grade(x,) for x in preds_text]) / len(preds_text) #FKGL score
  metrics['sari'] = metric3.compute(sources=inputs_text, predictions=preds_text, references=labels_text)['sari']

  return metrics

In [None]:
train_batch_size = 8
eval_batch_size = 8

In [None]:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
    output_dir='style_transfer_chkpts',
    include_inputs_for_metrics=True,   # allows us to compute SARI in compute_metrics fn
    num_train_epochs=5,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    eval_accumulation_steps=1,
    prediction_loss_only=False,
    learning_rate=0.0005,
    evaluation_strategy='steps',
    save_steps=1000,
    save_total_limit=3,
    remove_unused_columns=True,
    run_name='run_final', # Wandb run name
    logging_steps=1000, 
    eval_steps=1000, 
    logging_first_step=False, 
    load_best_model_at_end=True, 
    metric_for_best_model="loss", # loss to eval models
    greater_is_better=False 
)

In [None]:
trainer = Trainer(
  model=st_model,
  args=training_args,
  train_dataset=train_dataset,
  eval_dataset=dev_dataset,
  data_collator=data_collator,
  compute_metrics=my_compute_metrics,
)

In [None]:
# Finetune!
trainer.train()

In [None]:
# save finetuned model and tokenizer
trainer.save_model('best_model_simplify')
tokenizer.save_pretrained('best_model_simplify')

model = trainer.model

Evaluate model

In [None]:
# Eval finetuned model in inference mode
test_args = TrainingArguments(
  output_dir="finetune_simp_eval",
  do_train=False,
  do_predict=True,
  fp16=True,
  per_device_eval_batch_size=8,   
)

In [None]:
# init test trainer
test_trainer = Trainer(
            model=model, 
            args=test_args, 
            tokenizer=tokenizer,
            compute_metrics=my_compute_metrics)

In [None]:
test_results = test_trainer.predict(test_dataset)
print(test_results)