In [246]:
from transformers import BartTokenizer, BartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

loading file vocab.json from cache at /home/ubuntu/.cache/huggingface/hub/models--facebook--bart-base/snapshots/aadd2ab0ae0c8268c7c9693540e9904811f36177/vocab.json
loading file merges.txt from cache at /home/ubuntu/.cache/huggingface/hub/models--facebook--bart-base/snapshots/aadd2ab0ae0c8268c7c9693540e9904811f36177/merges.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at None
loading configuration file config.json from cache at /home/ubuntu/.cache/huggingface/hub/models--facebook--bart-base/snapshots/aadd2ab0ae0c8268c7c9693540e9904811f36177/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0

In [242]:
import datasets

dataset = datasets.load_dataset('ms_marco', 'v2.1', split='train[:5000]')

Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)


In [243]:
from transformers.models.bart.modeling_bart import shift_tokens_right
import random
import torch
random.seed(42)

def random_mask(query):
    words = query.split()
    if len(words) < 3:
        return query
    mask_index = random.randint(1, len(words) - 1)
    return ' '.join(words[:mask_index] + ['<mask>'])

def convert_to_features(batch):
    random.shuffle(batch['passages'])
    contexts = ['# '.join(passages['passage_text']) for passages in batch['passages']]
    masked_queries = list(map(random_mask, batch['query']))
    inputs = [query + '; ' + context for context, query in zip(contexts, masked_queries)]

    input_encodings = tokenizer.batch_encode_plus(inputs, pad_to_max_length=True, max_length=1024, truncation=True, return_tensors='pt')
    label_encodings = tokenizer.batch_encode_plus(batch['query'], pad_to_max_length=True, max_length=1024, truncation=True, return_tensors='pt')
    labels = label_encodings['input_ids']
    # decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
    labels[labels[:,:] == model.config.pad_token_id] = -100
    
    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        # 'decoder_input_ids': decoder_input_ids,
        'labels': labels,
        'masked_queries': masked_queries,
    }

    return encodings

In [244]:
dataset = dataset.map(convert_to_features, batched=True, batch_size=8)

Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84/cache-1c80317fa3b1799d.arrow


In [245]:
dataset = dataset.train_test_split(test_size=0.1)

Loading cached split indices for dataset at /home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84/cache-8633ab0fd822d949.arrow and /home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84/cache-a78d5d6770fccf2e.arrow


In [247]:
from transformers.trainer import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir='./models/bart-summarizer',
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    do_train=True,
    do_eval=True,
    warmup_steps=500,   
    weight_decay=0.01,
    logging_dir='./logs',
    learning_rate=1e-5,
    logging_steps=25,

)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [248]:
trainer.train()

Step,Training Loss
25,3.035
50,2.7094
75,2.5383
100,2.02
125,1.8889
150,1.4597
175,1.5649
200,1.677
225,1.4994
250,1.4209


KeyboardInterrupt: 

: 

In [58]:
trainer.evaluate()

The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: answers, passages, wellFormedAnswers, query_id, masked_queries, query, query_type. If answers, passages, wellFormedAnswers, query_id, masked_queries, query, query_type are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 2


{'eval_loss': 1.0784064531326294}

In [59]:
tokenizer.save_pretrained('~/models/bart-fine-tuned-msmarco-with-context')
model.save_pretrained('~/models/bart-fine-tuned-msmarco-with-context')

tokenizer config file saved in ~/models/bart-fine-tuned-msmarco-with-context/tokenizer_config.json
Special tokens file saved in ~/models/bart-fine-tuned-msmarco-with-context/special_tokens_map.json
Configuration saved in ~/models/bart-fine-tuned-msmarco-with-context/config.json
Configuration saved in ~/models/bart-fine-tuned-msmarco-with-context/generation_config.json
Model weights saved in ~/models/bart-fine-tuned-msmarco-with-context/pytorch_model.bin


In [60]:
for i in range(0, 10):
    print('Actual:    ', dataset['test'][i]['query'])
    print('Query:     ', dataset['test'][i]['masked_queries'])
    to_encode = dataset['test'][i]['masked_queries'] + '; ' + '# '.join(dataset['test'][i]['passages']['passage_text'])
    encoded = tokenizer(
        to_encode,
        pad_to_max_length=True,
        max_length=1024,
        truncation=True,
        return_tensors='pt')
    output = model.generate(input_ids=encoded['input_ids'].to('cuda'), max_length=1024, num_beams=4, early_stopping=True)
    print('Predicted: ', tokenizer.decode(output[0], skip_special_tokens=True))
    print('---------------------')

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



Actual:     how to make the color rose gold
Query:      how to make <mask>
Predicted:  how to make rose gold hair
---------------------
Actual:     names moby dick is called
Query:      names moby <mask>
Predicted:  names moby and dick
---------------------
Actual:     us bank payoff mailing address
Query:      us bank payoff <mask>


Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,

Predicted:  us bank payoff number
---------------------
Actual:     what is cwr means
Query:      what <mask>
Predicted:  what is a congestion indication
---------------------
Actual:     how to open 1972 vw beetle hood
Query:      how <mask>
Predicted:  how to open the hood on a beetle
---------------------
Actual:     what does body habitus mean
Query:      what does body <mask>


Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



Predicted:  what does body habitus mean
---------------------
Actual:     what causes learning disabilities?
Query:      what causes learning <mask>
Predicted:  what causes learning disabilities
---------------------
Actual:     is there a way to merge to cells with writing in them together
Query:      is there a way to merge to <mask>
Predicted:  is there a way to merge to combine Excel
---------------------
Actual:     the pointer sisters neutron dance
Query:      the pointer sisters <mask>


Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



Predicted:  the pointer sisters neutron dance lyrics
---------------------
Actual:     does adenosine triphosphate give energy
Query:      does adenosine triphosphate <mask>
Predicted:  does adenosine triphosphate store energy
---------------------


In [None]:
import torch
for item in dataset['test']:
    text = item['text']
    labels = item['labels_text']
    encoded = tokenizer(text, return_tensors='pt').to('cuda')
    output = model.generate(**encoded, max_length=512, num_beams=4, early_stopping=True)
    predicted = tokenizer.decode(output[0].to('cpu'), skip_special_tokens=True)
    print('Text     : ', text)
    print('Predicted: ', predicted)
    print('Expected : ', labels)
    print('--' * 20)

KeyError: 'text'

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'labels_text'],
        num_rows: 90
    })
    test: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'labels_text'],
        num_rows: 10
    })
})