In [None]:
# TODO get rid of config paths make it config object
# TODO make the auth keys a part of config object
# TODO give a feature of model selction
# TODO how to add file filters?

In [None]:
from apps.base.api_aggregation.loggers import Loggers
import asyncio
import json
from typing import AsyncGenerator, Dict, Any
from gpt_researcher import GPTResearcher
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from langchain_aws import BedrockEmbeddings
from apps.langflow.models import DeepResearchConversationMessages
from asgiref.sync import sync_to_async

class SSELogStreamer:
    """
    Custom WebSocket replacement for GPTResearcher to emit real-time steps via async queue.
    """
    def __init__(self):
        self.queue = asyncio.Queue()

    async def send_json(self, data: Dict[str, Any]) -> None:
        """
        Called by GPTResearcher during research.
        Sends intermediate agent steps, log updates, etc.
        """
        # print("SSELogStreamer:", data)
        await self.queue.put(data)

    async def get_next(self) -> Dict[str, Any]:
        """
        Await next log from queue.
        """
        return await self.queue.get()

def get_researcher(
    mode: str,
    query: str,
    report_type: str,
    tone: str,
    websocket: SSELogStreamer
) -> GPTResearcher:
    match mode:
        case "web":
            researcher = GPTResearcher(
                query,
                report_type,
                "markdown",
                "web",
                tone,
                config_path='./configs/web-config.json',
                websocket=websocket
            )

        case "internal":
            client = QdrantClient(
                url="https://<domain>:<port>",
                api_key="<your-api-key>",
            )
            vector_store = QdrantVectorStore(
                client=client,
                collection_name="<collection_name>",
                embedding=BedrockEmbeddings(region_name="<region>", model_id="<BEDROCK_MODEL_ID>"),
            )
            researcher = GPTResearcher(
                query,
                report_type,
                "markdown",
                "langchain_vectorstore",
                tone,
                vector_store=vector_store,
                config_path="./configs/local-config.json",
                websocket=websocket
            )

        case "hybrid":
            client = QdrantClient(
                url="https://<domain>:<port>",
                api_key="<your-api-key>",
            )
            # TODO how to add file filters?
            vector_store = QdrantVectorStore(
                client=client,
                collection_name="<collection_name>",
                embedding=BedrockEmbeddings(region_name="<region>", model_id="<BEDROCK_MODEL_ID>"),
            )
            researcher = GPTResearcher(
                query,
                report_type,
                "markdown",
                "hybrid",
                tone,
                vector_store=vector_store,
                config_path="./configs/hybrid-config.json",
                websocket=websocket
            )
    
    return researcher

async def run_gpt_researcher_streaming(
    conversation_id: int,
    mode: str,
    query: str,
    report_type: str,
    report_source: str,
    tone: str,
) -> AsyncGenerator[str, None]:
    log_streamer = SSELogStreamer()
    accumulated_logs = []

    researcher = get_researcher(
        mode,
        query,
        report_type,
        tone,
        log_streamer
    )

    research_task = asyncio.create_task(researcher.conduct_research())

    try:
        while True:
            while not log_streamer.queue.empty():
                log = await log_streamer.get_next()

                if log.get("type") == "logs":
                    log_str = log.get("output", "")
                    accumulated_logs.append(log_str)

                yield f"data: {json.dumps({'event': 'step', 'data': log})}\n\n"

            if research_task.done():
                break

            await asyncio.sleep(0.1)

        await research_task

        write_task = asyncio.create_task(researcher.write_report())

        while True:
            while not log_streamer.queue.empty():
                log = await log_streamer.get_next()

                if log.get("type") == "logs":
                    log_str = log.get("output", "")
                    accumulated_logs.append(log_str)

                yield f"data: {json.dumps({'event': 'step', 'data': log})}\n\n"

            if write_task.done():
                break

            await asyncio.sleep(0.1)

        final_report = await write_task

        # Flush any remaining logs
        while not log_streamer.queue.empty():
            log = await log_streamer.get_next()

            if log.get("type") == "logs":
                log_str = log.get("output", "")
                accumulated_logs.append(log_str)

            yield f"data: {json.dumps({'event': 'step', 'data': log})}\n\n"

        yield f"data: {json.dumps({'event': 'final_report', 'data': final_report})}\n\n"

        Loggers.log_statement(message="Saving the logs and final report to DB")

        # Persist to DB using sync_to_async
        await sync_to_async(DeepResearchConversationMessages.objects.create)(
            sender="AI",
            message={"message": final_report, "logs": "\n".join(accumulated_logs)},
            conversation_id=conversation_id
        )

    except asyncio.CancelledError:
        yield f"data: {json.dumps({'event': 'error', 'data': 'Streaming interrupted'})}\n\n"