# Abstractive summaries - Train DistilBART on TWEETSUMM dataset

In [28]:
import json, re
from huggingface_hub import notebook_login
import pandas as pd
import numpy as np
import os, time, datetime

try:
    from datasets import load_dataset, Dataset, DatasetDict
except:
    !pip install datasets
    from datasets import load_dataset, Dataset, DatasetDict

try:
    import accelerate
except:
    !pip install -U 'accelerate==0.27.2'
    import accelerate


import transformers
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, pipeline, set_seed, BartTokenizer
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, BartForConditionalGeneration
from transformers import GenerationConfig

try:
    import wandb
except:
    !pip install wandb

print(transformers.__version__, accelerate.__version__)


4.44.0 0.33.0


In [29]:
ds_dir = ""
try:
    HF_TOKEN =  os.environ['HF_TOKEN']
except:
    HF_TOKEN = ""

if 'google.colab' in str(get_ipython()):
    print("Running on Colab")
    from google.colab import drive, userdata
    drive.mount('/content/drive')
    HF_TOKEN = userdata.get('HF_TOKEN')
elif os.environ['KAGGLE_KERNEL_RUN_TYPE']:
    from kaggle_secrets import UserSecretsClient
    print("Running on Kaggle")
    ds_dir = "/kaggle/input/tweet-data-2106-1512/"
    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
    WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY


Running on Kaggle


In [30]:
set_seed(17)

os.environ["WANDB_PROJECT"] = "aiml-thesis-train"
wandb.init(settings=wandb.Settings(start_method="thread"))

VBox(children=(Label(value='0.032 MB of 0.032 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [31]:
from huggingface_hub import login
login(token=HF_TOKEN)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


## Load data

In [32]:
ds_dir="/kaggle/input/bertdata2207/"
checkpoint_bart = "sshleifer/distilbart-xsum-12-6"

In [33]:
train_df_temp = pd.read_csv(ds_dir + f"dials_abs_2607_1312_train_spc.csv", names=['conv_id','dialogue','summary'], encoding='utf-8', dtype={'conv_id':'string', 'dialogue':'string', 'summary': 'string'})
train_df_temp.convert_dtypes()
train_df_temp.drop(columns=['conv_id'], inplace=True)
train_df_temp.reset_index(drop=True, inplace=True)

val_df_temp = pd.read_csv(ds_dir + "dials_abs_2607_1312_valid_spc.csv", names=['conv_id','dialogue','summary'], encoding='utf-8', dtype={'conv_id':'string', 'dialogue':'string', 'summary': 'string'})
val_df_temp.convert_dtypes()
val_df_temp.drop(columns=['conv_id'], inplace=True)
val_df_temp.reset_index(drop=True, inplace=True)

test_df_temp = pd.read_csv(ds_dir + "dials_abs_2607_1312_test_spc.csv", names=['conv_id','dialogue','summary'], encoding='utf-8', dtype={'conv_id':'string', 'dialogue':'string', 'summary': 'string'})
test_df_temp.convert_dtypes()
# test_df_temp.drop(columns=['conv_id'], inplace=True)
test_df_temp.reset_index(drop=True, inplace=True)

print(train_df_temp.dtypes)
print(train_df_temp.head())

dialogue    string[python]
summary     string[python]
dtype: object
                                            dialogue  \
0  Customer: So neither my iPhone nor my Apple Wa...   
1  Customer: @115850 hi team! i m planning to get...   
2  Customer: @AskAmex Where do I write to address...   
3  Customer: @AmazonHelp @115821 Wow, expected 4 ...   
4  Customer: @GWRHelp I'd rather you spent some t...   

                                             summary  
0  Customer enquired about his Iphone and Apple w...  
1  Customer is eager to know about the replacemen...  
2  Signed up for an AmexCard with Delta but it di...  
3  The customer have a problem. The agent is very...  
4  Customer cannot purchase a train ticket on the...  


In [34]:
tweetsumm_abs = DatasetDict(
    {
        'train': Dataset.from_pandas(train_df_temp),
        'validation': Dataset.from_pandas(val_df_temp),
        'test': Dataset.from_pandas(test_df_temp)
    }
)

In [35]:
bart_tokenizer = BartTokenizer.from_pretrained(checkpoint_bart)
tokenizer = bart_tokenizer



In [36]:
# arrr = [0,1,2,3,4,5,6,7]
# valsss = ['a','b','c','d','e','f','g','h']

# kwkwk = {f"id-{x}": vall for x, vall in enumerate(valsss)}
# origindict = {'alpha':5, **kwkwk}
# print(origindict)

In [37]:
# Source: https://huggingface.co/docs/transformers/en/tasks/summarization

def preprocess_function(examples):
    prefix = "summarize: "
    inputs = [str(prefix) + str(dial) for dial in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True) # same params as tweetsumm paper
    labels = tokenizer(text_target=examples["summary"], max_length=80, truncation=True)
    # print(inputs, model_inputs['input_ids'])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [38]:
# encc = bart_tokenizer.encode_plus("qwerty Q.\nQwerty Q. \n Qwerty x.") # train_df_temp.iloc[5,0][:320])
# print(encc)
# print(bart_tokenizer.decode(encc['input_ids'], skip_special_tokens=False))
# for i in sorted(set(encc['input_ids'])):
#     print(i, repr(bart_tokenizer.decode(i, skip_special_tokens=False)))
# tokenizer = bart_tokenizer
# tokenized_tweetsumm_abs = tweetsum_train_val_abs.map(preprocess_function, batched=True)

In [39]:
tokenized_tweetsumm_abs = tweetsumm_abs.map(preprocess_function, batched=True)
print(tokenized_tweetsumm_abs["train"][0])

Map:   0%|          | 0/867 [00:00<?, ? examples/s]

Map:   0%|          | 0/110 [00:00<?, ? examples/s]

Map:   0%|          | 0/109 [00:00<?, ? examples/s]

{'dialogue': 'Customer: So neither my iPhone nor my Apple Watch are recording my steps/activity, and Health doesn’t recognise either source anymore for some reason. Any ideas? https://t.co/m9DPQbkftD\nCustomer: @AppleSupport please read the above.\nAgent: @135060 Let’s investigate this together. To start, can you tell us the software versions your iPhone and Apple Watch are running currently?\nCustomer: @AppleSupport My iPhone is on 11.1.2, and my watch is on 4.1.\nAgent: @135060 Thank you. Have you tried restarting both devices since this started happening?\nCustomer: @AppleSupport I’ve restarted both, also un-paired then re-paired the watch.\nAgent: @135060 Got it. When did you first notice that the two devices were not talking to each other. Do the two devices communicate through other apps such as Messages?\nCustomer: @AppleSupport Yes, everything seems fine, it’s just Health and activity.\nAgent: @135060 Let’s move to DM and look into this a bit more. When reaching out in DM, let 

In [40]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint_bart)

## Setup Training Evaluation

In [41]:
try:
    import evaluate
    rouge = evaluate.load("rouge")
    meteor = evaluate.load("meteor")
    bertscore = evaluate.load("bertscore")
except:
    !pip install evaluate nltk rouge_score bert_score
    import evaluate
    rouge = evaluate.load("rouge")
    meteor = evaluate.load("meteor")
    bertscore = evaluate.load("bertscore")

[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [43]:
def compute_metrics_abs(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_scores = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    bert_scores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
    bert_scores.pop('hashcode')
    result = {
      **{f"rouge/{k}": round(v, 4) for k,v in rouge_scores.items()},
      **{f"bertscore/bertscore-{k}": round(np.mean(v), 4) for k,v in bert_scores.items()},
      'meteor': round(meteor.compute(predictions=decoded_preds, references=decoded_labels)['meteor'], 4),
    }
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return result


In [68]:
test_eval_data = {'input_ids': [0, 18581, 3916, 2072, 35, 19458, 35, 407, 5063, 127, 2733, 3486, 127, 1257, 3075, 32, 5492, 127, 2402, 73, 30280, 6, 8, 1309, 630, 17, 27, 90, 11865, 1169, 1300, 5988, 13, 103, 1219, 4, 5053, 2956, 116, 1205, 640, 90, 4, 876, 73, 119, 466, 5174, 1864, 428, 330, 2543, 495, 50118, 44799, 35, 787, 20770, 38873, 2540, 1166, 5, 1065, 4, 50118, 45443, 35, 787, 1558, 1096, 2466, 2780, 17, 27, 29, 4830, 42, 561, 4, 598, 386, 6, 64, 47, 1137, 201, 5, 2257, 7952, 110, 2733, 8, 1257, 3075, 32, 878, 855, 116, 50118, 44799, 35, 787, 20770, 38873, 1308, 2733, 16, 15, 365, 4, 134, 4, 176, 6, 8, 127, 1183, 16, 15, 204, 4, 134, 4, 50118, 45443, 35, 787, 1558, 1096, 2466, 3837, 47, 4, 6319, 47, 1381, 12721, 154, 258, 2110, 187, 42, 554, 2909, 116, 50118, 44799, 35, 787, 20770, 38873, 38, 17, 27, 548, 12721, 196, 258, 6, 67, 542, 12, 6709, 7651, 172, 769, 12, 6709, 7651, 5, 1183, 4, 50118, 45443, 35, 787, 1558, 1096, 2466, 8432, 24, 4, 520, 222, 47, 78, 3120, 14, 5, 80, 2110, 58, 45, 1686, 7, 349, 97, 4, 1832, 5, 80, 2110, 8469, 149, 97, 3798, 215, 25, 34692, 116, 50118, 44799, 35, 787, 20770, 38873, 3216, 6, 960, 1302, 2051, 6, 24, 17, 27, 29, 95, 1309, 8, 1940, 4, 50118, 45443, 35, 787, 1558, 1096, 2466, 2780, 17, 27, 29, 517, 7, 18695, 8, 356, 88, 42, 10, 828, 55, 4, 520, 3970, 66, 11, 18695, 6, 905, 201, 216, 77, 42, 78, 554, 2909, 2540, 4, 286, 1246, 6, 222, 24, 386, 71, 41, 2935, 50, 71, 15602, 10, 1402, 1553, 116, 1205, 640, 90, 4, 876, 73, 534, 14043, 1343, 791, 2036, 975, 642, 565, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [0, 44799, 17730, 7651, 59, 39, 38, 17283, 8, 1257, 1183, 61, 16, 45, 2018, 39, 143, 2402, 73, 30280, 8, 474, 1713, 4, 18497, 16, 1996, 7, 517, 7, 18695, 8, 356, 88, 24, 4, 2]}


In [54]:
data_collator

DataCollatorForSeq2Seq(tokenizer=BartTokenizer(name_or_path='sshleifer/distilbart-xsum-12-6', vocab_size=50265, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True),
}, model='sshleifer/distilbart-xsum-12-6', padding=True, max

In [None]:
data_collator([te_tokenised_batch])

In [64]:
for sample in tweetsumm_abs['train'].select(range(17,54))['dialogue']:
    print(dir(sample))
    break

['__add__', '__class__', '__contains__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getnewargs__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__mod__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmod__', '__rmul__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'capitalize', 'casefold', 'center', 'count', 'encode', 'endswith', 'expandtabs', 'find', 'format', 'format_map', 'index', 'isalnum', 'isalpha', 'isascii', 'isdecimal', 'isdigit', 'isidentifier', 'islower', 'isnumeric', 'isprintable', 'isspace', 'istitle', 'isupper', 'join', 'ljust', 'lower', 'lstrip', 'maketrans', 'partition', 'removeprefix', 'removesuffix', 'replace', 'rfind', 'rindex', 'rjust', 'rpartition', 'rsplit', 'rstrip', 'split', 'splitlines', 'startswith', 'strip', 'swapcase', 'title', 'translate', 'upper', 'zfill']


In [65]:
# test_eval_tok_batch = tokenizer(tweetsumm_abs['train'].select(range(17,21))['dialogue'])

In [70]:
test_eval_coll_batch = data_collator([test_eval_data])

In [71]:
compute_metrics_abs(test_eval_coll_batch)

ValueError: too many values to unpack (expected 2)

In [42]:
# import numpy as np


# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
#     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
#     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
#     # result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
#     result = {
#       'rouge': rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True),
#       'bertscore': bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"),
#       'meteor': meteor.compute(predictions=decoded_preds, references=decoded_labels),
#     }
#     prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
#     result["gen_len"] = np.mean(prediction_lens)
#     print(json.dumps(result, indent=2))
#     return {k: round(v, 4) if type(v) != list else v for k, v in result.items()}

In [44]:
def compute_test_metrics_abs(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_scores = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    bert_scores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
    bert_scores.pop('hashcode')
    result = {
      **{f"test/rouge/{k}": round(v, 4) for k,v in rouge_scores.items()},
      **{f"test/bertscore/bertscore-{k}": round(np.mean(v), 4) for k,v in bert_scores.items()},
      'test/meteor': round(meteor.compute(predictions=decoded_preds, references=decoded_labels)['meteor'], 4),
    }
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["test/gen_len"] = np.mean(prediction_lens)
    return result

In [45]:
# print(json.dumps(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"), indent=2))
# bertscores = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
# np.mean(bertscores)
# 'rouge': rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True,use_aggregator=False),
# wandb.log({f"losses/loss-{ii}": loss for ii, loss in enumerate(losses)})
# rouge_scores = {f"rouge/rougerouge-id-{i}": score for i, score in enumerate(rouge.compute(predictions=decoded_preds,
#                                                                                references=decoded_labels,
#                                                                                use_stemmer=True,
#                                                                                use_aggregator=True))}
# bert_scores = {f"bertscore/bert-id-{i}": score for i, score in enumerate(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en"))},

#   for k,v in result.items():
#     print(k, type(v), v)
# Bug fix source: https://discuss.huggingface.co/t/bug-in-summarization-tutorial/60566/2
# {k: round(v, 4) if type(v) != list else v for k, v in result.items()}

#       'rouge1': round(rouge_scores['rouge1'], 4),
#       'rouge2': round(rouge_scores['rouge2'], 4),
#       'rougeL': round(rouge_scores['rougeL'], 4),
#       'rougeLsum': round(rouge_scores['rougeLsum'], 4),
#       'bertscore/bertscore-precision': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['precision']),
#       'bertscore/bertscore-recall': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['recall']),
#       'bertscore/bertscore-f1': np.mean(bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")['f1']),

## Train

In [46]:
# print(json.dumps(), indent=2)
# blah = bertscore.compute(predictions=['a', 'blue', 'car'], references=['a', 'black', 'car'], lang="en")
# for b,c in blah.items():
#     print(c)
#     print(np.round(sum(c)/len(c), 4))

In [47]:
# from transformers import AutoConfig
# config = AutoConfig.from_pretrained(checkpoint_bart)
# config.max_length = 80
# config.min_length = 10
# print(config)

In [48]:
model = BartForConditionalGeneration.from_pretrained(checkpoint_bart) #, config=config)

In [49]:
# ppp = pipeline("summarization", checkpoint_bart, device='cuda')
# out_t = ppp(tokenized_tweetsumm_abs["test"][0]['dialogue'])

In [50]:
# print(out_t)

In [51]:
# out = model.generate(return_tensors='pt')
# print(out, type(out), type(out[0]))
# print(tokenizer.decode(out, skip_special_tokens=False))

In [52]:
# Print config
print("Tokenizer config:", tokenizer.init_kwargs)
print("Model config:", str(model.config).replace('\n',''))

Tokenizer config: {'errors': 'replace', 'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=False), 'add_prefix_space': False, 'model_max_length': 1024, 'tokenizer_file': None, 'name_or_path': 'sshleifer/distilbart-xsum-12-6'}
Model config: BartConfig {  "_name_or_path": "s

In [53]:
current_time = datetime.datetime.now().strftime("%d%m-%H%M")
print(current_time)
run_name_model = f"distilbart-abs-{current_time}"
wandb.run.name = run_name_model
wandb.run.save()

gen_config = GenerationConfig(max_source_length=512,bos_token_id=0)
gen_config.save_pretrained("roequitz/distilbart-abs-tweetsumm-2908a", push_to_hub=True)

training_args = Seq2SeqTrainingArguments(
    output_dir=f"trained-distilbart-abs-{current_time[0:4]}",
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=10,
    learning_rate=3e-3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
    generation_max_length=80,
    generation_config=gen_config,
    push_to_hub=False,
    report_to="wandb",
    run_name=run_name_model
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_tweetsumm_abs["train"],
    eval_dataset=tokenized_tweetsumm_abs["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics_abs,
)

training_start = time.time()
trainer.train()
training_end = time.time()
print("Time it took for training:", str(datetime.timedelta(seconds=(training_end-training_start))))



2908-1423


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LookupError: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/opt/conda/nltk_data'
    - '/opt/conda/share/nltk_data'
    - '/opt/conda/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************


In [None]:
!ls -lha & df -h & du -h

In [None]:
trainer.push_to_hub()

In [None]:
test_predictions = trainer.predict(tokenized_tweetsumm_abs["test"])
print(type(test_predictions), test_predictions)
#print(test_df_temp.head())

In [None]:
test_df_temp['predictions'] = test_predictions['predictions']
test_df_temp['metrics'] = test_predictions['metrics']

In [None]:
print(tokenized_tweetsumm_abs["test"][0]["input_ids"])
preds = trainer.predict(tokenized_tweetsumm_abs["test"])

In [None]:
print(preds.predictions)

In [None]:
import csv
test_name = ds_dir + f"test_preds_metrics_{current_time[0:2]}_{current_time[2:4]}_bart.csv"
test_df_temp.to_csv(test_name, index=False, header=False, quoting=csv.QUOTE_ALL)
wandb.log_artifact(test_name, results)
wandb.log(test_predictions['metrics'])
wandb.finish()