In [32]:
from transformers import BartTokenizerFast, BartForConditionalGeneration

bart_tokenizer = BartTokenizerFast.from_pretrained('~/models/bart-fine-tuned-msmarco-with-context-1.65')
bart_model = BartForConditionalGeneration.from_pretrained('~/models/bart-fine-tuned-msmarco-with-context-1.65')

In [33]:
from transformers import BartConfig

config = BartConfig(
    encoder_layers=3,
    decoder_layers=3,
)
distilbart = BartForConditionalGeneration(config)

In [None]:
import datasets

dataset = datasets.load_dataset('ms_marco', 'v2.1', split='train[:5000]')

In [None]:
from sentence_transformers import SentenceTransformer, util

bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')

corpus = []
for i in range(len(dataset)):
    corpus.extend(dataset[i]['passages']['passage_text'])

corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

In [None]:
from transformers.models.bart.modeling_bart import shift_tokens_right
from sentence_transformers import util
import random
import torch
random.seed(42)

def random_mask(query):
    words = query.split()
    if len(query) < 3 or len(words) < 2:
        return query
    
    mask_index = random.randint(len(words[0]), len(query) - 1)
    return query[:mask_index]

def convert_to_features6(batch):
    random.shuffle(batch['passages'])
    masked_queries = list(map(random_mask, batch['query']))

    query_embeddings = bi_encoder.encode(masked_queries, convert_to_tensor=True)
    masked_queries = [query + '<mask>' for query in masked_queries]
    knn = util.semantic_search(query_embeddings, corpus_embeddings, top_k=10)
    contexts = ['; '.join([corpus[e['corpus_id']] for e in embeddings]) for embeddings in knn]

    inputs = [query + '# ' + context for context, query in zip(contexts, masked_queries)]

    input_encodings = tokenizer.batch_encode_plus(inputs, padding='max_length', max_length=1024, truncation=True, return_tensors='pt')
    label_encodings = tokenizer.batch_encode_plus(batch['query'], padding='max_length', max_length=1024, truncation=True, return_tensors='pt')
    labels = label_encodings['input_ids']
    # decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
    labels[labels[:,:] == model.config.pad_token_id] = -100
    
    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        # 'decoder_input_ids': decoder_input_ids,
        'labels': labels,
        'masked_queries': masked_queries,
    }

    return encodings

In [99]:
from transformers.models.bart.modeling_bart import shift_tokens_right

tokenized = bart_tokenizer("what is a <mask>#a lego brick is one of the most amazing bricks", return_tensors="pt")
# decoder_input_ids = shift_tokens_right(tokenized['input_ids'], bart_tokenizer.pad_token_id, bart_tokenizer.eos_token_id)
tokenized['input_ids'] = bart_model(**tokenized, decoder_input_ids=decoder_input_ids).logits.argmax(dim=-1)

print(bart_tokenizer.batch_decode(tokenized['input_ids']))

['<s>what is a lego</s> lego brick</s></s> of the most amazing bricks</s>']


In [100]:
print(bart_tokenizer.batch_decode(bart_model.generate(**tokenized)))

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}



['</s><s>what is a lego</s>']


In [74]:
bart_tokenizer('<mask>')

{'input_ids': [0, 50264, 2], 'attention_mask': [1, 1, 1]}