# Гибридный поиск на основе milvus

In [1]:
from pymilvus import (
    MilvusClient,
    DataType,
    Function,
    FunctionType,
    AnnSearchRequest,
    WeightedRanker,
    RRFRanker
)

import nltk
from nltk.corpus import stopwords

from langchain_huggingface import HuggingFaceEmbeddings

#### Загружаем стоп слова

In [10]:
# Загружаем стоп-слова для русского языка
nltk.download("stopwords")
russian_stopwords = stopwords.words("russian")

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\mariya.kuznetsova\AppData\Roaming\nltk_data..
[nltk_data]     .
[nltk_data]   Package stopwords is already up-to-date!


### Класс для гибридного поиска в milvus

In [153]:
class HybridRetriever:
    def __init__(self, uri, collection_name="Messages", dense_embedding_function=None):
        self.uri = uri
        self.collection_name = collection_name
        self.embedding_function = dense_embedding_function
        self.client = MilvusClient(uri=uri)

    def build_collection(self):
        dense_dim = len(self.embedding_function.embed_query('test'))

        if self.client.has_collection(collection_name=self.collection_name):
            self.client.drop_collection(collection_name=self.collection_name)

        tokenizer_params = {
            "tokenizer": "standard",
            "filter": [
                {"type": "stemmer", "language": "russian"},
                {
                    "type": "stop",
                    "stop_words": russian_stopwords,
                },
            ],
        }

        schema = MilvusClient.create_schema()
        schema.add_field(
            field_name="id",
            datatype=DataType.INT64,
            is_primary=True,
            auto_id=True
        )
        schema.add_field(
            field_name="message",
            datatype=DataType.VARCHAR,
            max_length=20000,
            analyzer_params=tokenizer_params,
            enable_match=True,
            enable_analyzer=True,
        )
        schema.add_field(
            field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
        )
        schema.add_field(
            field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
        )
        schema.add_field(
            field_name="chat_message_id",
            datatype=DataType.INT64
        )
        schema.add_field(
            field_name="chat_name", datatype=DataType.VARCHAR, max_length=200
        )

        functions = Function(
            name="bm25",
            function_type=FunctionType.BM25,
            input_field_names=["message"],
            output_field_names="sparse_vector",
        )

        schema.add_function(functions)

        index_params = MilvusClient.prepare_index_params()
        index_params.add_index(
            field_name="sparse_vector",
            index_type="SPARSE_INVERTED_INDEX",
            metric_type="BM25",
        )
        index_params.add_index(
            field_name="dense_vector", index_type="FLAT", metric_type="IP"
        )

        self.client.create_collection(
            collection_name=self.collection_name,
            schema=schema,
            index_params=index_params,
        )

    def insert_data(self, metadata):
        embedding = self.embedding_function.embed_documents([metadata['message']])
        dense_vec = embedding[0]
        self.client.insert(
            self.collection_name, {"dense_vector": dense_vec, **metadata}
        )

    async def search(self, query: str, chat_names: list, k: int = 20, mode="hybrid", weights=[0.5, 0.5], k_rerank=100):

        output_fields = [
            "message",
            "chat_name",
        ]
        filter_expression = f"chat_name in [{", ".join(f'\"{name}\"' for name in chat_names)}]"
        
        if mode in ["dense", "hybrid"]:
            embedding = self.embedding_function.embed_query(query)
            dense_vec = embedding

        if mode == "sparse":
            results = self.client.search(
                collection_name=self.collection_name,
                data=[query],
                anns_field="sparse_vector",
                limit=k,
                filter=filter_expression,
                output_fields=output_fields,
            )
        elif mode == "dense":
            results = self.client.search(
                collection_name=self.collection_name,
                data=[dense_vec],
                anns_field="dense_vector",
                limit=k,
                filter=filter_expression,
                output_fields=output_fields,
            )
        elif mode == "hybrid":
            full_text_search_params = {"metric_type": "BM25"}
            full_text_search_req = AnnSearchRequest(
                [query], "sparse_vector", full_text_search_params, limit=k, expr=filter_expression,
            )

            dense_search_params = {"metric_type": "IP"}
            dense_req = AnnSearchRequest(
                [dense_vec], "dense_vector", dense_search_params, limit=k, expr=filter_expression,
            )

            results = self.client.hybrid_search(
                self.collection_name,
                [full_text_search_req, dense_req],
                # ranker=WeightedRanker(*weights),
                ranker=RRFRanker(k=k_rerank),
                limit=k,
                filter=filter_expression,
                output_fields=output_fields,
            )
        else:
            raise ValueError("Invalid mode")
        return [
            {
                "message": doc["entity"]["message"],
                "chat_name": doc["entity"]["chat_name"],
                "score": doc["distance"],
            }
            for doc in results[0]
        ]

### Выбираем модель эмбеддингов и создаем ретривер

In [114]:
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

In [154]:
retriever = HybridRetriever(
    uri="http://localhost:19530", 
    collection_name="Messages", 
    dense_embedding_function=embeddings)
retriever.build_collection()

### Загружаем данные

In [156]:
retriever.insert_data({"message": 'Сегодня холодно', 'chat_name': 'Погода', 'chat_message_id': 0})
retriever.insert_data({"message": 'Завтра 10 градусов и сильный ветер', 'chat_name': 'Погода', 'chat_message_id': 1})
retriever.insert_data({"message": '10 декабря сильный ветер и -10 градусов', 'chat_name': 'Погода', 'chat_message_id': 2})

retriever.insert_data({"message": 'Сегодня концер Егора крида', 'chat_name': 'Афиша', 'chat_message_id': 0})
retriever.insert_data({"message": '10 ноября стендап Абрамова ветер', 'chat_name': 'Афиша', 'chat_message_id': 1})
retriever.insert_data({"message": '11 декабря стендап Усовича', 'chat_name': 'Афиша', 'chat_message_id': 2})

retriever.insert_data({"message": 'Сегодня концер Егора крида', 'chat_name': 'Концерты', 'chat_message_id': 0})
retriever.insert_data({"message": '10 ноября стендап Абрамова ветер', 'chat_name': 'Концерты', 'chat_message_id': 1})
retriever.insert_data({"message": '11 декабря стендап Усовича', 'chat_name': 'Концерты', 'chat_message_id': 2})

### Пробный поиск

In [158]:
await retriever.search('Когда сильный ветер?', chat_names=["Погода"], k = 20, mode="hybrid", weights=[0.5, 0.5], k_rerank=10)

[{'message': 'Завтра 10 градусов и сильный ветер',
  'chat_name': 'Погода',
  'score': 0.1818181872367859},
 {'message': '10 декабря сильный ветер и -10 градусов',
  'chat_name': 'Погода',
  'score': 0.16025641560554504},
 {'message': 'Сегодня холодно',
  'chat_name': 'Погода',
  'score': 0.0833333358168602}]