In [None]:

# async def summarize_content_tool(content: List[Document]) -> str:
#     """
#     Description:
#     This function summarizes a list of documents using a state-based summarization pipeline. It extracts summaries from individual documents, merges them iteratively, and generates a final summary.
#
#     Parameters:
#     content (List[Document]): A list of Document objects to be summarized.
#
#     Returns:
#     A str containing the final summarized content.
#     """
#     async def generate_summary(state: SummaryState):
#         try:
#             response = await initialize_doc_parser_chain().ainvoke(state["content"])
#             return {"summaries": [response]}
#         except Exception as e:
#             logger.exception("Failed to generate summary")
#             raise e
#
#     async def collapse_summaries(state: OverallState):
#         try:
#             doc_lists = split_list_of_docs(
#                 state["collapsed_summaries"], length_function, config_settings.MAX_TOKENS
#             )
#             results = []
#             for doc_list in doc_lists:
#                 results.append(await acollapse_docs(doc_list, reduce_summary_chain().ainvoke))
#             return {"collapsed_summaries": results}
#         except Exception as e:
#             logger.exception("Failed to collapse summaries")
#             raise e
#
#     async def generate_final_summary(state: OverallState):
#         try:
#             response = await reduce_summary_chain().ainvoke(state["collapsed_summaries"])
#             return {"final_summary": response}
#         except Exception as e:
#             logger.exception("Failed to generate final summary")
#             raise e
#
#     try:
#         graph = StateGraph(OverallState)
#         graph.add_node("generate_summary", generate_summary)
#         graph.add_node("collect_summaries", collect_summaries)
#         graph.add_node("collapse_summaries", collapse_summaries)
#         graph.add_node("generate_final_summary", generate_final_summary)
#
#         graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
#         graph.add_edge("generate_summary", "collect_summaries")
#         graph.add_conditional_edges("collect_summaries", should_collapse)
#         graph.add_conditional_edges("collapse_summaries", should_collapse)
#         graph.add_edge("generate_final_summary", END)
#
#         app = graph.compile()
#
#         async for step in app.astream(
#                 {"contents": [doc.page_content for doc in content]},
#                 {"recursion_limit": 10},
#         ):
#             if "final_summary" in step:
#                 return step["final_summary"]
#
#         raise Exception("Failed to generate final summary")
#
#     except Exception as e:
#         logger.exception("Summarization tool failed")
#         raise e


In [None]:

# async def summarize_content_tool(content: list[Document]) -> str:
#     async def generate_summary(state: SummaryState):
#         response = await initialize_doc_parser_chain.ainvoke(state["content"])
#         return {"summaries": [response]}
#
#     async def collapse_summaries(state: OverallState):
#         doc_lists = split_list_of_docs(
#             state["collapsed_summaries"], length_function, config_settings.MAX_TOKENS
#         )
#         results = []
#         for doc_list in doc_lists:
#             results.append(await acollapse_docs(doc_list, reduce_summary_chain.ainvoke))
#
#         return {"collapsed_summaries": results}
#
#     async def generate_final_summary(state: OverallState):
#         response = await reduce_summary_chain.ainvoke(state["collapsed_summaries"])
#         return {"final_summary": response}
#
#     graph = StateGraph(OverallState)
#     graph.add_node("generate_summary", generate_summary)  # same as before
#     graph.add_node("collect_summaries", collect_summaries)
#     graph.add_node("collapse_summaries", collapse_summaries)
#     graph.add_node("generate_final_summary", generate_final_summary)
#
#     # Edges:
#     graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
#     graph.add_edge("generate_summary", "collect_summaries")
#     graph.add_conditional_edges("collect_summaries", should_collapse)
#     graph.add_conditional_edges("collapse_summaries", should_collapse)
#     graph.add_edge("generate_final_summary", END)
#
#     app = graph.compile()
#
#     async for step in app.astream(
#             {"contents": [doc.page_content for doc in content]},
#             {"recursion_limit": 10},
#     ):
#         print(list(step.keys()))
#

In [None]:
import pprint

from langgraph.prebuilt import create_react_agent
import asyncio
from typing import List
from langchain_core.documents import Document
from loguru import logger

from langgraph.checkpoint.memory import MemorySaver
from domains.utils import get_chat_model

from langchain_core.messages import HumanMessage

from langgraph.graph import END, START, StateGraph
from domains.agents.models import QueryRequest, OverallState, SummaryState
from domains.agents.tools import qna_tool, information_extraction_tool, summarize_content_tool
from domains.settings import config_settings


async def orchestrator_agent(query: str) -> str:
    """
    Orchestrator agent that uses qna_tool, information_extraction_tool, and summarize_content_tool.

    Args:
        query: The query string.

    Returns:
        A string containing the final summary.
    """
    async def run_qna_tool(state: OverallState):
        try:
            request = QueryRequest(query=state['query'])
            documents = await qna_tool(request)
            return documents
        except Exception as e:
            logger.exception("Failed to run qna_tool")
            raise e

    async def run_information_extraction_tool(state: OverallState):
        try:
            documents = await information_extraction_tool(query=state['query'])
            return documents
        except Exception as e:
            logger.exception("Failed to run information_extraction_tool")
            raise e

    async def run_summarize_content_tool(state: OverallState):
        try:
            summary = await summarize_content_tool(state["documents"])
            return {"final_summary": summary}
        except Exception as e:
            logger.exception("Failed to run summarize_content_tool")
            raise e

    try:
        graph = StateGraph(OverallState)
        graph.add_node("run_qna_tool", run_qna_tool)
        graph.add_node("run_information_extraction_tool", run_information_extraction_tool)
        graph.add_node("run_summarize_content_tool", run_summarize_content_tool)

        graph.add_edge(START, "run_qna_tool")
        graph.add_conditional_edges("run_qna_tool", lambda state: "run_information_extraction_tool" if not state["documents"] else "run_summarize_content_tool")
        graph.add_edge("run_information_extraction_tool", "run_summarize_content_tool")
        graph.add_edge("run_summarize_content_tool", END)

        app = graph.compile()

        async for step in app.astream({}, {"recursion_limit": 10}):
            if "final_summary" in step:
                return step["final_summary"]

        raise Exception("Failed to generate final summary")
    except Exception as e:
        logger.exception("Orchestrator agent failed")
        raise e


memory = MemorySaver()

async def new(query: str, id):
    agent_executor=create_react_agent(
        model=get_chat_model(model_key="OPENAI_CHAT"),
        tools=[orchestrator_agent],
        checkpointer=memory,
    )

    config = {"configurable": {"thread_id": id}}
    async for step in agent_executor.astream(
            {"messages": [HumanMessage(content=query)]},
            config,
            stream_mode="values",
    ):
        step["messages"][-1].pretty_print()



In [None]:
# async def run_agents_2(query: str, id: str):
#     # Define custom state schema
#     class AgentState(TypedDict):
#         messages: Annotated[list[BaseMessage], add_messages]  # Chat messages
#         query: str  # Original query
#         result: Optional[str]  # Orchestrator result
#         is_last_step: IsLastStep
#         remaining_steps: RemainingSteps
#
#     # Create memory for state persistence
#     memory = MemorySaver()
#
#     # Create tool node for orchestrator
#     async def run_orchestrator(state: AgentState) -> AgentState:
#         """Execute orchestrator and update state with result"""
#         try:
#             result = await orchestrator_agent(state["query"])
#             # Handle case where result is a dictionary
#             if isinstance(result, dict) and "contents" in result:
#                 content = result["contents"][0] if isinstance(result["contents"], list) else str(result["contents"])
#             else:
#                 content = str(result)
#
#             return {
#                 "messages": [AIMessage(content=content)],
#                 "result": content
#             }
#
#         except Exception as e:
#             logger.exception("Orchestrator execution failed")
#             error_msg = f"Orchestrator failed: {str(e)}"
#             return {
#                 "messages": [AIMessage(content=error_msg)],
#                 "result": error_msg
#             }
#     # Create workflow graph
#     workflow = StateGraph(AgentState)
#
#     # Add nodes
#     workflow.add_node("orchestrator", run_orchestrator)
#
#     # Set entry point and flow
#     workflow.set_entry_point("orchestrator")
#     workflow.add_edge("orchestrator", END)
#
#     # Compile graph
#     agent_executor = workflow.compile(checkpointer=memory)
#
#     # Execute graph
#     config = {"configurable": {"thread_id": id}}
#     final_result = None
#
#     # Stream results
#     async for step in agent_executor.astream(
#             {
#                 "messages": [HumanMessage(content=query)],
#                 "query": query,
#                 "result": None
#             },
#             config
#     ):
#         if "result" in step:
#             final_result = step["result"]
#
#     logger.info(f"Agent result: {final_result}")
#     return final_result

In [None]:
# class StreamingLLMCallbackHandler(AsyncCallbackHandler):
#     """Callback handler for streaming LLM responses."""
#
#     def __init__(self, websocket_internal: WebSocket):
#         self.websocket = websocket_internal
#
#     async def on_chat_model_start(
#             self,
#             serialized: typing.Dict[str, typing.Any],
#             messages: typing.List[typing.List[BaseMessage]],
#             *,
#             run_id: uuid.UUID,
#             parent_run_id: uuid.UUID | None = None,
#             tags: typing.List[str] | None = None,
#             metadata: typing.Dict[str, typing.Any] | None = None,
#             **kwargs: typing.Any,
#     ) -> typing.Any:
#         logger.info(f"LLM chain chat model start with serialized: {serialized}\nmessages: {messages}\nkwargs: {kwargs}")
#
#     async def on_llm_start(self,
#                            serialized: typing.Dict[str, typing.Any],
#                            prompts: typing.List[str],
#                            *,
#                            run_id: uuid.UUID,
#                            parent_run_id: typing.Optional[uuid.UUID] = None,
#                            tags: typing.Optional[typing.List[str]] = None,
#                            metadata: typing.Optional[typing.Dict[str, typing.Any]] = None,
#                            **kwargs: typing.Any) -> None:
#         logger.info(f'LLM chain started with prompts: {prompts} and kwargs: {kwargs}')
#
#     async def on_llm_new_token(self, token: str, **kwargs: typing.Any) -> None:
#         timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3]
#         resp = {
#             "message": token,
#             "type": "stream",
#             "content_type": None
#         }
#         formatted_resp = f'{resp}\n{timestamp}'
#         await self.websocket.send_json(formatted_resp)
#
#     async def on_llm_end(self, response: LLMResult, *, run_id: uuid.UUID,
#                          parent_run_id: typing.Optional[uuid.UUID] = None,
#                          tags: typing.Optional[typing.List[str]] = None, **kwargs: typing.Any) -> None:
#         logger.info(f'LLM chain ended with response: {response}')


# class StreamingLLMCallbackHandler(langchain.callbacks.base.AsyncCallbackHandler):
#     """Callback handler for streaming LLM responses."""
#
#     def __init__(self, websocket_internal):
#         self.websocket = websocket_internal
#
#     async def on_chat_model_start(
#             self,
#             serialized: typing.Dict[str, typing.Any],
#             messages: typing.List[typing.List[BaseMessage]],
#             *,
#             run_id: uuid.UUID,
#             parent_run_id: uuid.UUID | None = None,
#             tags: typing.List[str] | None = None,
#             metadata: typing.Dict[str, typing.Any] | None = None,
#             **kwargs: typing.Any,
#     ) -> typing.Any:
#         logger.info(f"LLM chain chat model start with serialized: {serialized}\nmessages: {messages}\nkwargs: {kwargs}")
#
#     async def on_llm_start(self,
#                            serialized: typing.Dict[str, typing.Any],
#                            prompts: typing.List[str],
#                            *,
#                            run_id: uuid.UUID,
#                            parent_run_id: typing.Optional[uuid.UUID] = None,
#                            tags: typing.Optional[typing.List[str]] = None,
#                            metadata: typing.Optional[typing.Dict[str, typing.Any]] = None,
#                            **kwargs: typing.Any) -> None:
#         logger.info(f'LLM chain started with prompts: {prompts} and kwargs: {kwargs}')
#
#     async def on_llm_new_token(self, token: str, **kwargs: typing.Any) -> None:
#         timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3]
#         resp = {
#             "message": token,
#             "type": "stream",
#             "content_type": None
#         }
#         formatted_resp = f'{resp}\n{timestamp}'
#         await self.websocket.send_json(formatted_resp)
#
#     async def on_llm_end(self, response: langchain.schema.output.LLMResult, *, run_id: uuid.UUID,
#                          parent_run_id: typing.Optional[uuid.UUID] = None,
#                          tags: typing.Optional[typing.List[str]] = None, **kwargs: typing.Any) -> None:
#         logger.info(f'LLM chain ended with response: {response}')

# class StreamingLLMCallbackHandler(langchain.callbacks.base.AsyncCallbackHandler):
#     """Callback handler for streaming LLM responses."""
#
#     def __init__(self, websocket_internal):
#         self.websocket = websocket_internal
#
#     async def on_chat_model_start(
#             self,
#             serialized: typing.Dict[str, typing.Any],
#             messages: typing.List[typing.List[BaseMessage]],
#             *,
#             run_id: uuid.UUID,
#             parent_run_id: uuid.UUID | None = None,
#             tags: typing.List[str] | None = None,
#             metadata: typing.Dict[str, typing.Any] | None = None,
#             **kwargs: typing.Any,
#     ) -> typing.Any:
#         logger.info(f"LLM chain chat model start with serialized: {serialized}\nmessages: {messages}\nkwargs: {kwargs}")
#
#     async def on_llm_start(self,
#                            serialized: typing.Dict[str, typing.Any],
#                            prompts: typing.List[str],
#                            *,
#                            run_id: uuid.UUID,
#                            parent_run_id: typing.Optional[uuid.UUID] = None,
#                            tags: typing.Optional[typing.List[str]] = None,
#                            metadata: typing.Optional[typing.Dict[str, typing.Any]] = None,
#                            **kwargs: typing.Any) -> None:
#         logger.info(f'LLM chain started with prompts: {prompts} and kwargs: {kwargs}')
#
#     async def on_llm_new_token(self, token: str, **kwargs: typing.Any) -> None:
#         resp = ChatResponse(message=token, type="stream")
#         await self.websocket.send_json(resp.dict())
#
#     async def on_llm_end(self, response: langchain.schema.output.LLMResult, *, run_id: uuid.UUID,
#                          parent_run_id: typing.Optional[uuid.UUID] = None,
#                          tags: typing.Optional[typing.List[str]] = None, **kwargs: typing.Any) -> None:
#         logger.info(f'LLM chain ended with response: {response}')

In [None]:
# import pprint
#
# import fastapi
# from pydantic import BaseModel
#
# from domains.retreival.rag_util import send_message_over_websocket
# from domains import retreival
# from domains.retreival.utils import transform_user_query_for_retreival
# from fastapi import WebSocket
# from langchain_core.documents import Document
#
# from domains.retreival.pinecone_doc_retreival.utils import get_related_docs_without_context
# from typing import Optional, List, Dict, Callable
# from domains.retreival.initialize_memory import initialise_memory_from_chat_context
# from domains.settings import config_settings
#
# from domains.retreival.utils import get_chat_model_with_streaming
# from langchain.chains import question_answering as q_a
#
# from loguru import logger
# from domains.retreival.models import RagUseCase, RAGGenerationResponse
# from domains.retreival.models import Message
# from domains.retreival.prompts import PROMPT_PREFIX_QNA, PROMPT_SUFFIX, initialise_doc_search_prompt_template
# from langchain.prompts import PromptTemplate
#
#
#
# def run_rag(
#     question: str,
#     language: str,
#     chat_context: Optional[List[Message]],
#     websocket: fastapi.WebSocket,
#     namespace: str,
# ):
#     DEFAULT_MIN_SCORE = 0.8
#     # initialise minimum score
#     minimum_score = (
#         config_settings.MINIMUM_SCORE
#         or DEFAULT_MIN_SCORE
#     )
#
#     # initialise prompt
#     prompt_qna_ask_question = initialise_doc_search_prompt_template(
#         PROMPT_PREFIX_QNA, PROMPT_SUFFIX
#     )
#
#     # initialise memory
#     memory = initialise_memory_from_chat_context(
#         chat_context
#     )
#
#     if not namespace or namespace == "":
#         namespace = config_settings.PINECONE_DEFAULT_DEV_NAMESPACE
#
#     return rag_with_streaming(
#         websocket=websocket,
#         language=language,
#         question=question,
#         minimum_score=minimum_score,
#         prompt_template_ask_question=prompt_qna_ask_question,
#         memory=memory,
#         namespace=namespace,
#     )
#
#
# async def rag_with_streaming(
#     websocket: fastapi.WebSocket,
#     question: str,
#     language: str,
#     minimum_score: float,
#     prompt_template_ask_question: PromptTemplate,
#     memory,
#     namespace: str,
#     use_case: RagUseCase = RagUseCase.DEFAULT,
#     citations_count=0,
# ):
#     try:
#         if not citations_count:
#             citations_count = config_settings.PINECONE_TOTAL_DOCS_TO_RETRIEVE
#
#         # citations_toggle = config_settings.CITATIONS_TOGGLE
#         index_name = config_settings.PINECONE_INDEX_NAME
#
#         await send_message_over_websocket(
#             websocket, "", retreival.MESSAGE_TYPE_START
#         )
#
#         await send_message_over_websocket(
#             websocket,
#             "",
#             retreival.MESSAGE_TYPE_START,
#             content_type=retreival.CONTENT_TYPE_OPTIMISED_QUESTION,
#         )
#
#         retreival_query = await transform_user_query_for_retreival(
#             question, "OPTIMIZED_QUESTION_MODEL"
#         )
#
#         logger.debug(f"Retreival query: {retreival_query}")
#
#         related_docs_with_score = []
#         if retreival_query or retreival_query != "None":
#             related_docs_with_score = await get_related_docs_without_context(
#                 index_name,
#                 namespace,
#                 retreival_query,
#             )
#
#             # logger.debug(
#             #     "\n".join(
#             #         [
#             #             f'File Name: {doc.metadata.get("file_name", "")}, Score: {score}, Page Number: {doc.metadata.get("page", 0)}'
#             #             for doc, score in related_docs_with_score
#             #         ]
#             #     )
#             # )
#         logger.info(f"Related docs without context: {related_docs_with_score}")
#         await send_message_over_websocket(
#             websocket,
#             "",
#             retreival.MESSAGE_TYPE_END,
#             content_type=retreival.CONTENT_TYPE_OPTIMISED_QUESTION,
#         )
#
#         route = RagUseCase.DEFAULT
#
#         rag_generation: RAGGenerationResponse = await generator_routing(
#             memory,
#             language,
#             retreival_query,
#             prompt_template_ask_question,
#             websocket,
#             route,
#             citations_count,
#             minimum_score,
#             related_docs_with_score,
#         )
#
#         await send_message_over_websocket(
#             websocket, "", retreival.MESSAGE_TYPE_END, content_type=retreival.CONTENT_TYPE_ANSWER
#         )
#
#
#         await send_message_over_websocket(
#             websocket, "", retreival.MESSAGE_TYPE_END
#         )
#
#     except Exception as e:
#         print(f"Error: {e}")
#         await send_message_over_websocket(
#             websocket, f"Error: {e}", retreival.MESSAGE_TYPE_ERROR
#         )
#         logger.error(f"Error: {e}")
#         return
#
#
# async def generator_routing(
#     memory,
#     language: str,
#     optimised_question: str,
#     prompt_template_ask_question: PromptTemplate,
#     websocket: WebSocket,
#     route: str,
#     citations_count: int,
#     minimum_score: float,
#     related_docs_with_score: list[tuple[Document, float]] = [],
# ) -> RAGGenerationResponse:
#
#     if route == RagUseCase.DEFAULT:
#         response = await run_doc_retrieval_flow(
#             memory,
#             optimised_question,
#             prompt_template_ask_question,
#             related_docs_with_score[:citations_count],
#             websocket,
#             minimum_score,
#             language,
#         )
#
#     return response
#
# from langchain_core.output_parsers import StrOutputParser
#
# async def run_doc_retrieval_flow(
#     memory,
#     optimised_question: str,
#     prompt_template_ask_question: PromptTemplate,
#     related_docs_with_score: list[tuple[Document, float]],
#     websocket: WebSocket,
#     minimum_score: float,
#     language: str,
# ) -> RAGGenerationResponse:
#
#     # document_count = len(
#     #     [doc for doc in related_docs_with_score if doc[1] > minimum_score]
#     # )
#
#     document_count = len(related_docs_with_score)
#
#     llm =get_chat_model_with_streaming(
#         websocket, model_key=config_settings.LLMS.get("OPENAI_CHAT")
#     )
#     if llm:
#         llm_chain = prompt_template_ask_question | llm | StrOutputParser()
#
#         response = await llm_chain.ainvoke(
#             {
#                 "question": optimised_question,
#                 "chat_history": memory.buffer_as_str,
#                 "doc_count": str(document_count),
#                 "context": related_docs_with_score,
#                 "language": language
#             },
#         )
#         logger.info(f"Response: {response}")
#
#         return RAGGenerationResponse(
#             answer=response
#         )