## RAG example with Langchain, Milvus, and vLLM

Requirements:
- A Milvus instance, either standalone or cluster.
- Connection credentials to Milvus must be available as environment variables: MILVUS_USERNAME and MILVUS_PASSWORD.
- A vLLM inference endpoint. In this example we use the OpenAI Compatible API.

### Needed packages and imports

In [None]:
!pip install -q einops==0.7.0 langchain==0.1.9 pymilvus==2.3.6 sentence-transformers==2.4.0 openai==1.13.3

In [None]:
import os
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import RetrievalQA
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks import StdOutCallbackHandler
from langchain_community.llms import VLLMOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import Milvus

#### Bases parameters, Inference server and Milvus info

In [None]:
os.environ["OPENAI_API_KEY"] = "EMPTY"
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')

In [None]:
# Replace values according to your Milvus deployment

### local / in-cluster
# INFERENCE_SERVER_URL = "http://vllm.rag-with-llama2-model-deployment.svc.cluster.local:8000/v1"

### ocp route
INFERENCE_SERVER_URL ="http://vllm:8000/v1"
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
MAX_TOKENS=2048
TOP_P=0.95
TEMPERATURE=0.01
PRESENCE_PENALTY=1.03
MILVUS_HOST = "vectordb-milvus"
MILVUS_PORT = 19530
MILVUS_USERNAME = "root"
MILVUS_PASSWORD = "Milvus"
MILVUS_COLLECTION = "redhat_notes"
RAG_NEAREST_NEIGHBOURS = "5"

#### Initialize the connection

In [None]:
model_kwargs = {'trust_remote_code': True}
embeddings = HuggingFaceEmbeddings(
    model_name="nomic-ai/nomic-embed-text-v1",
    model_kwargs=model_kwargs,
    show_progress=False
)

store = Milvus(
    embedding_function=embeddings,
    connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT, "user": MILVUS_USERNAME, "password": MILVUS_PASSWORD},
    collection_name=MILVUS_COLLECTION,
    metadata_field="metadata",
    text_field="page_content",
    drop_old=False
    )

#### Initialize query chain

In [None]:
template="""<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant named SnemeisBot answering questions.
You will be given a question you need to answer, and a context about everything a sales team wrote down about its customers to provide you with information. You must answer the question based as much as possible on this context.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, don't share false information.
<</SYS>>

Context: 
{context}

Question: {question} [/INST]
"""

In [None]:
template = """<s>[INST] <<SYS>>
            You are a helpful, respectful and honest assistant named SnemeisBot. You are answering a question.
            You will be given a question you need to answer, and a context about everything a sales team wrote down about its customers to provide you with information. You must answer the question based as much as possible on this context.

            If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, don't share false information.
            <</SYS>>
            
            Use the following context (delimited by <ctx></ctx>) and the chat history (delimited by <hs></hs>) to answer the question in a helpful, respectful manner:
            ------
            <ctx>
            {context}
            </ctx>
            ------
            <hs>
            {history}
            </hs>
            ------
            {question}
            Answer:
            """

# QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
prompt_template = PromptTemplate(
        input_variables=["history", "context", "question"],
        template=template,
    )

llm = VLLMOpenAI(
    openai_api_key=OPENAI_API_KEY,
    openai_api_base=INFERENCE_SERVER_URL,
    model_name=MODEL_NAME,
    max_tokens=MAX_TOKENS,
    top_p=TOP_P,
    temperature=TEMPERATURE,
    presence_penalty=PRESENCE_PENALTY,
    streaming=True,
    verbose=False,
    callbacks=[StreamingStdOutCallbackHandler()]
)

handler = StdOutCallbackHandler()

qa_chain = RetrievalQA.from_chain_type(
        llm,
        retriever=store.as_retriever(search_type="similarity", search_kwargs={"k": int(RAG_NEAREST_NEIGHBOURS)}),
        chain_type_kwargs={
            "prompt": prompt_template, 
            "verbose": True,
            "memory": ConversationBufferMemory(
                memory_key="history",
                input_key="question"),},
        return_source_documents=True
    )

os.environ["TOKENIZERS_PARALLELISM"] = "false"

#### Query example

In [None]:
question = "In that context - who is the customer?"
result = qa_chain.invoke({"query": question})

In [None]:
result

#### Retrieve source

In [None]:
def remove_duplicates(input_list):
    unique_list_doctitle = []
    unique_list_doccontent = []
    for item in input_list:
        if item.metadata['source'] not in unique_list_doctitle:
            unique_list_doctitle.append(item.metadata['source'])
            unique_list_doccontent.append(item.page_content)
    return unique_list_doctitle, unique_list_doccontent

results = remove_duplicates(result['source_documents'])

for s in results:
    print(s)

In [None]:
def remove_duplicates(input_list):
    unique_list = []
    for item in input_list:
        if item not in unique_list:
            print(item.page_content)
            unique_list.append(item.metadata['source'])
    return unique_list

results = remove_duplicates(result['source_documents'])

for s in results:
    print(s)