Tutorial: https://milvus.io/docs/integrate_with_hugging-face.md

In [None]:
!pip install --upgrade pymilvus transformers datasets torch

In [7]:
import IPython

IPython.Application.instance().kernel.do_shutdown(True) #automatically restarts kernel

{'status': 'ok', 'restart': True}

In [1]:
from datasets import load_dataset


DATASET = "squad"  # Name of dataset from HuggingFace Datasets
INSERT_RATIO = 0.001  # Ratio of example dataset to be inserted

data = load_dataset(DATASET, split="validation")
data = data.train_test_split(test_size=INSERT_RATIO, seed=42)["test"]
data = data.map(
    lambda val: {"answer": val["answers"]["text"][0]},
    remove_columns=["id", "answers", "context"],
)

print(data)


  from .autonotebook import tqdm as notebook_tqdm
Downloading readme: 100%|██████████| 7.62k/7.62k [00:00<00:00, 20.7MB/s]
Downloading data: 100%|██████████| 14.5M/14.5M [00:00<00:00, 20.5MB/s]
Downloading data: 100%|██████████| 1.82M/1.82M [00:00<00:00, 4.37MB/s]
Generating train split: 100%|██████████| 87599/87599 [00:00<00:00, 297151.02 examples/s]
Generating validation split: 100%|██████████| 10570/10570 [00:00<00:00, 288400.52 examples/s]
Map: 100%|██████████| 11/11 [00:00<00:00, 1291.64 examples/s]

Dataset({
    features: ['title', 'question', 'answer'],
    num_rows: 11
})





In [2]:
from transformers import AutoTokenizer, AutoModel
import torch

MODEL = (
    "sentence-transformers/all-MiniLM-L6-v2"  # Name of model from HuggingFace Models
)
INFERENCE_BATCH_SIZE = 64  # Batch size of model inference

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModel.from_pretrained(MODEL)


def encode_text(batch):
    # Tokenize sentences
    encoded_input = tokenizer(
        batch["question"], padding=True, truncation=True, return_tensors="pt"
    )

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    token_embeddings = model_output[0]
    attention_mask = encoded_input["attention_mask"]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    sentence_embeddings = torch.sum(
        token_embeddings * input_mask_expanded, 1
    ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    # Normalize embeddings
    batch["question_embedding"] = torch.nn.functional.normalize(
        sentence_embeddings, p=2, dim=1
    )
    return batch


data = data.map(encode_text, batched=True, batch_size=INFERENCE_BATCH_SIZE)
data_list = data.to_list()


Map: 100%|██████████| 11/11 [00:00<00:00, 13.80 examples/s]


In [3]:
from pymilvus import MilvusClient


MILVUS_URI = "http://milvus.milvus.svc.cluster.local"  # Connection URI
COLLECTION_NAME = "huggingface_test"  # Collection name
DIMENSION = 384  # Embedding dimension depending on model

milvus_client = MilvusClient(MILVUS_URI)
if milvus_client.has_collection(collection_name=COLLECTION_NAME):
    milvus_client.drop_collection(collection_name=COLLECTION_NAME)
milvus_client.create_collection(
    collection_name=COLLECTION_NAME,
    dimension=DIMENSION,
    auto_id=True,  # Enable auto id
    enable_dynamic_field=True,  # Enable dynamic fields
    vector_field_name="question_embedding",  # Map vector field name and embedding column in dataset
    consistency_level="Strong",  # To enable search with latest data
)


In [5]:
milvus_client.insert(collection_name=COLLECTION_NAME, data=data_list)

{'insert_count': 11,
 'ids': [451627567490287437, 451627567490287438, 451627567490287439, 451627567490287440, 451627567490287441, 451627567490287442, 451627567490287443, 451627567490287444, 451627567490287445, 451627567490287446, 451627567490287447],
 'cost': 0}

In [6]:
questions = {
    "question": [
        "What is LGM?",
        "When did Massachusetts first mandate that children be educated in schools?",
    ]
}

question_embeddings = [v.tolist() for v in encode_text(questions)["question_embedding"]]

search_results = milvus_client.search(
    collection_name=COLLECTION_NAME,
    data=question_embeddings,
    limit=3,  # How many search results to output
    output_fields=["answer", "question"],  # Include these fields in search results
)

for q, res in zip(questions["question"], search_results):
    print("Question:", q)
    for r in res:
        print(
            {
                "answer": r["entity"]["answer"],
                "score": r["distance"],
                "original question": r["entity"]["question"],
            }
        )
    print("\n")


Question: What is LGM?
{'answer': 'Last Glacial Maximum', 'score': 0.956273078918457, 'original question': 'What does LGM stands for?'}
{'answer': 'coordinate the response to the embargo', 'score': 0.21201416850090027, 'original question': 'Why was this short termed organization created?'}
{'answer': '"Reducibility Among Combinatorial Problems"', 'score': 0.1945795714855194, 'original question': 'What is the paper written by Richard Karp in 1972 that ushered in a new era of understanding between intractability and NP-complete problems?'}


Question: When did Massachusetts first mandate that children be educated in schools?
{'answer': '1852', 'score': 0.9709996581077576, 'original question': 'In what year did Massachusetts first require children to be educated in schools?'}
{'answer': 'several regional colleges and universities', 'score': 0.341647207736969, 'original question': 'In 1890, who did the university decide to team up with?'}
{'answer': '1962', 'score': 0.19310054183006287, 'o