# Abstractive summaries - Train DistilBART on TWEETSUMM dataset

In [41]:
import json, re
from huggingface_hub import notebook_login
import pandas as pd
import numpy as np
import os, time, datetime

try:
    from datasets import load_dataset
except:
    !pip install datasets
    from datasets import load_dataset

try:
    import accelerate
except:
    !pip install -U 'accelerate==0.27.2'
    import accelerate


import transformers
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, pipeline, set_seed
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

try:
    import wandb
except:
    !pip install wandb

print(transformers.__version__, accelerate.__version__)


4.41.2 0.30.1


In [42]:
# try:
#   import transformers
# except:
#   !pip install -U transformers[torch]
#   import transformers

In [43]:
# for x in dir(transformers):
#     if "torch" in x:
#         print(x)
        
# print(transformers.is_tf_available())

In [44]:
ds_dir = ""
try:
    HF_TOKEN =  os.environ['HF_TOKEN']
except:
    HF_TOKEN = ""

if 'google.colab' in str(get_ipython()):
  print("In Colab")
  from google.colab import drive, userdata
  drive.mount('/content/drive')
  HF_TOKEN = userdata.get('HF_TOKEN')
elif os.environ['KAGGLE_KERNEL_RUN_TYPE']:
  from kaggle_secrets import UserSecretsClient
  print("In Kaggle")
  ds_dir = "/kaggle/input/tweet-data-2106-1512/"
  user_secrets = UserSecretsClient()
  HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
  WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
  os.environ['WANDB_API_KEY'] = WANDB_API_KEY


In Kaggle


In [45]:
set_seed(17)

os.environ["WANDB_PROJECT"] = "aiml-thesis-train"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

wandb.init(settings=wandb.Settings(start_method="thread"))

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [46]:
from huggingface_hub import login
login(token=HF_TOKEN)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


## Load data

In [47]:
train_df_temp = pd.read_feather(ds_dir + "data/train_dial_abs_noex_noco_2006.feather")
train_df_temp.drop(columns=['index', 'company'], inplace=True)

In [48]:
lengths_train_summaries = [len(row) for row in tokenized_tweetsumm_abs['train']['input_ids']]
print("Abstractive summaries training lengths[mean,max,min]:", np.mean(lengths_train_summaries), np.max(lengths_train_summaries), np.min(lengths_train_summaries))

Abstractive summaries training lengths[mean,max,min]: 303.1355353075171 512 119


In [49]:
val_df_temp = pd.read_feather(ds_dir + "data/val_dial_abs_noex_noco_2006.feather")
val_df_temp.drop(columns=['index', 'company'], inplace=True)

In [50]:
from datasets import Dataset, DatasetDict

tweetsum_train_val_abs = DatasetDict(
 {
 'train': Dataset.from_pandas(train_df_temp),
 'validation': Dataset.from_pandas(val_df_temp)
 }
)

In [51]:
tweetsum_train_val_abs['train'][10]

{'dialogue': "<USER> Bought these biscuits a couple of weeks ago, only just opened and looked at the best before date... @49975 <SYSTEM> @393926 We can't see a picture of the best before date or biscuits attached to your tweets here. Please can you tweet or DM it to us again? Thanks. <URL> <USER> @marksandspencer  <URL> @marksandspencer  <URL> @marksandspencer  <URL> <SYSTEM> @393926 We'd certainly like to take a closer look into this. Please DM us a picture of your full receipt. <URL> <USER> @marksandspencer Don’t have a receipt as they were bought for my grandparents 3 weeks ago and have only just realised the date on them @marksandspencer The biscuits were purchased at the Marks and Spencer’s Store at the Ricoh Arena, Coventry <SYSTEM> @393926 Did yo use a sparks card on your transaction, Cian? <USER> @marksandspencer Yes <SYSTEM> @393926 Hi Cian. I'm really sorry to see this, especially when it was such a lovely gesture too! No worries though - we got your back ;) 1/3 @393926 Is th

In [52]:
# Source: https://huggingface.co/docs/transformers/en/tasks/summarization

def preprocess_function(examples):
  prefix = "summarize: "
  inputs = [prefix + dial for dial in examples["dialogue"]]
  model_inputs = tokenizer(inputs, max_length=512, truncation=True) # same params as tweetsumm paper
  labels = tokenizer(text_target=examples["summary"], max_length=80, truncation=True)
  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [53]:
checkpoint_bart = "sshleifer/distilbart-xsum-12-6"

In [54]:
t5_tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
bart_tokenizer = AutoTokenizer.from_pretrained(checkpoint_bart)

In [55]:
tokenizer = bart_tokenizer
tokenized_tweetsumm_abs = tweetsum_train_val_abs.map(preprocess_function, batched=True)

Map:   0%|          | 0/878 [00:00<?, ? examples/s]

Map:   0%|          | 0/110 [00:00<?, ? examples/s]

In [56]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint_bart)

In [57]:
print(json.dumps(tokenized_tweetsumm_abs['train'][5], indent=2))

{
  "dialogue": "<USER> @115802 @AirAsiaSupport doesn\u2019t seem like a customer\u2019s time/money is of value to you! just because you are a low-cost carrier!? Nvr again! ! <SYSTEM> @366867 Sorry for the delay, Saim. Flight Change is subject to Change Fee (per person/per flight) at <URL> plus &gt;&gt; @366867 &gt;&gt;Fare Difference, so you only have to pay for those. If you think it's cheaper to make a new booking, you may consider so.-Floi <USER> @AirAsiaSupport bt how can the rescheduling charges be more than the ticket charges. <SYSTEM> @366867 Really sorry as flight change is subject to change fee +  fare difference accordingly.Thanks - Ed <USER> @AirAsiaSupport i will go ahead with fresh booking bt will u move my baggage n food to the new pnr? <SYSTEM> @366867 Hi Saim , We are sorry as the add ons cannot be transferred to another booking. Thanks - Khairul <USER> @AirAsiaSupport @115802 @121276 never seen somebody sooo least bothered about customer... feel cheated.. and feel rob

End of sentence token for starting generating summaries with BART:  0


## Setup Training Evaluation

In [58]:
!pip install evaluate nltk rouge_score bert_score

  pid, fd = os.forkpty()
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [59]:
!pip install -U nltk

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [60]:
import evaluate

rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
bertscore = evaluate.load("bertscore")

[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [61]:
# import numpy as np


# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
#     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
#     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
#     # result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
#     result = {
#       'rouge': rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True),
#       'bertscore': bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"),
#       'meteor': meteor.compute(predictions=decoded_preds, references=decoded_labels),
#     }
#     prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
#     result["gen_len"] = np.mean(prediction_lens)
#     print(json.dumps(result, indent=2))
#     return {k: round(v, 4) if type(v) != list else v for k, v in result.items()}

In [62]:
def compute_metrics_abs(eval_pred):
  predictions, labels = eval_pred
  decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  
#   bertscores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
#     np.mean(bertscores)
  result = {
    'rouge': rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True),
    'bertscore': bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"),
    'meteor': meteor.compute(predictions=decoded_preds, references=decoded_labels),
  }
  for k,v in result.items():
    print(k, type(v), v)
  prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
  result["gen_len"] = np.mean(prediction_lens)
  print(json.dumps(result, indent=2))
  return result # {k: round(v, 4) if type(v) != list else v for k, v in result.items()}

## Train

In [63]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_bart)

In [64]:
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [71]:
from transformers import GenerationConfig

In [80]:
training_start = time.time()
gen_config = GenerationConfig(max_source_length=512,min_new_tokens=80,bos_token_id=0)
training_args = Seq2SeqTrainingArguments(
    output_dir="trained-distilbart-abs-2106",
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=10,
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
    generation_config=gen_config,
    push_to_hub=False,
    report_to="wandb",
    run_name="distilbart-abs-2106_2341_tesststats"
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_tweetsumm_abs["train"],
    eval_dataset=tokenized_tweetsumm_abs["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics_abs,
)

trainer.train()
training_end = time.time()
print("Time it took for training:", str(datetime.timedelta(seconds=(training_end-training_start))))



Epoch,Training Loss,Validation Loss,Rouge,Bertscore,Meteor,Gen Len
1,2.1191,2.082642,"{'rouge1': 0.21574205572428873, 'rouge2': 0.08510464895182768, 'rougeL': 0.18588121236892674, 'rougeLsum': 0.18904663379694484}","{'precision': [0.8948967456817627, 0.9196454286575317, 0.0, 0.0, 0.9016120433807373, 0.0, 0.8733774423599243, 0.0, 0.9184833765029907, 0.8747596740722656, 0.9072914123535156, 0.89311283826828, 0.9119383692741394, 0.938646137714386, 0.0, 0.0, 0.9016316533088684, 0.0, 0.8718732595443726, 0.0, 0.9207643866539001, 0.8882859945297241, 0.8723347187042236, 0.8945499658584595, 0.0, 0.0, 0.0, 0.8980494737625122, 0.8923925757408142, 0.0, 0.9517074823379517, 0.8911606073379517, 0.0, 0.0, 0.9031713008880615, 0.8742817640304565, 0.0, 0.8885214328765869, 0.0, 0.8802607655525208, 0.9295825958251953, 0.0, 0.8896499276161194, 0.8769556283950806, 0.0, 0.0, 0.0, 0.9439671039581299, 0.881846010684967, 0.0, 0.9132688045501709, 0.9014869332313538, 0.9005262851715088, 0.9194005131721497, 0.0, 0.8754273653030396, 0.0, 0.9017542600631714, 0.8894200921058655, 0.0, 0.9178104400634766, 0.8756179213523865, 0.9284126162528992, 0.0, 0.903694748878479, 0.0, 0.0, 0.0, 0.9338507652282715, 0.0, 0.0, 0.0, 0.9028917551040649, 0.0, 0.9218497276306152, 0.9058266878128052, 0.0, 0.0, 0.8886988162994385, 0.0, 0.8952760696411133, 0.0, 0.9324182868003845, 0.9428428411483765, 0.8749940395355225, 0.9303081631660461, 0.8916181325912476, 0.0, 0.0, 0.9383559226989746, 0.8649318814277649, 0.8875401616096497, 0.0, 0.8904289603233337, 0.0, 0.0, 0.9058986306190491, 0.0, 0.0, 0.0, 0.9238861203193665, 0.844862163066864, 0.0, 0.9049930572509766, 0.0, 0.0, 0.0, 0.8884814977645874, 0.0, 0.950259268283844], 'recall': [0.8527296781539917, 0.9074962139129639, 0.0, 0.0, 0.8695670962333679, 0.0, 0.8765427470207214, 0.0, 0.8609092831611633, 0.8616518378257751, 0.8861052393913269, 0.8705592155456543, 0.893672525882721, 0.8637343645095825, 0.0, 0.0, 0.8663536310195923, 0.0, 0.8676398992538452, 0.0, 0.9015766978263855, 0.8690415024757385, 0.8647665977478027, 0.8757278323173523, 0.0, 0.0, 0.0, 0.876862108707428, 0.8512600064277649, 0.0, 0.8827943801879883, 0.8636014461517334, 0.0, 0.0, 0.8927298784255981, 0.8633140325546265, 0.0, 0.8432349562644958, 0.0, 0.8488699197769165, 0.9072322845458984, 0.0, 0.8460506200790405, 0.8666802644729614, 0.0, 0.0, 0.0, 0.8947888612747192, 0.87111896276474, 0.0, 0.8794275522232056, 0.8699379563331604, 0.8704975247383118, 0.8579683303833008, 0.0, 0.8470050692558289, 0.0, 0.8723961114883423, 0.8828433752059937, 0.0, 0.8810412883758545, 0.8320557475090027, 0.8725265264511108, 0.0, 0.8714341521263123, 0.0, 0.0, 0.0, 0.8988475203514099, 0.0, 0.0, 0.0, 0.8686721920967102, 0.0, 0.8484339118003845, 0.877432107925415, 0.0, 0.0, 0.8437725901603699, 0.0, 0.8868192434310913, 0.0, 0.8919705152511597, 0.865478515625, 0.8780840635299683, 0.8753795623779297, 0.8415331244468689, 0.0, 0.0, 0.9014556407928467, 0.8395147323608398, 0.8648468852043152, 0.0, 0.8463851809501648, 0.0, 0.0, 0.8711459040641785, 0.0, 0.0, 0.0, 0.8653390407562256, 0.8389106392860413, 0.0, 0.8580697774887085, 0.0, 0.0, 0.0, 0.8635730743408203, 0.0, 0.8831084966659546], 'f1': [0.873304545879364, 0.9135304689407349, 0.0, 0.0, 0.8852996230125427, 0.0, 0.874957263469696, 0.0, 0.8887649178504944, 0.8681562542915344, 0.8965731263160706, 0.8816917538642883, 0.9027130603790283, 0.8996334671974182, 0.0, 0.0, 0.8836406469345093, 0.0, 0.8697514533996582, 0.0, 0.9110695123672485, 0.878558337688446, 0.8685341477394104, 0.8850388526916504, 0.0, 0.0, 0.0, 0.8873293399810791, 0.8713411092758179, 0.0, 0.9159566164016724, 0.8771646022796631, 0.0, 0.0, 0.8979202508926392, 0.868763267993927, 0.0, 0.8652860522270203, 0.0, 0.8642804026603699, 0.9182714819908142, 0.0, 0.867302656173706, 0.8717877268791199, 0.0, 0.0, 0.0, 0.9187203049659729, 0.8764496445655823, 0.0, 0.8960287570953369, 0.8854314684867859, 0.885257363319397, 0.8876227736473083, 0.0, 0.8609817624092102, 0.0, 0.8868322372436523, 0.886119544506073, 0.0, 0.8990501165390015, 0.8532812595367432, 0.899602472782135, 0.0, 0.8872713446617126, 0.0, 0.0, 0.0, 0.9160148501396179, 0.0, 0.0, 0.0, 0.8854514360427856, 0.0, 0.883619487285614, 0.891403317451477, 0.0, 0.0, 0.8656531572341919, 0.0, 0.891027569770813, 0.0, 0.9117460250854492, 0.9025057554244995, 0.8765363097190857, 0.9020084738731384, 0.8658519983291626, 0.0, 0.0, 0.9195356965065002, 0.852033793926239, 0.8760465979576111, 0.0, 0.8678486347198486, 0.0, 0.0, 0.8881824612617493, 0.0, 0.0, 0.0, 0.8936547040939331, 0.8418758511543274, 0.0, 0.8809069991111755, 0.0, 0.0, 0.0, 0.875850260257721, 0.0, 0.915454089641571], 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.41.2)'}",{'meteor': 0.13590265544642885},20.0




tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


rouge <class 'dict'> {'rouge1': 0.21574205572428873, 'rouge2': 0.08510464895182768, 'rougeL': 0.18588121236892674, 'rougeLsum': 0.18904663379694484}
bertscore <class 'dict'> {'precision': [0.8948967456817627, 0.9196454286575317, 0.0, 0.0, 0.9016120433807373, 0.0, 0.8733774423599243, 0.0, 0.9184833765029907, 0.8747596740722656, 0.9072914123535156, 0.89311283826828, 0.9119383692741394, 0.938646137714386, 0.0, 0.0, 0.9016316533088684, 0.0, 0.8718732595443726, 0.0, 0.9207643866539001, 0.8882859945297241, 0.8723347187042236, 0.8945499658584595, 0.0, 0.0, 0.0, 0.8980494737625122, 0.8923925757408142, 0.0, 0.9517074823379517, 0.8911606073379517, 0.0, 0.0, 0.9031713008880615, 0.8742817640304565, 0.0, 0.8885214328765869, 0.0, 0.8802607655525208, 0.9295825958251953, 0.0, 0.8896499276161194, 0.8769556283950806, 0.0, 0.0, 0.0, 0.9439671039581299, 0.881846010684967, 0.0, 0.9132688045501709, 0.9014869332313538, 0.9005262851715088, 0.9194005131721497, 0.0, 0.8754273653030396, 0.0, 0.9017542600631714, 

Non-default generation parameters: {'max_length': 62, 'min_length': 11, 'early_stopping': True, 'num_beams': 6, 'length_penalty': 0.5, 'no_repeat_ngram_size': 3, 'forced_eos_token_id': 2}
[34m[1mwandb[0m: Adding directory to artifact (./trained-distilbart-abs-2106/checkpoint-110)... Done. 38.9s
Non-default generation parameters: {'max_length': 62, 'min_length': 11, 'early_stopping': True, 'num_beams': 6, 'length_penalty': 0.5, 'no_repeat_ngram_size': 3, 'forced_eos_token_id': 2}


Time it took for training: 0:03:13.691468


In [None]:
# trainer.push_to_hub()

In [79]:
tokenizer.decode(tokenized_tweetsumm_abs['train'][5]['input_ids'], skip_special_tokens=False)

"<s>summarize: <USER> @115802 @AirAsiaSupport doesn’t seem like a customer’s time/money is of value to you! just because you are a low-cost carrier!? Nvr again!! <SYSTEM> @366867 Sorry for the delay, Saim. Flight Change is subject to Change Fee (per person/per flight) at <URL> plus &gt;&gt; @366867 &gt;&gt;Fare Difference, so you only have to pay for those. If you think it's cheaper to make a new booking, you may consider so.-Floi <USER> @AirAsiaSupport bt how can the rescheduling charges be more than the ticket charges. <SYSTEM> @366867 Really sorry as flight change is subject to change fee +  fare difference accordingly.Thanks - Ed <USER> @AirAsiaSupport i will go ahead with fresh booking bt will u move my baggage n food to the new pnr? <SYSTEM> @366867 Hi Saim, We are sorry as the add ons cannot be transferred to another booking. Thanks - Khairul <USER> @AirAsiaSupport @115802 @121276 never seen somebody sooo least bothered about customer... feel cheated.. and feel robbed... @AirAsi

In [None]:
# eos_start_token_bart = tokenized_tweetsumm_abs['train'][5]['input_ids'][0]
# print("End of sentence token for starting generating summaries with BART: ", eos_start_token_bart)