In [1]:
from dlkp.models import KeyphraseGenerator
from dlkp.generation import KGTrainingArguments, KGModelArguments, KGDataArguments

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_args = KGDataArguments(
    dataset_name="midas/inspec",
    dataset_config_name="generation",
    text_column_name="document",
    keyphrases_column_name="extractive_keyphrases",
    n_best_size=5,
    num_beams=3,
    cat_sequence=True,
)

In [4]:
training_args = KGTrainingArguments(
    output_dir="../outputs/generation",
    predict_with_generate=True,
    learning_rate=4e-5,
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    do_train=True,
    do_eval=True,
    do_predict=False,
    eval_steps=1000,
    logging_steps=1000
)

In [5]:
model_args = KGModelArguments(model_name_or_path="bloomberg/KeyBART")

In [6]:
KeyphraseGenerator.train_and_eval(model_args, data_args, training_args)

04/09/2022 18:55:01 - INFO - dlkp.generation.train_eval_generator - Training/evaluation parameters KGTrainingArguments(
_n_gpu=2,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
bf16=False,
bf16_full_eval=False,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_steps=1000,
evaluation_strategy=IntervalStrategy.NO,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
generation_max_length=None,
generation_num_beams=None,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_model_id=None,
hub_strategy=HubStrategy.EVERY_SAVE,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=4e-05,
length_colu

[INFO|configuration_utils.py:648] 2022-04-09 18:55:01,522 >> loading configuration file https://huggingface.co/bloomberg/KeyBART/resolve/main/config.json from cache at /home/debanjan/.cache/huggingface/transformers/d3c4f4b89efa42c978f3ea87e7a63b0a1bdfb0184bf17c014ef60c59c93b7da3.9458ab4fca85981d74f760077eca0099f9dac8306a47aaf172020a5397a4c223
[INFO|configuration_utils.py:684] 2022-04-09 18:55:01,525 >> Model config BartConfig {
  "_name_or_path": "bloomberg/KeyBART",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
 

04/09/2022 18:55:03 - INFO - datasets.builder - Overwrite dataset info from restored data version.
04/09/2022 18:55:03 - INFO - datasets.info - Loading Dataset info from /home/debanjan/.cache/huggingface/datasets/midas___inspec/generation/0.0.1/debd18641afb7048a36cee2b7bb8dfbf2cd1a68899118653a42fd760cf84284e
04/09/2022 18:55:03 - INFO - datasets.info - Loading Dataset info from /home/debanjan/.cache/huggingface/datasets/midas___inspec/generation/0.0.1/debd18641afb7048a36cee2b7bb8dfbf2cd1a68899118653a42fd760cf84284e


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 414.33it/s]
  0%|                                                                                                                                                                                                    | 0/1000 [00:00<?, ?ex/s]

04/09/2022 18:55:03 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/debanjan/.cache/huggingface/datasets/midas___inspec/generation/0.0.1/debd18641afb7048a36cee2b7bb8dfbf2cd1a68899118653a42fd760cf84284e/cache-5ad159c1e36ee61e.arrow


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:01<00:00, 516.27ex/s]
  0%|                                                                                                                                                                                                     | 0/500 [00:00<?, ?ex/s]

04/09/2022 18:55:05 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/debanjan/.cache/huggingface/datasets/midas___inspec/generation/0.0.1/debd18641afb7048a36cee2b7bb8dfbf2cd1a68899118653a42fd760cf84284e/cache-ec74da9311830ddd.arrow


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 560.48ex/s]
  0%|                                                                                                                                                                                                     | 0/500 [00:00<?, ?ex/s]

04/09/2022 18:55:06 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/debanjan/.cache/huggingface/datasets/midas___inspec/generation/0.0.1/debd18641afb7048a36cee2b7bb8dfbf2cd1a68899118653a42fd760cf84284e/cache-dfcf29c4ee2ab17b.arrow


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 560.19ex/s]
[INFO|modeling_utils.py:1431] 2022-04-09 18:55:07,803 >> loading weights file https://huggingface.co/bloomberg/KeyBART/resolve/main/pytorch_model.bin from cache at /home/debanjan/.cache/huggingface/transformers/277d84cd7ca8e4aea34cc45faf3976559f510c572eadb4c06d6428fa4bdf10d2.5fd0da0c4c8b79f9789db1511826fa743a04bf996360f94176176478d0a57b5d
[INFO|modeling_utils.py:1702] 2022-04-09 18:55:10,769 >> All model checkpoint weights were used when initializing BartForConditionalGeneration.

[INFO|modeling_utils.py:1710] 2022-04-09 18:55:10,770 >> All the weights of BartForConditionalGeneration were initialized from the model checkpoint at bloomberg/KeyBART.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditiona

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py", line 1329, in forward
    outputs = self.model(
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py", line 1198, in forward
    encoder_outputs = self.encoder(
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py", line 824, in forward
    layer_outputs = encoder_layer(
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py", line 307, in forward
    hidden_states, attn_weights, _ = self.self_attn(
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/debanjan/code/research/dlkp/venv/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py", line 215, in forward
    attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
RuntimeError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 23.69 GiB total capacity; 6.76 GiB already allocated; 75.56 MiB free; 6.81 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


In [None]:
generator = KeyphraseGenerator.load(
    "../outputs/generation"
)

input_text = (
    "In this work, we explore how to learn task-specific language models aimed towards learning rich "
    "representation of keyphrases from text documents. We experiment with different masking strategies for "
    "pre-training transformer language models (LMs) in discriminative as well as generative settings. In the "
    "discriminative setting, we introduce a new pre-training objective - Keyphrase Boundary Infilling with "
    "Replacement (KBIR), showing large gains in performance (upto 9.26 points in F1) over SOTA, when LM "
    "pre-trained using KBIR is fine-tuned for the task of keyphrase extraction. In the generative setting, we "
    "introduce a new pre-training setup for BART - KeyBART, that reproduces the keyphrases related to the "
    "input text in the CatSeq format, instead of the denoised original input. This also led to gains in "
    "performance (upto 4.33 points in F1@M) over SOTA for keyphrase generation. Additionally, we also "
    "fine-tune the pre-trained language models on named entity recognition (NER), question answering (QA), "
    "relation extraction (RE), abstractive summarization and achieve comparable performance with that of the "
    "SOTA, showing that learning rich representation of keyphrases is indeed beneficial for many other "
    "fundamental NLP tasks."
)

generator_out = generator.generate(input_text)
print(generator_out)