In [1]:
import datasets

dataset = datasets.load_dataset('ms_marco', 'v2.1', split='train[:10000]')

Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)


In [2]:
from sentence_transformers import SentenceTransformer, util

bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')

In [3]:
corpus = []
for i in range(len(dataset)):
    corpus.extend(dataset[i]['passages']['passage_text'])

In [4]:
corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

Batches:   0%|          | 0/3117 [00:00<?, ?it/s]

In [13]:
from transformers import BartTokenizer, BartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')

loading file vocab.json from cache at /home/ubuntu/.cache/huggingface/hub/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/vocab.json
loading file merges.txt from cache at /home/ubuntu/.cache/huggingface/hub/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/merges.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /home/ubuntu/.cache/huggingface/hub/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/tokenizer_config.json
loading configuration file config.json from cache at /home/ubuntu/.cache/huggingface/hub/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "ar

In [None]:
from transformers.models.bart.modeling_bart import shift_tokens_right
from sentence_transformers import util
import random
import torch
random.seed(42)

def random_mask(query):
    words = query.split()
    if len(query) < 3 or len(words) < 2:
        return query
    
    mask_index = random.randint(len(words[0]), len(query) - 1)
    return query[:mask_index]

def convert_to_features6(batch):
    random.shuffle(batch['passages'])
    masked_queries = list(map(random_mask, batch['query']))

    query_embeddings = bi_encoder.encode(masked_queries, convert_to_tensor=True)
    masked_queries = [query + '<mask>' for query in masked_queries]
    knn = util.semantic_search(query_embeddings, corpus_embeddings, top_k=10)
    contexts = ['; '.join([corpus[e['corpus_id']] for e in embeddings]) for embeddings in knn]

    inputs = [query + '# ' + context for context, query in zip(contexts, masked_queries)]

    input_encodings = tokenizer.batch_encode_plus(inputs, padding='max_length', max_length=1024, truncation=True, return_tensors='pt')
    label_encodings = tokenizer.batch_encode_plus(batch['query'], padding='max_length', max_length=1024, truncation=True, return_tensors='pt')
    labels = label_encodings['input_ids']
    # decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
    labels[labels[:,:] == model.config.pad_token_id] = -100
    
    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        # 'decoder_input_ids': decoder_input_ids,
        'labels': labels,
        'masked_queries': masked_queries,
    }

    return encodings

In [7]:
dataset.cleanup_cache_files()
dataset = dataset.map(convert_to_features6, batched=True, batch_size=8, keep_in_memory=True)



  0%|          | 0/1250 [00:00<?, ?ba/s]

In [8]:
dataset = dataset.train_test_split(test_size=0.1)

In [11]:
from transformers.trainer import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir='./models/bart-summarizer',
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    do_train=True,
    do_eval=True,
    warmup_steps=500,   
    weight_decay=0.01,
    logging_dir='./logs',
    learning_rate=1e-4,
    logging_steps=25,

)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [12]:
trainer.train()

Step,Training Loss
25,13.6619
50,10.6501
75,9.2933
100,7.7891
125,6.816
150,6.3845
175,6.3919


KeyboardInterrupt: 

In [12]:
trainer.evaluate()

The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: wellFormedAnswers, query_id, masked_queries, query_type, passages, query, answers. If wellFormedAnswers, query_id, masked_queries, query_type, passages, query, answers are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 2


{'eval_loss': 2.0615005493164062}

In [None]:
tokenizer.save_pretrained('~/models/bart-fine-tuned-msmarco-with-context')
model.save_pretrained('~/models/bart-fine-tuned-msmarco-with-context')

tokenizer config file saved in ~/models/bart-fine-tuned-msmarco-with-context-1.65/tokenizer_config.json
Special tokens file saved in ~/models/bart-fine-tuned-msmarco-with-context-1.65/special_tokens_map.json
Configuration saved in ~/models/bart-fine-tuned-msmarco-with-context-1.65/config.json
Configuration saved in ~/models/bart-fine-tuned-msmarco-with-context-1.65/generation_config.json
Model weights saved in ~/models/bart-fine-tuned-msmarco-with-context-1.65/pytorch_model.bin


In [13]:
for i in range(0, 10):
    print('Actual:    ', dataset['test'][i]['query'])
    print('Query:     ', dataset['test'][i]['masked_queries'])
    to_encode = dataset['test'][i]['masked_queries'] + '; ' + '# '.join(dataset['test'][i]['passages']['passage_text'])
    encoded = tokenizer(
        to_encode,
        pad_to_max_length=True,
        max_length=1024,
        truncation=True,
        return_tensors='pt')
    output = model.generate(input_ids=encoded['input_ids'].to('cuda'), max_length=1024, num_beams=4, early_stopping=True)
    print('Predicted: ', tokenizer.decode(output[0], skip_special_tokens=True))
    print('---------------------')

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



Actual:     what date did tennessee become a state
Query:      what date did tennessee <mask>
Predicted:  what date did tennessee become a state?
---------------------
Actual:     cyber liability definition
Query:      cyber liabili<mask>
Predicted:  cyber liabilius
---------------------
Actual:     what does psi stand for in tire business
Query:      what does psi stand for<mask>


Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



Predicted:  what does psi stand for
---------------------
Actual:     what do urologist give for recurring uti
Query:      what do urolo<mask>
Predicted:  what do urolo technicians do
---------------------
Actual:     average cycles of macbook pro battery
Query:      average cycles of macbook pro batter<mask>


Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



Predicted:  average cycles of macbook pro batter
---------------------
Actual:     how ripen tomatoes faster
Query:      how ripen tomato<mask>
Predicted:  how ripen tomato slices
---------------------
Actual:     longest distance between two tube stations
Query:      longest dist<mask>
Predicted:  longest distemper
---------------------
Actual:     ldl normal range
Query:      ldl normal ra<mask>


Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



Predicted:  ldl normal rates
---------------------
Actual:     all except which disease is transmitted to humans by ticks
Query:      all except which d<mask>
Predicted:  all except which battery
---------------------
Actual:     longest term life insurance policy
Query:      longest ter<mask>
Predicted:  longest terroir
---------------------


In [None]:
import torch
for item in dataset['test']:
    text = item['text']
    labels = item['labels_text']
    encoded = tokenizer(text, return_tensors='pt').to('cuda')
    output = model.generate(**encoded, max_length=512, num_beams=4, early_stopping=True)
    predicted = tokenizer.decode(output[0].to('cpu'), skip_special_tokens=True)
    print('Text     : ', text)
    print('Predicted: ', predicted)
    print('Expected : ', labels)
    print('--' * 20)

KeyError: 'text'

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'labels_text'],
        num_rows: 90
    })
    test: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'labels_text'],
        num_rows: 10
    })
})