In [None]:
import os
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Pinecone
from langchain.schema import BaseRetriever, Document
from groq import Groq
from cerebras.cloud.sdk import Cerebras
import pinecone
from typing import List, Dict, Any

# =======================
# Configuration
# =======================
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
CEREBRAS_API_KEY = os.getenv("CEREBRAS_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

INDEX_CONFIG = {
    "legal-index": {"name": "legal-docs", "top_k": 2},
    "tech-index": {"name": "technical-docs", "top_k": 3}
}

# =======================
# Client Initialization
# =======================
pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
groq_client = Groq(api_key=GROQ_API_KEY)
embeddings = OpenAIEmbeddings()
if not CEREBRAS_API_KEY:
    raise EnvironmentError("CEREBRAS_API_KEY environment variable not set")

try:
    cerebras_client = Cerebras(api_key=CEREBRAS_API_KEY)
except Exception as e:
    raise e

# =======================
# Multi-Index Retriever
# =======================
class MultiIndexRetriever(BaseRetriever):
    def __init__(self, index_config: Dict[str, Any]):
        super().__init__()
        self.retrievers = {
            name: Pinecone(
                pc.Index(config["name"]),
                embeddings.embed_query,
                "text"
            ).as_retriever(search_kwargs={"k": config["top_k"]})
            for name, config in index_config.items()
        }

    def get_relevant_documents(self, query: str) -> List[Document]:
        combined_docs = []
        for name, retriever in self.retrievers.items():
            docs = retriever.get_relevant_documents(query)
            for doc in docs:
                doc.metadata["source_index"] = name  # Add source identifier
            combined_docs.extend(docs)
        return combined_docs

# =======================
# Groq LLM Wrapper
# =======================
# class GroqLLMWrapper:
#     def __init__(self, model_name: str = "llama3-70b-8192"):
#         self.client = Groq(api_key=GROQ_API_KEY)
#         self.model = model_name

#     def __call__(self, prompt: str) -> str:
#         response = self.client.chat.completions.create(
#             messages=[{"role": "user", "content": prompt}],
#             model=self.model,
#             temperature=0.3
#         )
#         return response.choices[0].message.content

# =======================
# Cerebras LLM Wrapper
# =======================
class CerebrasLLMWrapper:
    def __init__(self, model_name: str = "llama3.1-8b"):
        self.client = Cerebras(api_key=CEREBRAS_API_KEY)
        self.model = model_name

    def __call__(self, prompt: str) -> str:
        response = self.client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=self.model,
            temperature=0.3
        )
        return response.choices[0].message.content

# =======================
# RAG Pipeline Setup
# =======================
memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True,
    output_key="answer"
)

multi_retriever = MultiIndexRetriever(INDEX_CONFIG)

qa_chain = ConversationalRetrievalChain.from_llm(
    llm=CerebrasLLMWrapper(),
    retriever=multi_retriever,
    memory=memory,
    return_source_documents=True,
    verbose=True
)

# =======================
# Chat Interface
# =======================
def format_response(result: Dict) -> str:
    response = f"Answer: {result['answer']}\n\nSources:"
    sources = {doc.metadata["source_index"] for doc in result["source_documents"]}
    for source in sources:
        response += f"\n- {source}"
    return response

if __name__ == "__main__":
    print("Multi-Index RAG Chat (type '/exit' to quit)")
    while True:
        query = input("\nYou: ")
        if query.lower() == "/exit":
            break
            
        result = qa_chain({"question": query})
        print(f"\nAssistant: {format_response(result)}")

In [None]:
#Start Session in multiRAG chat
!curl -X POST "http://localhost:8000/start-session" \
     -H "Content-Type: application/json" \
     -d '{
           "system_prompt": "You are an AI assistant that provides legal advice."
         }'

In [None]:
#Delete session
curl -X DELETE "http://localhost:8000/delete-session" \
     -H "Content-Type: application/json" \
     -d '{
           "session_id": "a1b2c3d4-e5f6-7890-abcd-1234567890ef"
         }'

In [None]:
#Prompt the AI with a question
!curl -X POST "http://localhost:8000/chat" \
     -H "Content-Type: application/json" \
     -d '{
           "session_id": "a1b2c3d4-e5f6-7890-abcd-1234567890ef",
           "vector_stores": ["legal-index", "tech-index"],
           "message": "What are the latest updates in data privacy laws?",
           "system_prompt": "Please provide concise and formal responses."
         }'

In [None]:
#Sample response from the AI
{
  "session_id": "a1b2c3d4-e5f6-7890-abcd-1234567890ef",
  "vector_stores": ["legal-index", "tech-index"],
  "input": "What are the latest updates in data privacy laws?",
  "response": "Answer: [AI-generated answer]\n\nSources:\n- legal-index\n- tech-index"
}

In [None]:
#Prompt the AI with a question without a systme prompt
curl -X POST "http://localhost:8000/chat" \
     -H "Content-Type: application/json" \
     -d '{
           "session_id": "a1b2c3d4-e5f6-7890-abcd-1234567890ef",
           "vector_stores": ["legal-index", "tech-index"],
           "message": "How can I improve my productivity?"
         }'