# Playground for RAG


In [3]:
import sys
import os

# Add the parent directory of 'src' to sys.path
sys.path.append(os.path.abspath("../"))

In [4]:
from dotenv import load_dotenv
load_dotenv()



True

In [5]:
from langchain_cohere import ChatCohere
from langchain_openai import OpenAI
from langchain_cohere import CohereEmbeddings
from langchain_community.embeddings import OpenAIEmbeddings
from langgraph.graph import END, StateGraph, MessagesState
from langchain import hub
from langchain_core.documents import Document
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from astrapy.constants import VectorMetric
from astrapy import DataAPIClient, Collection
import psycopg_pool
from psycopg_pool import AsyncConnectionPool
from contextlib import asynccontextmanager

COHERE_API_KEY = os.getenv("COHERE_API_KEY")
MODEL = "command-r-plus-08-2024"
EMBEDDING_MODEL = "embed-english-v3.0"
ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT")
ASTRA_DB_COLLECTION = os.getenv("ASTRA_DB_COLLECTION")
PROMPT = """
        You are an AI assistant specialized in question-answering tasks.
        Your responses must be strictly based on the provided retrieved context. 
        If the context does not contain sufficient information to answer the question, respond with:
        "I apologize, but I don't have enough relevant information in my knowledge base to provide an accurate answer to your question. Please feel free to rephrase your question or ask about a different topic.".

        Do not include information or assumptions outside the provided context.
        Provide answers that are accurate, concise, and professional.
        Context for this task:
        {docs_content}
    """

# Setup the connection pool (asynchronous)
DATABASE_URL = os.getenv("POSTGRES_CONNECTION_STRING")
MAX_POOL_SIZE = 20
AUTOCOMMIT = True
PREPARE_THRESHOLD = 0
pool: AsyncConnectionPool = psycopg_pool.AsyncConnectionPool(
    conninfo=DATABASE_URL,
    max_size=MAX_POOL_SIZE,
    kwargs={
        "autocommit": AUTOCOMMIT,
        "prepare_threshold": PREPARE_THRESHOLD
    }
)

prompt = hub.pull("rlm/rag-prompt")
def get_llm():
    return ChatCohere(
        cohere_api_key=COHERE_API_KEY,
        model=MODEL
    )

llm = get_llm()

# Embedding Model Initialization
def get_embedding_model():
    return CohereEmbeddings(
        cohere_api_key=COHERE_API_KEY,
        model=EMBEDDING_MODEL
    )

embedding_model = get_embedding_model()

def similarity_search(embedding, limit=5):
    client = DataAPIClient()
    db = client.get_database(
        ASTRA_DB_API_ENDPOINT,
        token=ASTRA_DB_APPLICATION_TOKEN
    )

    collection: Collection = db.get_collection(ASTRA_DB_COLLECTION)
    if collection is None:
        collection = db.create_collection(
            ASTRA_DB_COLLECTION,
            dimension=3,
            metric=VectorMetric.COSINE,
        )
    return collection.find(
        {},
        sort={"$vector": embedding},
        limit=limit,
        include_similarity=True)

@asynccontextmanager
async def get_db_connection():
    """
    Context manager for managing database connections using the connection pool.
    """
    async with pool.connection() as conn:
        yield conn

@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """
    Perform a similarity search to retrieve relevant information.
    """
    embedding = embedding_model.embed_query(query)
    rows = similarity_search(embedding, limit=10)
    context = [
        Document(
            page_content=row["text"],
            metadata={
                "source_key": row["source_key"],
                "source_label": row["source_label"],
                "similarity": row["$similarity"]
            }
        ) for row in rows
    ]
    serialized = "\n\n".join(
        f"RAG_SOURCE_METADATA: {doc.metadata}\nRAG_SOURCE_CONTENT: {doc.page_content}\nEND_RAG_SOURCE_CONTENT\n"
        for doc in context
    )
    return serialized, context

# Query or Respond
def query_or_respond(state: MessagesState):
    """Generate tool call for retrieval or respond."""
    llm_with_tools = llm.bind_tools([retrieve])
    # print("################")
    # print("Query or Respond:", state["messages"])
    response = llm_with_tools.invoke(state["messages"])

    # print("Response:", response)
    return {"messages": [response]}

tools = ToolNode([retrieve])

def generate(state: MessagesState):
    """Generate answer."""
    # Get generated ToolMessages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # Format into prompt
    docs_content = "\n\n".join(doc.content for doc in tool_messages)
    system_message_content = PROMPT.format(docs_content=docs_content)
    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]
    prompt = [SystemMessage(system_message_content)] + conversation_messages

    # Run
    # print("################")
    # print("Prompt:", prompt)
    response = llm.invoke(prompt)
    # print("REsponse2:", response)
    return {"messages": [response]}

async def ask_question(question: str, thread_id: str):
    config = {"configurable": {"thread_id": thread_id}}

    graph_builder = StateGraph(MessagesState)
    graph_builder.add_node(query_or_respond)
    graph_builder.add_node(tools)
    graph_builder.add_node(generate)

    graph_builder.set_entry_point("query_or_respond")
    graph_builder.add_conditional_edges(
        "query_or_respond",
        tools_condition,
        {END: END, "tools": "tools"},
    )
    graph_builder.add_edge("tools", "generate")
    graph_builder.add_edge("generate", END)


    async with get_db_connection() as pool:
        checkpointer = AsyncPostgresSaver(pool)
        await checkpointer.setup()
        graph = graph_builder.compile(checkpointer=checkpointer)
        async for message, metadata in graph.astream(
            {"messages": [{"role": "user", "content": question}]},
            stream_mode="messages",
            config=config,
        ):
            yield message.content




In [None]:
async for response in ask_question("Hey, what is sorting?", "0987"):
    print(response)




I
 will
 search
 for
 '
what
 is
 sorting
?'











RAG_SOURCE_METADATA: {'source_key': 'c85d9e04221f6198d1bf4825c24fd39310dd85550a4899b07c189476a7a7fdbf', 'source_label': 'https://www.geeksforgeeks.org/dsa-tutorial-learn-data-structures-and-algorithms/', 'similarity': 0.7891314}
RAG_SOURCE_CONTENT: in various applications such as databases, we 2 min read Sorting Algorithms A Sorting Algorithm is used to rearrange a given array or list of elements in an order. Sorting is provided in library implementation of most of the programming languages. Basics of Sorting Algorithms:Introduction to Sorting Applications of Sorting Sorting Algorithms:Comparison Based : Selection Sor 3 min read Recursive Algorithms Recursion is technique used in computer science to solve big problems by breaking them into smaller,
END_RAG_SOURCE_CONTENT


RAG_SOURCE_METADATA: {'source_key': 'c85d9e04221f6198d1bf4825c24fd39310dd85550a4899b07c189476a7a7fdbf', 'source_label': 'https://www.geeksforgeeks.org/dsa-tutor