# Abstractive summaries - Train Distilt5 on TWEETSUMM dataset

In [None]:
from huggingface_hub import login
import pandas as pd
import numpy as np
import os, time, datetime, shutil

from datasets import Dataset, DatasetDict

from transformers import DataCollatorForSeq2Seq, AutoTokenizer, set_seed
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

import wandb

In [None]:
!pip freeze > requirements_t5.txt

In [None]:
def get_current_time(underscore=False):
    return datetime.datetime.now().strftime("%d%m-%H%M" if not underscore else "%d%m_%H%M")

In [None]:
run_name = f"t5-abs-{get_current_time()}"
models_dir = os.path.join(os.getcwd(), 'models')
os.makedirs(models_dir, exist_ok=True)
results_dir = os.path.join(os.getcwd(), 'results', 't5')
os.makedirs(results_dir, exist_ok=True)
ds_dir = os.path.join(os.getcwd(), 'data')
print(run_name)

In [None]:
try:
    HF_TOKEN =  os.environ['HF_TOKEN']
except:
    HF_TOKEN = ""

if 'google.colab' in str(get_ipython()):
    print("Running on Colab")
    from google.colab import drive, userdata
    drive.mount('/content/drive')
    HF_TOKEN = userdata.get('HF_TOKEN')
elif os.environ.get('KAGGLE_KERNEL_RUN_TYPE') != None:
    ds_dir = '/kaggle/input/bertdata2207/'
    from kaggle_secrets import UserSecretsClient
    print("Running on Kaggle")
    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
    WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY
    os.makedirs(os.path.join(os.getcwd(), "results"), exist_ok=True)
    os.makedirs(os.path.join(os.getcwd(), 'results', 't5'), exist_ok=True)


In [None]:
set_seed(17)

In [None]:
os.environ["WANDB_PROJECT"] = f"aiml-thesis-train-{run_name}"
os.environ["WANDB_WATCH"] = "false"
wandb.init(settings=wandb.Settings(start_method="thread"), id=run_name)

In [None]:
login(token=HF_TOKEN)

## Load data

In [None]:
print(ds_dir)

In [None]:
checkpoint_t5 = "google-t5/t5-base"

In [None]:
def csv_to_pandas(file_name, ds_dir, drop_conv_id=True):
    df = pd.read_csv(os.path.join(ds_dir, file_name), names=['conv_id', 'dialogue', 'summary'], encoding='utf-8', dtype={'conv_id': 'string', 'dialogue': 'string', 'summary': 'string'})
    df = df.convert_dtypes()
    if drop_conv_id:
        df.drop(columns=['conv_id'], inplace=True)
    df.reset_index(drop=True, inplace=True)
    return df

In [None]:
train_df_temp = csv_to_pandas("dials_abs_2607_1312_train_spc.csv", ds_dir)
val_df_temp = csv_to_pandas("dials_abs_2607_1312_valid_spc.csv", ds_dir)
test_df = csv_to_pandas("dials_abs_2607_1312_test_spc.csv", ds_dir, drop_conv_id=False)

print(train_df_temp.dtypes)
print(train_df_temp.head(), len(train_df_temp))

In [None]:
tweetsumm_abs = DatasetDict(
    {
        'train': Dataset.from_pandas(train_df_temp),
        'validation': Dataset.from_pandas(val_df_temp),
        'test': Dataset.from_pandas(test_df)
    }
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint_t5)
print(tokenizer)

In [None]:
# Source: https://huggingface.co/docs/transformers/en/tasks/summarization

def preprocess_function(examples):
    prefix = "summarize: "
    inputs = [str(prefix) + str(dial) for dial in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True) # same params as tweetsumm paper
    labels = tokenizer(text_target=examples["summary"], max_length=80, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_tweetsumm_abs = tweetsumm_abs.map(preprocess_function, batched=True, remove_columns=['dialogue','summary'])
print(tokenized_tweetsumm_abs["train"][1])

## Setup Training Evaluation

In [None]:
!pip install -U nltk

In [None]:
!pip install evaluate pyrouge rouge_score bert_score meteor

In [None]:
import evaluate, nltk, csv
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
bertscore = evaluate.load("bertscore")

nltk.download('punkt_tab')

In [None]:
def compute_metrics_abs(eval_pred):
    predictions, labels = eval_pred
    # Extra line added to address an overflow: https://github.com/huggingface/transformers/issues/22634
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]

    rouge_scores = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    bert_scores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
    bert_scores.pop('hashcode')
    result = {
      **{f"rouge/{k}": round(v, 4) for k,v in rouge_scores.items()},
      **{f"bertscore/bertscore-{k}": round(np.mean(v), 4) for k,v in bert_scores.items()},
      'meteor': round(meteor.compute(predictions=decoded_preds, references=decoded_labels)['meteor'], 4),
    }
   
    result["gen_len"] = np.mean(prediction_lens)
    return result


## Train and Evaluate

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_t5)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
EXPERIMENT_PARAMS = []
BASE_PARAMS = {'lr':1e-4, 'batch_size':10, 'epochs': 20}
EXPERIMENT_PARAMS.append(BASE_PARAMS)

In [None]:
LEARN_RATES = (1e-3, 1e-4, 1e-5)
BATCH_SIZES = (2,5,10)
EPOCHS = (20,)

for lr in LEARN_RATES:
    for batch_size in BATCH_SIZES:
        for epoch in EPOCHS:
            if lr == BASE_PARAMS['lr'] and batch_size == BASE_PARAMS['batch_size'] and epoch == BASE_PARAMS['epochs']:
                continue
            experiment = {'lr':lr, 'batch_size':batch_size, 'epochs': epoch}
            EXPERIMENT_PARAMS.append(experiment)

In [None]:
def run_post_training(split, test_details, test_df_temp: pd.DataFrame, tokenizer, experiment, run_name_model, epoch, results_dir):
    # First line added due to label error, see 
    predictions = np.where(test_details.predictions != -100, test_details.predictions, tokenizer.pad_token_id)
    preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    test_df_temp['response'] = preds
    exp_res = {**experiment, **(test_details.metrics)}
    test_metrics_df = pd.DataFrame([exp_res])
    test_df_temp = test_df_temp.convert_dtypes()
    test_metrics_df = test_metrics_df.convert_dtypes()
    wandb.log({run_name_model: test_details.metrics})
    preds_name = f"{split}_preds_{run_name_model.replace('-','_')}_s{epoch}_t5.csv"
    metrics_name =  f"{split}_metrics_{run_name_model.replace('-','_')}_s{epoch}_t5.csv"
    test_df_temp.to_csv(os.path.join(results_dir, preds_name), index=False, header=False, encoding='utf-8', quoting=csv.QUOTE_ALL)
    test_metrics_df.to_csv(os.path.join(results_dir, metrics_name), index=False, header=True, encoding='utf-8', quoting=csv.QUOTE_ALL)

In [None]:
class ExtraCallback(TrainerCallback):        
    def on_train_end(self, args, state, control, **kwargs):
        # Save and upload CSVs
        super().on_train_end(args, state, control, **kwargs)
        df = pd.DataFrame(state.log_history)
        df = df.convert_dtypes()
        df = df.groupby(['epoch'], as_index=False).sum()
        df.to_csv(os.path.join(results_dir, "log_" + args.run_name.replace('-','_') + ".csv"), header=True, index=False)

In [None]:
for count, exp in enumerate(EXPERIMENT_PARAMS):
    run_name_model = f"{run_name}-lr-{exp['lr']}-bs-{exp['batch_size']}-maxep-{exp['epochs']}"
    print("=== Starting experiment", count, f"on {get_current_time()}:", run_name_model, "training")
    wandb.run.name = run_name_model
    wandb.run.save()

    training_args = Seq2SeqTrainingArguments(
        output_dir=os.path.join(models_dir, run_name_model),
        eval_strategy="epoch",
        logging_strategy="epoch",
        save_only_model=True,
        learning_rate=exp['lr'],
        per_device_train_batch_size=exp['batch_size'],
        per_device_eval_batch_size=exp['batch_size'],
        weight_decay=0.0,
        lr_scheduler_type='linear',
        warmup_ratio=0.1,
        gradient_accumulation_steps=2,
        save_strategy="epoch",
        save_total_limit=1,
        load_best_model_at_end=True,
        metric_for_best_model="eval_rouge/rougeL",
        greater_is_better=True,
        num_train_epochs=exp['epochs'],
        predict_with_generate=True,
        fp16=True,
        generation_max_length=80,
        push_to_hub=False,
        report_to="wandb",
        run_name=run_name_model,
    )
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_tweetsumm_abs["train"], # .select(range(0,50)),
        eval_dataset=tokenized_tweetsumm_abs["validation"], # .select(range(0,10)),
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics_abs,
    )
    trainer.add_callback(ExtraCallback)
    training_start = time.time()
    trainer.train()
    training_end = time.time()
    print(f"Finished experiment {count}: {run_name_model} - time it took for training:", str(datetime.timedelta(seconds=(training_end-training_start))))
    test_details = trainer.predict(tokenized_tweetsumm_abs['test'], metric_key_prefix='test')
    run_post_training('test', test_details, test_df, tokenizer, exp, run_name_model, trainer.state.best_model_checkpoint.split('-')[-1], results_dir)
    trainer.push_to_hub()
    shutil.rmtree(models_dir)
    os.makedirs(models_dir)

In [None]:
# Using wandb documentation: https://docs.wandb.ai/guides/artifacts
def log_csv_wandb(results_path, model_name):
    artifact = wandb.Artifact(name=model_name, type="predictions")
    for root, dirs, files in os.walk(results_path):
        for file in files:
            artifact.add_file(local_path=os.path.join(root, file), name=file)
    wandb.log_artifact(artifact)

In [None]:
log_csv_wandb(results_dir, run_name)

In [None]:
print("Finished all training and evaluation for", run_name)
wandb.finish()

In [None]:
print("Results uploaded")