# Abstractive summaries - Train DistilBART on TWEETSUMM dataset

In [1]:
import json, re
from huggingface_hub import notebook_login
import pandas as pd
import numpy as np
import os, time, datetime

try:
    from datasets import load_dataset
except:
    !pip install datasets
    from datasets import load_dataset

try:
    import accelerate
except:
    !pip install -U 'accelerate==0.27.2'
    import accelerate


import transformers
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, pipeline, set_seed
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import GenerationConfig

try:
    import wandb
except:
    !pip install wandb

print(transformers.__version__, accelerate.__version__)


2024-07-01 17:40:11.221979: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-01 17:40:11.222091: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-01 17:40:11.364817: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


4.41.2 0.30.1


In [2]:
# try:
#   import transformers
# except:
#   !pip install -U transformers[torch]
#   import transformers

In [3]:
# for x in dir(transformers):
#     if "torch" in x:
#         print(x)
        
# print(transformers.is_tf_available())

In [4]:
ds_dir = ""
try:
    HF_TOKEN =  os.environ['HF_TOKEN']
except:
    HF_TOKEN = ""

if 'google.colab' in str(get_ipython()):
  print("In Colab")
  from google.colab import drive, userdata
  drive.mount('/content/drive')
  HF_TOKEN = userdata.get('HF_TOKEN')
elif os.environ['KAGGLE_KERNEL_RUN_TYPE']:
  from kaggle_secrets import UserSecretsClient
  print("In Kaggle")
  ds_dir = "/kaggle/input/tweet-data-2106-1512/"
  user_secrets = UserSecretsClient()
  HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
  WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
  os.environ['WANDB_API_KEY'] = WANDB_API_KEY


In Kaggle


In [5]:
set_seed(17)

os.environ["WANDB_PROJECT"] = "aiml-thesis-train"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

wandb.init(settings=wandb.Settings(start_method="thread"))

[34m[1mwandb[0m: Currently logged in as: [33mdawidk5[0m ([33mdawidk5ul[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
from huggingface_hub import login
login(token=HF_TOKEN)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


## Load data

In [7]:
train_df_temp = pd.read_feather(ds_dir + "data/train_dial_abs_noex_noco_2006.feather")
train_df_temp.drop(columns=['index', 'company'], inplace=True)

In [8]:
# lengths_train_summaries = [len(row) for row in tokenized_tweetsumm_abs['train']['input_ids']]
# print("Abstractive summaries training lengths[mean,max,min]:", np.mean(lengths_train_summaries), np.max(lengths_train_summaries), np.min(lengths_train_summaries))

In [9]:
val_df_temp = pd.read_feather(ds_dir + "data/val_dial_abs_noex_noco_2006.feather")
val_df_temp.drop(columns=['index', 'company'], inplace=True)

In [10]:
# Only for testing statistics display
# train_df_temp[0:100]

Unnamed: 0,dialogue,summary
0,<USER> So neither my iPhone nor my Apple Watch...,Customer enquired about his Iphone and Apple w...
1,<USER> @115850 hi team! i m planning to get Ap...,Customer is eager to know about the replacemen...
2,<USER> @AskAmex Where do I write to address a ...,Signed up for an AmexCard with Delta but it di...
3,"<USER> @AmazonHelp @115821 Wow, expected 4 pac...",The customer have a problem. The agent is very...
4,<USER> @GWRHelp I'd rather you spent some time...,Customer cannot purchase a train ticket on the...
...,...,...
95,<USER> @SW_Help the 2120 Salisbury train from ...,Customer is asking that salisbury train from w...
96,<USER> . @Tesco used to do a box of chocolate ...,Customer is complaining that they can't find a...
97,<USER> @Delta no response for me regarding my ...,Customer is complaining that his luggage has b...
98,<USER> @117153 I've downloaded the latest app ...,Customer has some issue with logging into the ...


In [11]:
from datasets import Dataset, DatasetDict

tweetsum_train_val_abs = DatasetDict(
 {
 'train': Dataset.from_pandas(train_df_temp),
 'validation': Dataset.from_pandas(val_df_temp)
 }
)

In [12]:
tweetsum_train_val_abs['train'][10]

{'dialogue': "<USER> Bought these biscuits a couple of weeks ago, only just opened and looked at the best before date... @49975 <SYSTEM> @393926 We can't see a picture of the best before date or biscuits attached to your tweets here. Please can you tweet or DM it to us again? Thanks. <URL> <USER> @marksandspencer  <URL> @marksandspencer  <URL> @marksandspencer  <URL> <SYSTEM> @393926 We'd certainly like to take a closer look into this. Please DM us a picture of your full receipt. <URL> <USER> @marksandspencer Don’t have a receipt as they were bought for my grandparents 3 weeks ago and have only just realised the date on them @marksandspencer The biscuits were purchased at the Marks and Spencer’s Store at the Ricoh Arena, Coventry <SYSTEM> @393926 Did yo use a sparks card on your transaction, Cian? <USER> @marksandspencer Yes <SYSTEM> @393926 Hi Cian. I'm really sorry to see this, especially when it was such a lovely gesture too! No worries though - we got your back ;) 1/3 @393926 Is th

In [13]:
# Source: https://huggingface.co/docs/transformers/en/tasks/summarization

def preprocess_function(examples):
  prefix = "summarize: "
  inputs = [prefix + dial for dial in examples["dialogue"]]
  model_inputs = tokenizer(inputs, max_length=512, truncation=True) # same params as tweetsumm paper
  labels = tokenizer(text_target=examples["summary"], max_length=80, truncation=True)
  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [14]:
checkpoint_bart = "sshleifer/distilbart-xsum-12-6"

In [15]:
# t5_tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
bart_tokenizer = AutoTokenizer.from_pretrained(checkpoint_bart)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.59k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

In [16]:
tokenizer = bart_tokenizer
tokenized_tweetsumm_abs = tweetsum_train_val_abs.map(preprocess_function, batched=True)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

In [17]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint_bart)

In [18]:
print(json.dumps(tokenized_tweetsumm_abs['train'][5], indent=2))

{
  "dialogue": "<USER> @115802 @AirAsiaSupport doesn\u2019t seem like a customer\u2019s time/money is of value to you! just because you are a low-cost carrier!? Nvr again! ! <SYSTEM> @366867 Sorry for the delay, Saim. Flight Change is subject to Change Fee (per person/per flight) at <URL> plus &gt;&gt; @366867 &gt;&gt;Fare Difference, so you only have to pay for those. If you think it's cheaper to make a new booking, you may consider so.-Floi <USER> @AirAsiaSupport bt how can the rescheduling charges be more than the ticket charges. <SYSTEM> @366867 Really sorry as flight change is subject to change fee +  fare difference accordingly.Thanks - Ed <USER> @AirAsiaSupport i will go ahead with fresh booking bt will u move my baggage n food to the new pnr? <SYSTEM> @366867 Hi Saim , We are sorry as the add ons cannot be transferred to another booking. Thanks - Khairul <USER> @AirAsiaSupport @115802 @121276 never seen somebody sooo least bothered about customer... feel cheated.. and feel rob

## Setup Training Evaluation

In [19]:
!pip install evaluate nltk rouge_score bert_score

  pid, fd = os.forkpty()
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl.metadata (9.3 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=c958f5e56b2af46eb1a0a132877f5b421b946c5795341b64730022c7b3a970bf
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532

In [20]:
!pip install -U nltk

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting nltk
  Downloading nltk-3.8.1-py3-none-any.whl.metadata (2.8 kB)
Downloading nltk-3.8.1-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: nltk
  Attempting uninstall: nltk
    Found existing installation: nltk 3.2.4
    Uninstalling nltk-3.2.4:
      Successfully uninstalled nltk-3.2.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
preprocessing 0.1.13 requires nltk==3.2.4, but you have nltk 3.8.1 which is incompatible.[0m[31m
[0mSuccessfully installed nltk-3.8.1


In [21]:
import evaluate

rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
bertscore = evaluate.load("bertscore")

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.93k [00:00<?, ?B/s]

[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...


Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

In [22]:
# import numpy as np


# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
#     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
#     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
#     # result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
#     result = {
#       'rouge': rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True),
#       'bertscore': bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"),
#       'meteor': meteor.compute(predictions=decoded_preds, references=decoded_labels),
#     }
#     prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
#     result["gen_len"] = np.mean(prediction_lens)
#     print(json.dumps(result, indent=2))
#     return {k: round(v, 4) if type(v) != list else v for k, v in result.items()}

In [23]:
# arrr = [0,1,2,3,4,5,6,7]
# valsss = ['a','b','c','d','e','f','g','h']

# kwkwk = {f"id-{x}": vall for x, vall in enumerate(valsss)}
# origindict = {'alpha':5, **kwkwk}
# print(origindict)

In [46]:
def compute_metrics_abs(eval_pred):
  predictions, labels = eval_pred
  decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  
  # bertscores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
  # np.mean(bertscores)
  # 'rouge': rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True,use_aggregator=False),
  # wandb.log({f"losses/loss-{ii}": loss for ii, loss in enumerate(losses)})
  # rouge_scores = {f"rouge/rougerouge-id-{i}": score for i, score in enumerate(rouge.compute(predictions=decoded_preds,
  #                                                                                references=decoded_labels,
  #                                                                                use_stemmer=True,
  #                                                                                use_aggregator=True))}
  # bert_scores = {f"bertscore/bert-id-{i}": score for i, score in enumerate(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"))},
  rouge_scores = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
  bert_scores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
  bert_scores.pop('hashcode')
  result = {
      **{f"rouge/{k}": round(v, 4) for k,v in rouge_scores.items()},
      **{f"bertscore/bertscore-{k}": round(np.mean(v), 4) for k,v in bert_scores.items()},
      'meteor': round(meteor.compute(predictions=decoded_preds, references=decoded_labels)['meteor'], 4),
  }
  #   for k,v in result.items():
  #     print(k, type(v), v)
  prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
  result["gen_len"] = np.mean(prediction_lens)
  # print(json.dumps(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"), indent=2))
  return result # {k: round(v, 4) if type(v) != list else v for k, v in result.items()}

#       'rouge1': round(rouge_scores['rouge1'], 4),
#       'rouge2': round(rouge_scores['rouge2'], 4),
#       'rougeL': round(rouge_scores['rougeL'], 4),
#       'rougeLsum': round(rouge_scores['rougeLsum'], 4),
#       'bertscore/bertscore-precision': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['precision']),
#       'bertscore/bertscore-recall': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['recall']),
#       'bertscore/bertscore-f1': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['f1']),

## Train

In [47]:
# print(json.dumps(), indent=2)
# blah = bertscore.compute(predictions=['a', 'blue', 'car'], references=['a', 'black', 'car'], lang="en")
# for b,c in blah.items():
#     print(c)
#     print(np.round(sum(c)/len(c), 4))

In [43]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_bart)

In [44]:
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [48]:
current_time = datetime.datetime.now().strftime("%d%m-%H%M")
print(current_time)
run_name_model = f"distilbart-abs-{current_time}-tesststats"
wandb.run.name = run_name_model
wandb.run.save()

gen_config = GenerationConfig(max_source_length=512,min_new_tokens=80,bos_token_id=0)

training_args = Seq2SeqTrainingArguments(
    output_dir=f"trained-distilbart-abs-{current_time[0:4]}",
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=10,
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
    generation_config=gen_config,
    push_to_hub=False,
    report_to="wandb",
    run_name=run_name_model
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_tweetsumm_abs["train"],
    eval_dataset=tokenized_tweetsumm_abs["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics_abs,
)

training_start = time.time()
trainer.train()
training_end = time.time()
print("Time it took for training:", str(datetime.timedelta(seconds=(training_end-training_start))))

0107-1806


Epoch,Training Loss,Validation Loss,Rouge/rouge1,Rouge/rouge2,Rouge/rougel,Rouge/rougelsum,Bertscore/bertscore-precision,Bertscore/bertscore-recall,Bertscore/bertscore-f1,Meteor,Gen Len
1,2.7343,2.713014,0.1566,0.0574,0.1347,0.1323,0.4053,0.3885,0.3967,0.0979,20.0


Non-default generation parameters: {'max_length': 62, 'min_length': 11, 'early_stopping': True, 'num_beams': 6, 'length_penalty': 0.5, 'no_repeat_ngram_size': 3, 'forced_eos_token_id': 2}
[34m[1mwandb[0m: Adding directory to artifact (./trained-distilbart-abs-0107/checkpoint-25)... Done. 25.5s
Non-default generation parameters: {'max_length': 62, 'min_length': 11, 'early_stopping': True, 'num_beams': 6, 'length_penalty': 0.5, 'no_repeat_ngram_size': 3, 'forced_eos_token_id': 2}


Time it took for training: 0:01:10.464312


In [None]:
# trainer.push_to_hub()

In [None]:
tokenizer.decode(tokenized_tweetsumm_abs['train'][5]['input_ids'], skip_special_tokens=False)

In [None]:
# eos_start_token_bart = tokenized_tweetsumm_abs['train'][5]['input_ids'][0]
# print("End of sentence token for starting generating summaries with BART: ", eos_start_token_bart)