In [67]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModel
from torch import clamp, sum

In [68]:
DIMENSION = 768  # Embeddings size
MODEL = 'bert-base-uncased'  # Transformer to use for embeddings

In [69]:
connections.connect(host='127.0.0.1', port='19530')

if utility.has_collection('huggingface_db'):
    utility.drop_collection('huggingface_db')

fields = [
    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='original_question', dtype=DataType.VARCHAR, max_length=1000),
    FieldSchema(name='answer', dtype=DataType.VARCHAR, max_length=1000),
    FieldSchema(name='original_question_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields)
collection = Collection(name='huggingface_db', schema=schema)

index_params = {
    'metric_type': 'L2',
    'index_type': "IVF_FLAT",
    'params': {"nlist": 1536}
}
collection.create_index(field_name="original_question_embedding", index_params=index_params)
collection.load()

In [70]:
data_dataset = load_dataset('squad', split='all')
data_dataset = data_dataset.train_test_split(test_size=.001, seed=42)['test']
data_dataset = data_dataset.map(lambda val: {'answer': val['answers']['text'][0]}, remove_columns=['answers'])
data_dataset

Using the latest cached version of the module from C:\Users\duanm\.cache\huggingface\modules\datasets_modules\datasets\squad\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453 (last modified on Mon Jul 10 10:42:33 2023) since it couldn't be found locally at squad., or remotely on the Hugging Face Hub.
Found cached dataset squad (C:/Users/duanm/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)
Loading cached split indices for dataset at C:\Users\duanm\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-3f8f639b545735e4.arrow and C:\Users\duanm\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-8f60fd09160cab3e.arrow
Loading cached processed dataset at C:\Users\duanm\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837

Dataset({
    features: ['id', 'title', 'context', 'question', 'answer'],
    num_rows: 99
})

In [71]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)


def tokenize_question(batch):
    results = tokenizer(batch['question'], add_special_tokens=True, truncation=True, padding="max_length",
                        return_attention_mask=True, return_tensors="pt")
    batch['input_ids'] = results['input_ids']
    batch['token_type_ids'] = results['token_type_ids']
    batch['attention_mask'] = results['attention_mask']
    return batch


data_dataset = data_dataset.map(tokenize_question, batch_size=1000, batched=True)
data_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask'], output_all_columns=True)
data_dataset

Loading cached processed dataset at C:\Users\duanm\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-df4165deb4a6611f.arrow


Dataset({
    features: ['id', 'title', 'context', 'question', 'answer', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 99
})

In [72]:
model = AutoModel.from_pretrained(MODEL)


def embed(batch):
    sentence_embs = model(
        input_ids=batch['input_ids'],
        token_type_ids=batch['token_type_ids'],
        attention_mask=batch['attention_mask']
    )[0]
    input_mask_expanded = batch['attention_mask'].unsqueeze(-1).expand(sentence_embs.size()).float()
    batch['question_embedding'] = sum(sentence_embs * input_mask_expanded, 1) / clamp(input_mask_expanded.sum(1),
                                                                                      min=1e-9)
    return batch


data_dataset = data_dataset.map(embed, remove_columns=['input_ids', 'token_type_ids', 'attention_mask'], batched=True,
                                batch_size=64)
data_dataset

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Loading cached processed dataset at C:\Users\duanm\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837

Dataset({
    features: ['id', 'title', 'context', 'question', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'question_embedding'],
    num_rows: 99
})

In [73]:
def insert_function(batch):
    insertable = [
        batch['question'],
        [x[:995] + '...' if len(x) > 999 else x for x in batch['answer']],
        batch['question_embedding'].tolist()
    ]
    collection.insert(insertable)


data_dataset.map(insert_function, batched=True, batch_size=64)

Map:   0%|          | 0/99 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'question_embedding'],
    num_rows: 99
})

In [74]:
questions = {'question': ['When was chemistry invented?', 'When was Eisenhower born?']}
question_dataset = Dataset.from_dict(questions)
question_dataset = question_dataset.map(tokenize_question, batched=True, batch_size=1000)
question_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask'], output_all_columns=True)
question_dataset = question_dataset.map(embed, remove_columns=['input_ids', 'token_type_ids', 'attention_mask'],
                                        batched=True, batch_size=64)
question_dataset

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'input_ids', 'token_type_ids', 'attention_mask', 'question_embedding'],
    num_rows: 2
})

In [75]:
def search(batch):
    res = collection.search(batch['question_embedding'].tolist(),
                            anns_field='original_question_embedding', param={},
                            limit=10,
                            output_fields=['answer', 'original_question'])

    overall_id = []
    overall_distance = []
    overall_answer = []
    overall_original_question = []
    for hits in res:
        ids = []
        distance = []
        answer = []
        original_question = []
        for hit in hits:
            ids.append(hit.id)
            distance.append(hit.distance)
            answer.append(hit.entity.get('answer'))
            original_question.append(hit.entity.get('original_question'))
        overall_id.append(ids)
        overall_distance.append(distance)
        overall_answer.append(answer)
        overall_original_question.append(original_question)

    return {'id': overall_id,
            'distance': overall_distance,
            'answer': overall_answer,
            'original_question': overall_original_question}


question_dataset = question_dataset.map(search, batched=True, batch_size=1)
for x in question_dataset:
    print(x['question'])
    for x1, x2, x3 in zip(x['original_question'], x['answer'], x['distance']):
        print((x1, x2, x3.item()))
    print()

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

When was chemistry invented?
('When did the Papal States exist?', 'until 1870', 34.13431167602539)
('When was the Tower constructed?', '1787', 35.67085266113281)
('When were free elections held?', 'October 1992', 38.95990753173828)
('How old did biblical scholars think the Earth was?', '6,000 years', 44.86854553222656)
('Where was Russian schooling mandatory in the 20th century?', 'Poland, Bulgaria, the Czech Republic, Slovakia, Hungary, Albania, former East Germany and Cuba', 45.79857635498047)
('In what year was the Premier League created?', '1992', 47.060733795166016)
("When was ZE's Mutant Disco released?", '1981', 48.399925231933594)
("What was the Latin of Charlemagne's era later known as?", 'Medieval Latin', 50.96128845214844)
('How did Hobson argue to rid the world of imperialism?', 'taxation', 51.08031463623047)
('What Prussian system was superior to the French example?', 'military education', 52.56195068359375)

When was Eisenhower born?
('When was the Tower constructed?', '1

In [76]:
collection.release()