In [2]:
from pymilvus import Collection, connections
from RAG_Functions import *
import time
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

  from .autonotebook import tqdm as notebook_tqdm


## Load models

In [3]:
# embedding model
embedding_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# chat model
chat_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
chat_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

## Connect to Milvus collection 

In [4]:
connections.connect(host='localhost', port='19530')
collection_name = 'text_embeddings'
collection = Collection(name=collection_name)

## Index comparison

### IVF_FLAT

In [6]:
collection.drop_index(index_name='embedding_index')
index_params = {
    "metric_type": "L2",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128},
    "index_name": "embedding_index"
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

In [7]:
# Chat with model
input_text = input()

# Get embedding of input
input_embedding = get_mixedbread_of_query(embedding_model, input_text)

# Start timing query
start_time = time.time()

# Top5 sentences
top5_sentences = return_top_5_sentences(collection, input_embedding)

# End timing query
end_time = time.time()

print(top5_sentences)

(['Marvell: 5.', 'hinge: 15b.', 'BlaBlaCar: (xi) &nbsp.', 'InternetArchive: background: #333.\n}', 'InternetArchive: background: #333.\n}'], ['BlaBlaCar_TermsandConditions.txt', 'InternetArchive_Terms.txt', 'hinge_Terms.txt', 'Marvell_TermsofUse.txt'], 0.22635221481323242)


In [None]:
collection.drop_index(index_name='embedding_index')
index_params = {
    "metric_type": "L2",
    "index_type": "FLAT",
    "index_name": "embedding_index"
}
collection.create_index(field_name="embedding", index_params=index_params)

In [None]:
collection.drop_index(index_name='embedding_index')
index_params = {
    "metric_type": "L2",
    "index_type": "SCANN",
    "params": {"nlist": 128},
    "index_name": "embedding_index"
}
collection.create_index(field_name="embedding", index_params=index_params)