# Abstractive summaries - Train DistilBART on TWEETSUMM dataset

In [None]:
import json, re
from huggingface_hub import notebook_login
import pandas as pd
import numpy as np
import os, time, datetime

try:
    from datasets import load_dataset, Dataset, DatasetDict
except:
    !pip install datasets
    from datasets import load_dataset, Dataset, DatasetDict

try:
    import accelerate
except:
    !pip install -U 'accelerate==0.27.2'
    import accelerate


import transformers
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, pipeline, set_seed
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import GenerationConfig

try:
    import wandb
except:
    !pip install wandb

print(transformers.__version__, accelerate.__version__)


In [None]:
ds_dir = ""
try:
    HF_TOKEN =  os.environ['HF_TOKEN']
except:
    HF_TOKEN = ""

if 'google.colab' in str(get_ipython()):
    print("Running on Colab")
    from google.colab import drive, userdata
    drive.mount('/content/drive')
    HF_TOKEN = userdata.get('HF_TOKEN')
elif os.environ['KAGGLE_KERNEL_RUN_TYPE']:
    from kaggle_secrets import UserSecretsClient
    print("Running on Kaggle")
    ds_dir = "/kaggle/input/tweet-data-2106-1512/"
    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
    WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY


In [None]:
set_seed(17)

os.environ["WANDB_PROJECT"] = "aiml-thesis-train"
wandb.init(settings=wandb.Settings(start_method="thread"))

In [None]:
from huggingface_hub import login
login(token=HF_TOKEN)

## Load data

In [None]:
ds_dir="/kaggle/input/bertdata2207/"

In [None]:

train_df_temp = pd.read_csv(ds_dir + f"dials_abs_2607_1312_train_spc.csv", names=['conv_id','dialogue','summary'], encoding='utf-8', dtype={'conv_id':'string', 'dialogue':'string', 'summary': 'string'})
train_df_temp.convert_dtypes()
train_df_temp.drop(columns=['conv_id'], inplace=True)
train_df_temp.reset_index(drop=True, inplace=True)

val_df_temp = pd.read_csv(ds_dir + "dials_abs_2607_1312_valid_spc.csv", names=['conv_id','dialogue','summary'], encoding='utf-8', dtype={'conv_id':'string', 'dialogue':'string', 'summary': 'string'})
val_df_temp.convert_dtypes()
val_df_temp.drop(columns=['conv_id'], inplace=True)
val_df_temp.reset_index(drop=True, inplace=True)

test_df_temp = pd.read_csv(ds_dir + "dials_abs_2607_1312_train_spc.csv", names=['conv_id','dialogue','summary'], encoding='utf-8', dtype={'conv_id':'string', 'dialogue':'string', 'summary': 'string'})
test_df_temp.convert_dtypes()
# test_df_temp.drop(columns=['conv_id'], inplace=True)
test_df_temp.reset_index(drop=True, inplace=True)

print(train_df_temp.dtypes)
print(train_df_temp.head())

In [None]:
tweetsum_train_val_abs = DatasetDict(
    {
        'train': Dataset.from_pandas(train_df_temp[0:30]),
        'validation': Dataset.from_pandas(val_df_temp[0:10]),
        'test': Dataset.from_pandas(test_df_temp[0:10])
    }
)

In [None]:
# Source: https://huggingface.co/docs/transformers/en/tasks/summarization

def preprocess_function(examples):
    prefix = "summarize: "
    inputs = [prefix + dial for dial in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True) # same params as tweetsumm paper
    labels = tokenizer(text_target=examples["summary"], max_length=80, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
checkpoint_bart = "sshleifer/distilbart-xsum-12-6"

In [None]:
bart_tokenizer = AutoTokenizer.from_pretrained(checkpoint_bart)
tokenizer = bart_tokenizer

In [None]:
# encc = bart_tokenizer.encode_plus("qwerty Q.\nQwerty Q. \n Qwerty x.") # train_df_temp.iloc[5,0][:320])
# print(encc)
# print(bart_tokenizer.decode(encc['input_ids'], skip_special_tokens=False))
# for i in sorted(set(encc['input_ids'])):
#     print(i, repr(bart_tokenizer.decode(i, skip_special_tokens=False)))
# tokenizer = bart_tokenizer
# tokenized_tweetsumm_abs = tweetsum_train_val_abs.map(preprocess_function, batched=True)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint_bart)

## Setup Training Evaluation

In [None]:
try:
    import evaluate
    rouge = evaluate.load("rouge")
    meteor = evaluate.load("meteor")
    bertscore = evaluate.load("bertscore")
except:
    !pip install evaluate nltk rouge_score bert_score
    !pip install -U nltk
    import evaluate
    rouge = evaluate.load("rouge")
    meteor = evaluate.load("meteor")
    bertscore = evaluate.load("bertscore")

In [None]:
# import numpy as np


# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
#     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
#     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
#     # result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
#     result = {
#       'rouge': rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True),
#       'bertscore': bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"),
#       'meteor': meteor.compute(predictions=decoded_preds, references=decoded_labels),
#     }
#     prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
#     result["gen_len"] = np.mean(prediction_lens)
#     print(json.dumps(result, indent=2))
#     return {k: round(v, 4) if type(v) != list else v for k, v in result.items()}

In [None]:
# arrr = [0,1,2,3,4,5,6,7]
# valsss = ['a','b','c','d','e','f','g','h']

# kwkwk = {f"id-{x}": vall for x, vall in enumerate(valsss)}
# origindict = {'alpha':5, **kwkwk}
# print(origindict)

In [None]:
def compute_metrics_abs(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_scores = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    bert_scores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
    bert_scores.pop('hashcode')
    result = {
      **{f"rouge/{k}": round(v, 4) for k,v in rouge_scores.items()},
      **{f"bertscore/bertscore-{k}": round(np.mean(v), 4) for k,v in bert_scores.items()},
      'meteor': round(meteor.compute(predictions=decoded_preds, references=decoded_labels)['meteor'], 4),
    }
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return result


In [None]:
def compute_test_metrics_abs(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_scores = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    bert_scores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
    bert_scores.pop('hashcode')
    result = {
      **{f"test/rouge/{k}": round(v, 4) for k,v in rouge_scores.items()},
      **{f"test/bertscore/bertscore-{k}": round(np.mean(v), 4) for k,v in bert_scores.items()},
      'test/meteor': round(meteor.compute(predictions=decoded_preds, references=decoded_labels)['meteor'], 4),
    }
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["test/gen_len"] = np.mean(prediction_lens)
    return result

In [None]:
# print(json.dumps(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"), indent=2))
# bertscores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
# np.mean(bertscores)
# 'rouge': rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True,use_aggregator=False),
# wandb.log({f"losses/loss-{ii}": loss for ii, loss in enumerate(losses)})
# rouge_scores = {f"rouge/rougerouge-id-{i}": score for i, score in enumerate(rouge.compute(predictions=decoded_preds,
#                                                                                references=decoded_labels,
#                                                                                use_stemmer=True,
#                                                                                use_aggregator=True))}
# bert_scores = {f"bertscore/bert-id-{i}": score for i, score in enumerate(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"))},

#   for k,v in result.items():
#     print(k, type(v), v)
# Bug fix source: https://discuss.huggingface.co/t/bug-in-summarization-tutorial/60566/2
# {k: round(v, 4) if type(v) != list else v for k, v in result.items()}

#       'rouge1': round(rouge_scores['rouge1'], 4),
#       'rouge2': round(rouge_scores['rouge2'], 4),
#       'rougeL': round(rouge_scores['rougeL'], 4),
#       'rougeLsum': round(rouge_scores['rougeLsum'], 4),
#       'bertscore/bertscore-precision': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['precision']),
#       'bertscore/bertscore-recall': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['recall']),
#       'bertscore/bertscore-f1': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['f1']),

## Train

In [None]:
# print(json.dumps(), indent=2)
# blah = bertscore.compute(predictions=['a', 'blue', 'car'], references=['a', 'black', 'car'], lang="en")
# for b,c in blah.items():
#     print(c)
#     print(np.round(sum(c)/len(c), 4))

In [None]:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(checkpoint_bart)
config.max_length = 80
config.min_length = 10
print(config)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_bart, config=config)

In [None]:
# Print config
print("Tokenizer config:", tokenizer.init_kwargs)
print("Model config:", str(model.config).replace('\n',''))

In [None]:
current_time = datetime.datetime.now().strftime("%d%m-%H%M")
print(current_time)
run_name_model = f"distilbart-abs-{current_time}"
wandb.run.name = run_name_model
wandb.run.save()

gen_config = GenerationConfig(max_source_length=512,bos_token_id=0)
gen_config.save_pretrained("roequitz/distilbart-abs-tweetsumm", push_to_hub=True)

training_args = Seq2SeqTrainingArguments(
    output_dir=f"trained-distilbart-abs-{current_time[0:4]}",
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=10,
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
    generation_max_length=80,
    generation_config=gen_config,
    push_to_hub=False,
    report_to="wandb",
    run_name=run_name_model
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_tweetsumm_abs["train"],
    eval_dataset=tokenized_tweetsumm_abs["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics_abs,
)

training_start = time.time()
trainer.train()
training_end = time.time()
print("Time it took for training:", str(datetime.timedelta(seconds=(training_end-training_start))))

In [None]:
trainer.push_to_hub()

In [None]:
test_predictions = trainer.predict(tokenized_tweetsumm_abs["test"])
test_df_temp['predictions'] = test_predictions['predictions']
test_df_temp['metrics'] = test_predictions['metrics']

In [None]:
test_name = ds_dir + f"test_preds_metrics_{currenttime[0:2]}_{current_time[2:4]}_bart.csv"
test_df_temp.to_csv(test_name, index=False, header=False, quoting=csv.QUOTE_ALL)
wandb.log_artifact(test_name, results)
wandb.log(test_predictions['metrics'])
wandb.finish()