In [None]:
import os
import pickle
import re
import tqdm
import uuid

from FlagEmbedding import FlagModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pymilvus import (
    utility,
    CollectionSchema, DataType, FieldSchema, model,
    connections, Collection, AnnSearchRequest, RRFRanker,
)
from typing import List, Optional, Tuple

In [None]:
ENDPOINT = os.getenv('ZILLIS_ENDPOINT')
TOKEN = os.getenv('ZILLIS_TOKEN')
connections.connect(uri=ENDPOINT, token=TOKEN)

## Create milvus collection 
1. Drop existing collection (if one exists)
2. Define Schema -> How documents will be Ingested
3. Create Collection with Schema defined in 2.

In [5]:
collection_name = "odprt_index"

def drop_collection(collection_name):
    # check if the collection exists
    if utility.has_collection(collection_name):
        collection = Collection(name=collection_name)
        # release the collection
        collection.release()
        # drop the collection if it exists
        utility.drop_collection(collection_name)
        print(f"Collection '{collection_name}' has been dropped")
    else:
        print(f"Collection '{collection_name}' does not exist")

# drop_collection(collection_name)

In [6]:
auto_id = FieldSchema(
    name="pk",
    dtype=DataType.INT64,
    is_primary=True,
    auto_id=True)

doc_id = FieldSchema(
    name="doc_id",
    dtype=DataType.VARCHAR,
    max_length=500
)

doc_source = FieldSchema(
    name="doc_source",
    dtype=DataType.VARCHAR,
    max_length=1000,
    default_value="NA"
)

doc_content = FieldSchema(
    name="text",
    dtype=DataType.VARCHAR,
    max_length=50000,
    default_value=""
)

vec_embeddings = FieldSchema(
    name="dense_embeddings",
    dtype=DataType.FLOAT_VECTOR,
    dim=1024
)

keyword_embeddings = FieldSchema(
    name="sparse_embeddings",
    dtype=DataType.SPARSE_FLOAT_VECTOR
)

In [8]:
schema = CollectionSchema(
  fields=[auto_id, doc_id, doc_content, doc_source, vec_embeddings, keyword_embeddings],
  description="odprt_schema",
  enable_dynamic_field=True
)

In [9]:
def create_collection(collection_name, schema):
    if utility.has_collection(collection_name):
        print(f"Collection '{collection_name}' already exists")
        return Collection(name=collection_name)
    # create the collection
    return Collection(name=collection_name, schema=schema, using='default', shards_num=2)

In [None]:
collection = create_collection(collection_name, schema)

In [None]:
bge_embed_model = FlagModel(
    'BAAI/bge-large-en-v1.5'
)
splade_embed_model = model.sparse.SpladeEmbeddingFunction(
    model_name="naver/splade-cocondenser-ensembledistil",
    device="cpu",
)

## Preprocessing code here

In [None]:
...

## Feeding texts (and tables) into respective Embedding Models
1. Dense Embeddings with BGE
2. Sparse Embeddings with SPLADE

In [None]:
def get_dense_and_sparse_embeddings(all_texts: List[str]):
    dense_embeddings_list = bge_embed_model.encode(all_texts)
    sparse_embeddings_list = splade_embed_model.encode_documents(all_texts)
    return dense_embeddings_list, sparse_embeddings_list

In [None]:
dense_embeddings_list, sparse_embeddings_list = get_dense_and_sparse_embeddings(all_texts)

## Batch ingestion

In [5]:
def batch_ingestion(collection, final_docs):
    all_ids, all_texts, all_sources = ...
    dense_embeddings_list, sparse_embeddings_list = get_dense_and_sparse_embeddings(all_texts)
    
    data = [
        all_ids,
        all_texts,
        all_sources,
        dense_embeddings_list,
        sparse_embeddings_list
    ]
    batch_size = 100
    total_elements = len(data[0])
    total_batches = (total_elements + batch_size - 1) // batch_size

    # using tqdm to create a progress bar
    for start in tqdm(range(0, total_elements, batch_size), 
                     total=total_batches,
                     desc="Ingesting batches"):
        end = min(start + batch_size, total_elements)
        batch = [sublist[start:end] for sublist in data]
        collection.insert(batch)

## Create index
1. Delete existing index
2. Craete new index

In [None]:
def drop_indexes(collection: Collection, index_names: List[str]) -> None:
    collection.release()
    for name in index_names:
        collection.drop_index(index_name=name)
        print(f"Index '{name}' has been dropped")

In [None]:
# drop_indexes(collection, index_names=["sparse_embeddings", "dense_embeddings"])

In [None]:
def create_all_indexes(collection: Collection) -> None:
    # dense embeddings index
    collection.create_index(
        field_name="dense_embeddings",
        index_params={
            "metric_type": "COSINE",
            "index_type": "HNSW",
            "params": {
                "M": 5,
                "efConstruction": 512
            }
        },
        index_name="dense_embeddings_index"
    )
    
    print("Dense embeddings index created")

    # sparse embeddings index
    collection.create_index(
        field_name="sparse_embeddings",
        index_params={
            "metric_type": "IP",
            "index_type": "SPARSE_INVERTED_INDEX",
            "params": {
                "drop_ratio_build": 0.2
            }
        },
        index_name="sparse_embeddings_index"
    )
    
    print("Sparse embeddings index created")
    # load
    collection.load()
    print("Collection loaded")

## Hybrid search
1. Load in collection
2. conduct hybrid search

In [3]:
def hybrid_search(query: str) -> str:
    dense_embedding = list(bge_embed_model.encode_queries([query], normalize_embeddings=True)[0])
    sparse_embedding = list(splade_embed_model.encode_queries([query]))
    
    search_results = collection.hybrid_search(
            reqs=[
                AnnSearchRequest(
                    data=[dense_embedding],  # content vector embedding
                    anns_field='dense_embeddings',  # content vector field
                    param={"metric_type": "COSINE", "params": {"M": 64, "efConstruction": 512}}, 
                    limit=3
                ),
                AnnSearchRequest(
                    data=list(sparse_embedding),  # keyword vector embedding
                    anns_field='sparse_embeddings',  # keyword vector field
                    param={"metric_type": "IP", "params": {"drop_ratio_build": 0.2}}, 
                    limit=3
                )
            ],
            output_fields=['doc_id', 'text', 'doc_source'],
            # using RRFRanker here for reranking
            rerank=RRFRanker(),
            limit=3
            )
    
    hits = search_results[0]
    
    context = []
    for res in hits:
        text = res.text
        source = res.doc_source
        context.append(f"Source: {source}\nContext: {text}")
    
    return "\n\n".join(context)