In [None]:
from utils_openai import (
    setup_openai_api, create_embeddings, create_llm,
    load_msme_data, create_vectorstore, get_baseline_prompt
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.load import dumps, loads

print("[OK] Imports successful!")

In [None]:
api_key = setup_openai_api()
embeddings = create_embeddings(api_key)
llm = create_llm(api_key)
documents, metadatas, ids = load_msme_data("msme.csv")

vectorstore = create_vectorstore(
    documents, metadatas, ids, embeddings,
    collection_name="msme_technique2",
    persist_directory="./chroma_db_technique2"
)

retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
print("[OK] Setup complete!")

In [None]:
query_gen_template = """You are an AI assistant helping to improve search results.
Your task is to generate 4 different versions of the given user question.

These variations should:
- Rephrase using different words
- Use different levels of specificity
- Include relevant synonyms
- Maintain the original intent

Provide ONLY the questions, one per line, without numbering or explanation.

Original question: {question}

Alternative questions:"""

query_gen_prompt = ChatPromptTemplate.from_template(query_gen_template)
print("[OK] Query generation prompt ready!")

In [None]:
# Chain that generates multiple queries
query_generator = (
    query_gen_prompt
    | llm
    | StrOutputParser()
    | (lambda x: [q.strip() for q in x.split('\n') if q.strip()])
)

print("[OK] Query generator chain created!")

In [None]:
def get_unique_docs(documents):
    """Remove duplicate documents using content hashing"""
    unique_docs = list(set(dumps(doc) for doc in documents))
    return [loads(doc) for doc in unique_docs]

print("[OK] Deduplication function ready!")

In [None]:
# Complete chain:
# 1. Generate multiple queries
# 2. Retrieve docs for each query (map)
# 3. Deduplicate
# 4. Pass to prompt with original question

multi_query_retrieval = (
    query_generator
    | retriever.map()  # Retrieve for each generated query
    | get_unique_docs  # Remove duplicates
)

prompt = get_baseline_prompt()

multi_query_rag_chain = (
    {"context": multi_query_retrieval, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

print("[OK] Multi-query RAG chain ready!")

In [None]:
# ----------------------------
# Baseline RAG (single query)
# ----------------------------

baseline_rag_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

print("[OK] Baseline RAG chain ready!")


In [None]:
test_queries = [
    "business money",
    "company rules",
    "get funding"
]


In [None]:
for query in test_queries:
    print("=" * 80)
    print(f"USER QUERY: {query}\n")

    print("ðŸ”¹ BASELINE RAG RESPONSE:")
    baseline_response = baseline_rag_chain.invoke(query)
    print(baseline_response)

    print("\nðŸ”¹ MULTI-QUERY RAG RESPONSE:")
    multi_query_response = multi_query_rag_chain.invoke(query)
    print(multi_query_response)


In [None]:
for query in test_queries:
    print("=" * 80)
    print(f"ORIGINAL QUERY: {query}")
    generated_queries = query_generator.invoke({"question": query})

    print("Generated query variations:")
    for q in generated_queries:
        print(f"- {q}")
