In [1]:
from sentence_transformers import SentenceTransformer, util

bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')

In [2]:
import datasets

dataset = datasets.load_dataset('ms_marco', 'v2.1', split='train[:10000]')

Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)


In [3]:
corpus = []
for i in range(len(dataset)):
    corpus.extend(dataset[i]['passages']['passage_text'])
corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

Batches:   0%|          | 0/3117 [00:00<?, ?it/s]

In [4]:
import torch
from transformers import T5TokenizerFast, T5ForConditionalGeneration

tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").cuda()

In [5]:
# from transformers.models.bart.modeling_bart import shift_tokens_right
# from sentence_transformers import util
# import random
# import torch
# random.seed(42)

# def random_mask(query):
#     words = query.split()
#     if len(words) < 3:
#         return query
#     mask_index = random.randint(1, len(words) - 1)
#     return ' '.join(words[:mask_index])

# def convert_to_features(batch):
#     random.shuffle(batch['passages'])
#     masked_queries = list(map(random_mask, batch['query']))

#     query_embeddings = bi_encoder.encode(masked_queries, convert_to_tensor=True)
#     masked_queries = [query + ' <extra_id_0>' for query in masked_queries]
#     knn = util.semantic_search(query_embeddings, corpus_embeddings, top_k=10)
#     contexts = ['; '.join([corpus[e['corpus_id']] for e in embeddings]) for embeddings in knn]
#     inputs = [query + '# ' + context for context, query in zip(contexts, masked_queries)]

#     input_encodings = tokenizer.batch_encode_plus(inputs, pad_to_max_length=True, max_length=1024, truncation=True, return_tensors='pt')
#     label_encodings = tokenizer.batch_encode_plus(batch['query'], pad_to_max_length=True, max_length=1024, truncation=True, return_tensors='pt')
#     labels = label_encodings['input_ids']
#     # decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
#     labels[labels[:,:] == model.config.pad_token_id] = -100
    
#     encodings = {
#         'input_ids': input_encodings['input_ids'],
#         'attention_mask': input_encodings['attention_mask'],
#         # 'decoder_input_ids': decoder_input_ids,
#         'labels': labels,
#         'masked_queries': masked_queries,
#     }

#     return encodings

import random
from collections import Counter

def random_cutoff(query):
    words = query.split()
    if len(words) < 3:
        return query
    cutoff_index = random.randint(1, len(words) - 1)
    return ' '.join(words[:cutoff_index])

def prepare_input(query, word_frequencies):
    return query + ';' + ' '.join([word for word, _ in word_frequencies])

def convert_to_features(batch):
    queries = batch["query"]
    cutoffs = [random_cutoff(query) for query in queries]
    word_frequencies = []
    for i, passage in enumerate(batch["passages"]):
        encoded_query = bi_encoder.encode(cutoffs[i], convert_to_tensor=True)
        knn = util.semantic_search(encoded_query, corpus_embeddings, top_k=10)
        passages = [corpus[e['corpus_id']] for e in knn[0]]
        # passages = passage["passage_text"]
        counter = Counter([word.lower() for passage in passages for word in passage.split()])
        most_common = counter.most_common(256)
        word_frequencies.append(most_common)
    source = [prepare_input(query, word_frequencies) for query, word_frequencies in zip(cutoffs, word_frequencies)]
    labels = batch["query"]
    model_inputs = tokenizer.batch_encode_plus(source, max_length=512, truncation=True, padding="max_length")

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer.batch_encode_plus(labels, max_length=512, truncation=True, padding="max_length")

    return {
        "input_ids": model_inputs["input_ids"],
        "attention_mask": model_inputs["attention_mask"],
        "labels": labels["input_ids"],
    }

In [6]:
dataset = dataset.map(convert_to_features, batched=True, batch_size=8, keep_in_memory=True)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]



In [7]:
dataset = dataset.train_test_split(test_size=0.1)

In [8]:
import torch
from transformers import T5TokenizerFast, T5ForConditionalGeneration

tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").cuda()

In [9]:
from transformers.trainer import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir='./models/t5-autocomplete',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    do_train=True,
    do_eval=True,
    warmup_steps=500,   
    weight_decay=0.01,
    logging_dir='./logs',
    learning_rate=1e-2,
    logging_steps=25,

)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
)

In [10]:
trainer.train()

Step,Training Loss
25,26.4699
50,2.6006
75,0.0864
100,0.0673
125,0.0553
150,0.0539
175,0.0564
200,0.0553
225,0.0728
250,0.0626


In [None]:
trainer.evaluate()

The following columns in the evaluation set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: query, wellFormedAnswers, answers, query_id, query_type, passages. If query, wellFormedAnswers, answers, query_id, query_type, passages are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 4


In [None]:
tokenizer.save_pretrained('/home/ubuntu/models/t5-autocomplete')
model.save_pretrained('/home/ubuntu/models/t5-autocomplete')

In [11]:
import transformers

transformers.logging.set_verbosity_error()

for i in range(0, 10):
    print('Actual:    ', dataset['test'][i]['query'])
    cutoff = random_cutoff(dataset['test'][i]['query'])

    encoded_query = bi_encoder.encode(cutoff, convert_to_tensor=True)
    knn = util.semantic_search(encoded_query, corpus_embeddings, top_k=10)
    passages = [corpus[e['corpus_id']] for e in knn[0]]
    counter = Counter([word.lower() for passage in passages for word in passage.split()])
    most_common = counter.most_common(256)
    to_encode = prepare_input(cutoff, most_common)
    print('Query: ', cutoff)
    encoded = tokenizer(
        to_encode,
        pad_to_max_length=True,
        max_length=1024,
        truncation=True,
        return_tensors='pt')
    output = model.generate(input_ids=encoded['input_ids'].to('cuda'), max_length=1024, num_beams=4, early_stopping=True)
    print('Predicted: ', tokenizer.decode(output[0], skip_special_tokens=True))
    print('---------------------')

Actual:     average annual visits per nephrologist
Query:  average
Predicted:  what
---------------------
Actual:     is head surgeon capital letters
Query:  is
Predicted:  what
---------------------
Actual:     types of health coach
Query:  types of health




Predicted:  what is
---------------------
Actual:     causes of toxemia
Query:  causes of
Predicted:  what
---------------------
Actual:     manufactured definition
Query:  manufactured definition
Predicted:  is what
---------------------
Actual:     dexterity definition in resume
Query:  dexterity
Predicted:  what
---------------------
Actual:     what amount magnesium blood level is normal
Query:  what amount
Predicted:  what what what what what what
---------------------
Actual:     what causes excessive itching in cats
Query:  what causes
Predicted:  what what
---------------------
Actual:     what county is lowville ny located in
Query:  what county is lowville ny located
Predicted:  what what what
---------------------
Actual:     what is adhd caused by
Query:  what
Predicted:  what what
---------------------
