diff --git a/apex/services/deep_research/deep_research_base.py b/apex/services/deep_research/deep_research_base.py index 498fe9e36..0d0a32db7 100644 --- a/apex/services/deep_research/deep_research_base.py +++ b/apex/services/deep_research/deep_research_base.py @@ -6,5 +6,5 @@ class DeepResearchBase(LLMBase): async def invoke( self, messages: list[dict[str, str]], body: dict[str, Any] | None = None - ) -> tuple[str, list[dict[str, str]]]: + ) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]: raise NotImplementedError diff --git a/apex/services/deep_research/deep_research_langchain.py b/apex/services/deep_research/deep_research_langchain.py index 62f23b8ce..e4a98842b 100644 --- a/apex/services/deep_research/deep_research_langchain.py +++ b/apex/services/deep_research/deep_research_langchain.py @@ -1,5 +1,7 @@ +import asyncio from typing import Any +import tenacity from langchain.prompts import PromptTemplate from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import LLMChainFilter @@ -9,9 +11,11 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import RunnableSerializable +from langchain_experimental.utilities import PythonREPL from langchain_openai import ChatOpenAI from langchain_text_splitters import RecursiveCharacterTextSplitter from loguru import logger +from openai import RateLimitError from apex.common.config import Config from apex.services.deep_research.deep_research_base import DeepResearchBase @@ -53,6 +57,7 @@ def __init__( openai_api_base=summary_base_url if summary_base_url is not None else base_url, max_retries=3, temperature=0.01, + max_tokens=800, ) self.research_model = ChatOpenAI( model_name=research_model, @@ -60,6 +65,7 @@ def __init__( openai_api_base=research_base_url if research_base_url is not None else base_url, max_retries=3, temperature=0.01, + max_tokens=1200, ) self.compression_model = ChatOpenAI( model_name=compression_model, @@ -67,6 +73,7 @@ def __init__( openai_api_base=compression_base_url if compression_base_url is not None else base_url, max_retries=3, temperature=0.01, + max_tokens=600, ) self.final_model = ChatOpenAI( model_name=final_model, @@ -74,7 +81,11 @@ def __init__( openai_api_base=final_base_url if final_base_url is not None else base_url, max_retries=3, temperature=0.01, + max_tokens=2000, ) + # Caution: PythonREPL can execute arbitrary code on the host machine. + # Use with caution and consider sandboxing for untrusted inputs. + self.python_repl = PythonREPL() async def _create_vector_store(self, documents: list[Document]) -> BaseRetriever: text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64) @@ -107,6 +118,15 @@ def _create_research_chain(self) -> RunnableSerializable[dict[str, Any], str]: prompt = PromptTemplate( input_variables=["context", "question"], template="""Generate a comprehensive research report based on the provided context. +The report should be long-form (800-1200 words) and include the sections: +- Executive Summary +- Key Findings +- Evidence (quote or paraphrase context with attributions) +- Limitations +- Conclusion + +Explain reasoning explicitly in prose. Prefer depth over breadth. + Context: {context} Question: {question} Research Report: @@ -116,59 +136,312 @@ def _create_research_chain(self) -> RunnableSerializable[dict[str, Any], str]: async def invoke( self, messages: list[dict[str, str]], body: dict[str, Any] | None = None - ) -> tuple[str, list[dict[str, str]]]: # type: ignore[override] - # Clear tool history for each new invocation + ) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]: # type: ignore[override] + # Agentic, iterative deep research with a single websearch tool. self.tool_history = [] - question = messages[-1]["content"] - documents: list[Document] = [] - if body and "documents" in body and body["documents"]: - documents = [ - Document(page_content=doc["page_content"]) - for doc in body["documents"] - if doc and "page_content" in doc and doc["page_content"] is not None - ] - - if not documents: - # Track websearch in tool history - self.tool_history.append({"tool": "websearch", "args": question}) - websites = await self.websearch.search(query=question, max_results=5) - for website in websites: - if website.content: - documents.append(Document(page_content=str(website.content), metadata={"url": website.url})) - - if not documents: - return "Could not find any information on the topic.", self.tool_history - - retriever = await self._create_vector_store(documents) - if not retriever: - return "Could not create a vector store from the documents.", self.tool_history - - compression_retriever = self._create_compression_retriever(retriever) - - summary_chain = self._create_summary_chain() - research_chain = self._create_research_chain() - - compressed_docs: list[Document] = await compression_retriever.ainvoke(question) + reasoning_traces: list[dict[str, Any]] = [] - summary: str = await summary_chain.ainvoke({"context": compressed_docs, "question": question}) - - research_report: str = await research_chain.ainvoke({"context": compressed_docs, "question": question}) + question = messages[-1]["content"] + # Seed notes with any provided documents + notes: list[str] = [] + if body is not None and "documents" in body: + for doc in body["documents"]: + if doc and doc.get("page_content") is not None: + content = str(doc["page_content"])[:1000] + notes.append(f"Provided document snippet: {content}") + + # Iterative loop using research model to choose actions + max_iterations = 20 + step_index = 0 + + # Track discovered sources from websearch for citations + collected_sources: list[dict[str, str]] = [] + seen_urls: set[str] = set() + + agent_chain = self._build_agent_chain() + + while step_index < max_iterations: + logger.debug(f"Starting deep researcher {step_index + 1}/{max_iterations} step") + step_index += 1 + agent_output: str = await self._try_invoke( + agent_chain, + { + "question": question, + "notes": self._render_notes(notes=notes), + "sources": self._render_sources(collected_sources=collected_sources), + }, + ) + parsed = self._safe_parse_json(agent_output) + if parsed is None: + reasoning_traces.append( + { + "step": f"iteration-{step_index}", + "model": getattr(self.research_model, "model_name", "unknown"), + "output": agent_output, + "error": "Failed to parse JSON from agent output", + } + ) + # Add a note to steer next iteration toward valid JSON + notes.append("Agent output was not valid JSON. Please respond with valid JSON per schema.") + continue + + thought = str(parsed.get("thought", "")) + + # Final answer branch + if "final_answer" in parsed: + logger.debug("Early-stopping deep research due to the final answer") + final_answer = str(parsed.get("final_answer", "")) + reasoning_traces.append( + { + "step": f"iteration-{step_index}", + "model": getattr(self.research_model, "model_name", "unknown"), + "thought": thought, + "final_answer": final_answer, + } + ) + return final_answer, self.tool_history, reasoning_traces + + # Action branch (only websearch supported) + action = parsed.get("action") or {} + if action.get("tool") == "websearch": + action_input = action.get("input") or {} + query = str(action_input.get("query", question)) + try: + max_results = int(action_input.get("max_results", 5)) + except Exception: + max_results = 5 + max_results = max(1, min(10, max_results)) + + self.tool_history.append({"tool": "websearch", "args": query}) + websites = await self.websearch.search(query=query, max_results=max_results) + + observations: list[str] = [] + for idx, website in enumerate(websites[:max_results]): + if website.content: + snippet = str(website.content)[:1000] + observations.append( + f"Result {idx + 1}: {website.title or website.url or 'untitled'}\n{snippet}" + ) + # Track source metadata for citations + url = getattr(website, "url", "") or "" + if url and url not in seen_urls: + seen_urls.add(url) + collected_sources.append( + { + "url": url, + "title": website.title or "", + } + ) + + observation_text = "\n\n".join(observations) if observations else "No results returned." + notes.append(f"Thought: {thought}") + notes.append(f'Observation from websearch (q="{query}"):\n{observation_text}') + reasoning_traces.append( + { + "step": f"iteration-{step_index}", + "model": getattr(self.research_model, "model_name", "unknown"), + "thought": thought, + "action": {"tool": "websearch", "query": query, "max_results": max_results}, + "observation": observation_text[:1000], + } + ) + continue + + if action.get("tool") == "python_repl": + logger.debug(f"Applying tool: {action.get('tool')}") + action_input = action.get("input") or {} + code = str(action_input.get("code", "")).strip() + logger.debug(f"Code to be executed:\n{code}") + # Record the tool use (truncate long code for history) + self.tool_history.append({"tool": "python_repl", "args": code[:200]}) + + if not code: + observation_text = "python_repl received empty code." + else: + try: + # PythonREPL returns only printed output (may include trailing newline) + repl_output = self.python_repl.run(code) + observation_text = repl_output if repl_output else "(no output)" + logger.debug(f"Code execution result:\n{observation_text}") + except Exception as e: # noqa: BLE001 + observation_text = f"Error while executing code: {e}" + + notes.append(f"Thought: {thought}") + logger.debug(f"Thought: {thought}") + notes.append(f"Observation from python_repl:\n{observation_text}") + reasoning_traces.append( + { + "step": f"iteration-{step_index}", + "model": getattr(self.research_model, "model_name", "unknown"), + "thought": thought, + "action": {"tool": "python_repl", "code": code[:1000]}, + "observation": observation_text[:1000], + } + ) + continue + + # Unknown action or schema + reasoning_traces.append( + { + "step": f"iteration-{step_index}", + "model": getattr(self.research_model, "model_name", "unknown"), + "thought": thought, + "error": f"Unsupported action or schema: {action}", + } + ) + notes.append("Agent returned an unsupported action. Use the websearch tool or provide final_answer.") + + # Fallback: if loop ends without final answer, ask final model to synthesize from notes + logger.debug("Generating final answer") final_prompt = PromptTemplate( - input_variables=["summary", "research_report", "question"], - template="""Based on the following summary and research report, provide a final answer to the question. -Summary: {summary} -Research Report: {research_report} -Question: {question} -Final Answer: -""", + input_variables=["question", "notes", "sources"], + template=( + "You are a senior interdisciplinary researcher with expertise across " + "science, technology, humanities, and social sciences.\n" + "Provide report only in plain text using natural language.\n" + "Write your response in the form of a well-structured research report with sections:\n" + "Executive Summary, Key Findings, Evidence, Limitations, Conclusion.\n" + "Use inline numeric citations like [1], [2] that refer to Sources.\n" + "At the end, include a 'Sources' section listing the numbered citations.\n\n" + "Do NOT use JSON, or any other structured data format.\n" + "Question:\n{question}\n\n" + "Notes:\n{notes}\n\n" + "Sources:\n{sources}\n\n" + "Research Report:" + ), ) final_chain = final_prompt | self.final_model | StrOutputParser() - final_answer: str = await final_chain.ainvoke( - {"summary": summary, "research_report": research_report, "question": question} + final_report: str = await self._try_invoke( + final_chain, + { + "question": question, + "notes": self._render_notes(notes=notes, max_items=12), + "sources": self._render_sources(collected_sources=collected_sources, max_items=20), + }, + ) + reasoning_traces.append( + { + "step": "final-fallback", + "model": getattr(self.final_model, "model_name", "unknown"), + "output": final_report, + } ) - return final_answer, self.tool_history + return final_report, self.tool_history, reasoning_traces + + def _render_sources(self, collected_sources: list[dict[str, str]], max_items: int = 12) -> str: + if not collected_sources: + return "(none)" + lines: list[str] = [] + for i, src in enumerate(collected_sources[:max_items], start=1): + title = src.get("title") or "untitled" + url = src.get("url") or "" + lines.append(f"[{i}] {title} - {url}") + return "\n".join(lines) + + def _render_notes(self, notes: list[str], max_items: int = 8) -> str: + if not notes: + return "(none yet)" + clipped = notes[-max_items:] + return "\n".join(f"- {item}" for item in clipped) + + def _build_agent_chain(self) -> RunnableSerializable[dict[str, Any], str]: + prompt = PromptTemplate( + input_variables=["question", "notes", "sources"], + template=( + "You are DeepResearcher, a meticulous, tool-using research agent.\n" + "You can use exactly these tools: websearch, python_repl.\n\n" + "Tool: websearch\n" + "- description: Search the web for relevant information.\n" + "- args: keys: 'query' (string), 'max_results' (integer <= 10)\n\n" + "Tool: python_repl\n" + "- description: A Python shell for executing Python commands.\n" + "- note: Print values to see output, e.g., `print(...)`.\n" + "- args: keys: 'code' (string: valid python command).\n\n" + "Follow an iterative think-act-observe loop. " + "Prefer rich internal reasoning over issuing many tool calls.\n" + "Spend time thinking: produce substantial, explicit reasoning in each 'thought'.\n" + "Avoid giving a final answer too early. Aim for at least 6 detailed thoughts before finalizing,\n" + "unless the question is truly trivial. " + "If no tool use is needed in a step, still provide a reflective 'thought'\n" + "that evaluates evidence, identifies gaps, and plans the next step.\n\n" + "Always respond in strict JSON. Use one of the two schemas:\n\n" + "1) Action step (JSON keys shown with dot-paths):\n" + "- thought: string\n" + "- action.tool: 'websearch' | 'python_repl'\n" + "- action.input: for websearch -> {{query: string, max_results: integer}}\n" + "- action.input: for python_repl -> {{code: string}}\n\n" + "2) Final answer step:\n" + "- thought: string\n" + "- final_answer: string (use plain text for final answer, not a JSON)\n\n" + "In every step, make 'thought' a detailed paragraph (120-200 words) that:\n" + "- Summarizes what is known and unknown so far\n" + "- Justifies the chosen next action or decision not to act\n" + "- Evaluates evidence quality and cites source numbers when applicable\n" + "- Identifies risks, uncertainties, and alternative hypotheses\n\n" + "Executive Summary, Key Findings, Evidence, Limitations, Conclusion.\n" + "Use inline numeric citations like [1], [2] that refer to Sources.\n" + "Include a final section titled 'Sources' listing the numbered citations.\n\n" + "Question:\n{question}\n\n" + "Notes and observations so far:\n{notes}\n\n" + "Sources (use these for citations):\n{sources}\n\n" + "Respond with JSON always, except for final_anwer (use plain text)." + ), + ) + return prompt | self.research_model | StrOutputParser() + + @tenacity.retry( + retry=tenacity.retry_if_exception_type(RateLimitError), + stop=tenacity.stop_after_attempt(5), + wait=tenacity.wait_fixed(10), + reraise=True, + ) + async def _try_invoke(self, chain: RunnableSerializable[dict[str, Any], str], inputs: dict[str, Any]) -> str: + return await chain.ainvoke(inputs) + + def _safe_parse_json(self, text: str) -> dict[str, Any] | None: + """Attempt to parse a JSON object from model output. + + Tries full parse, fenced code extraction, and best-effort substring extraction. + """ + import json + import re + + # Direct parse + try: + obj = json.loads(text) + if isinstance(obj, dict): + return obj + return None + except Exception: + pass + + # Extract first JSON code fence + fence_match = re.search(r"```(?:json)?\s*({[\s\S]*?})\s*```", text) + if fence_match: + candidate = fence_match.group(1) + try: + obj2 = json.loads(candidate) + if isinstance(obj2, dict): + return obj2 + return None + except Exception: + pass + + # Heuristic: find first '{' and last '}' and try parse + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + candidate2 = text[start : end + 1] + try: + obj3 = json.loads(candidate2) + if isinstance(obj3, dict): + return obj3 + return None + except Exception: + return None + return None class _CustomEmbeddings(Embeddings): # type: ignore @@ -201,16 +474,26 @@ async def aembed_query(self, text: str) -> list[float]: deep_researcher = DeepResearchLangchain(**config.deep_research.kwargs, websearch=websearch) # Create a dummy request. - dummy_messages = [{"role": "user", "content": "What is the purpose of subnet 1 in Bittensor?"}] + dummy_messages = [ + { + "role": "user", + "content": """In the study of convex sets, why might two closed convex sets fail to have a strictly + separating hyperplane, even if they are disjoint? What geometric or topological properties could + prevent strict separation, and how does this contrast + with the case where strict separation is possible? Can you provide an intuitive example where such + a scenario occurs, and explain the underlying reasoning?""", + } + ] dummy_body: dict[str, Any] = {} # Run the invoke method. async def main() -> None: timer_start = time.perf_counter() - result, tool_history = await deep_researcher.invoke(dummy_messages, dummy_body) - logger.debug("Answer:", result) - logger.debug("Tool History:", tool_history) + result, tool_history, reasoning_traces = await deep_researcher.invoke(dummy_messages, dummy_body) + print(f"Answer: {result}") + print(f"Tool History: {tool_history}") + print(f"Reasoning Traces: {reasoning_traces}") timer_end = time.perf_counter() - logger.debug(f"Time elapsed: {timer_end - timer_start:.2f}s") + print(f"Time elapsed: {timer_end - timer_start:.2f}s") asyncio.run(main()) diff --git a/apex/services/llm/llm.py b/apex/services/llm/llm.py index b7bbe3b2e..d3665cbdd 100644 --- a/apex/services/llm/llm.py +++ b/apex/services/llm/llm.py @@ -14,7 +14,7 @@ def __init__(self, base_url: str, model: str, key: str): async def invoke( self, messages: list[dict[str, str]], body: dict[str, Any] | None = None - ) -> tuple[str, list[dict[str, str]]]: + ) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]: headers = { "Authorization": "Bearer " + self._key, "Content-Type": "application/json", @@ -35,7 +35,8 @@ async def invoke( data = await response.json() content = data.get("choices", [{}])[0].get("message", {}).get("content", "") - return str(content), [] + # This base LLM does not build multi-step chains; return empty reasoning_traces + return str(content), [], [] def __str__(self) -> str: return f"{self.__class__.__name__}({self._base_url}, {self._model})" diff --git a/apex/services/llm/llm_base.py b/apex/services/llm/llm_base.py index ccb84b761..bbdff173f 100644 --- a/apex/services/llm/llm_base.py +++ b/apex/services/llm/llm_base.py @@ -4,5 +4,5 @@ class LLMBase: async def invoke( self, messages: list[dict[str, str]], body: dict[str, Any] | None = None - ) -> tuple[str, list[dict[str, str]]]: + ) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]: raise NotImplementedError diff --git a/apex/validator/generate_query.py b/apex/validator/generate_query.py index 41d72fd30..4b59190fb 100644 --- a/apex/validator/generate_query.py +++ b/apex/validator/generate_query.py @@ -23,6 +23,6 @@ async def generate_query(llm: LLMBase, websearch: WebSearchBase) -> str: search_website = random.choice(search_results) search_content = search_website.content query = QUERY_PROMPT_TEMPLATE.format(context=search_content) - query_response, _ = await llm.invoke([{"role": "user", "content": query}]) + query_response, _, _ = await llm.invoke([{"role": "user", "content": query}]) logger.debug(f"Generated query.\nPrompt: '{query}'\nResponse: '{query_response}'") return query_response diff --git a/apex/validator/generate_reference.py b/apex/validator/generate_reference.py index 39adb41d6..61d7f57be 100644 --- a/apex/validator/generate_reference.py +++ b/apex/validator/generate_reference.py @@ -1,9 +1,13 @@ +from typing import Any + from loguru import logger from apex.services.deep_research.deep_research_base import DeepResearchBase -async def generate_reference(llm: DeepResearchBase, query: str) -> tuple[str, list[dict[str, str]]]: +async def generate_reference( + llm: DeepResearchBase, query: str +) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]: """Generate a reference response for the given prompt. Args: @@ -25,10 +29,10 @@ async def generate_reference(llm: DeepResearchBase, query: str) -> tuple[str, li "content": ( f"Research Question: {query}\n\n" "Please think through the answer carefully, annotate each step with citations like [1], [2], etc., " - 'and conclude with a "References:" list mapping each [n] to its source URL or title.' + 'and conclude with a "Sources:" list mapping each [n] to its source URL or title.' ), } - response, tool_history = await llm.invoke([system_message, user_message]) + response, tool_history, reasoning_traces = await llm.invoke([system_message, user_message]) logger.debug(f"Generated reference.\nPrompt: '{user_message}'\nResponse: '{response}'") - return response, tool_history + return response, tool_history, reasoning_traces diff --git a/apex/validator/logger_wandb.py b/apex/validator/logger_wandb.py index dcfdd2754..36e9be74f 100644 --- a/apex/validator/logger_wandb.py +++ b/apex/validator/logger_wandb.py @@ -46,6 +46,7 @@ async def log( reference: str | None = None, discriminator_results: MinerDiscriminatorResults | None = None, tool_history: list[dict[str, str]] | None = None, + reasoning_traces: list[dict[str, Any]] | None = None, ) -> None: """Log an event to wandb.""" if self.run: @@ -53,6 +54,7 @@ async def log( processed_event = self.process_event(discriminator_results.model_dump()) processed_event["reference"] = reference processed_event["tool_history"] = tool_history + processed_event["reasoning_trace"] = reasoning_traces self.run.log(processed_event) def process_event(self, event: Mapping[str, Any]) -> dict[str, Any]: diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index cc1e6e045..2e1a243e2 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -84,12 +84,15 @@ async def run_single(self, task: QueryTask) -> str: reference = None tool_history: list[dict[str, str]] = [] + reasoning_traces: list[dict[str, Any]] = [] if random.random() < self.reference_rate: try: generator_results = None ground_truth = 0 logger.debug(f"Generating task reference for query: {query[:20]}..") - reference, tool_history = await generate_reference(llm=self.deep_research, query=query) + reference, tool_history, reasoning_traces = await generate_reference( + llm=self.deep_research, query=query + ) except BaseException as exc: logger.exception(f"Failed to generate reference: {exc}") @@ -100,7 +103,9 @@ async def run_single(self, task: QueryTask) -> str: if random.random() < self.redundancy_rate: try: logger.debug(f"Generating redundant task reference for query: {query[:20]}..") - reference, tool_history = await generate_reference(llm=self.deep_research, query=query) + reference, tool_history, reasoning_traces = await generate_reference( + llm=self.deep_research, query=query + ) except BaseException as exc: logger.warning(f"Failed to generate redundant reference: {exc}") @@ -111,7 +116,10 @@ async def run_single(self, task: QueryTask) -> str: if self.logger_wandb: await self.logger_wandb.log( - reference=reference, discriminator_results=discriminator_results, tool_history=tool_history + reference=reference, + discriminator_results=discriminator_results, + tool_history=tool_history, + reasoning_traces=reasoning_traces, ) if self._debug: diff --git a/pyproject.toml b/pyproject.toml index 28cc016b2..f65321880 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "pytest-mock>=3.14.1", "wandb>=0.21.1", "ruff>=0.12.5", + "langchain-experimental>=0.3.4", ] diff --git a/tests/services/deep_research/test_deep_research_langchain.py b/tests/services/deep_research/test_deep_research_langchain.py index 1a05ea259..7e6391076 100644 --- a/tests/services/deep_research/test_deep_research_langchain.py +++ b/tests/services/deep_research/test_deep_research_langchain.py @@ -1,7 +1,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from langchain_core.documents import Document from apex.services.deep_research.deep_research_langchain import ( DeepResearchLangchain, @@ -33,6 +32,10 @@ def deep_research_langchain(mock_websearch, mock_llm_embed, mock_chat_openai): with ( patch("apex.services.deep_research.deep_research_langchain.LLMEmbed", return_value=mock_llm_embed), patch("apex.services.deep_research.deep_research_langchain.ChatOpenAI", return_value=mock_chat_openai), + patch( + "apex.services.deep_research.deep_research_langchain.PythonREPL", + return_value=MagicMock(run=MagicMock(return_value="2\n")), + ), ): return DeepResearchLangchain( key="test_key", @@ -68,144 +71,133 @@ async def test_custom_embeddings_aembed_query(mock_llm_embed): @pytest.mark.asyncio async def test_invoke_with_documents_in_body(deep_research_langchain, mock_websearch): - """Test invoke method when documents are provided in the body.""" + """When body contains documents, agent can directly produce a final report without websearch.""" messages = [{"role": "user", "content": "test question"}] - docs = [{"page_content": "doc1"}, {"page_content": "doc2"}] - body = {"documents": docs} + body = {"documents": [{"page_content": "doc1"}, {"page_content": "doc2"}]} with ( - patch.object( - deep_research_langchain, "_create_vector_store", new_callable=AsyncMock - ) as mock_create_vector_store, - patch.object(deep_research_langchain, "_create_compression_retriever") as mock_create_compression_retriever, - patch.object(deep_research_langchain, "_create_summary_chain") as mock_create_summary_chain, - patch.object(deep_research_langchain, "_create_research_chain") as mock_create_research_chain, patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), ): - mock_compression_retriever = MagicMock() - mock_compression_retriever.ainvoke = AsyncMock(return_value=[Document(page_content="compressed_doc")]) - mock_create_compression_retriever.return_value = mock_compression_retriever - - mock_summary_chain = MagicMock() - mock_summary_chain.ainvoke = AsyncMock(return_value="summary") - mock_create_summary_chain.return_value = mock_summary_chain - - mock_research_chain = MagicMock() - mock_research_chain.ainvoke = AsyncMock(return_value="research_report") - mock_create_research_chain.return_value = mock_research_chain - - final_chain = MagicMock() - final_chain.ainvoke = AsyncMock(return_value="final_answer") - mock_prompt_template.return_value.__or__.return_value.__or__.return_value = final_chain + agent_chain = AsyncMock() + agent_chain.ainvoke.return_value = '{"thought": "enough info", "final_answer": "final_report"}' + mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain - await deep_research_langchain.invoke(messages, body) + result = await deep_research_langchain.invoke(messages, body) mock_websearch.search.assert_not_called() - mock_create_vector_store.assert_called_once() + assert result[0] == "final_report" @pytest.mark.asyncio async def test_invoke_with_websearch(deep_research_langchain, mock_websearch): - """Test invoke method when no documents are in the body, falling back to websearch.""" + """Agent chooses websearch then produces final answer.""" messages = [{"role": "user", "content": "test question"}] - mock_websearch.search.return_value = [MagicMock(content="web_doc", url="http://a.com")] + mock_websearch.search.return_value = [MagicMock(content="web_doc", url="http://a.com", title="A")] with ( - patch.object( - deep_research_langchain, "_create_vector_store", new_callable=AsyncMock - ) as mock_create_vector_store, - patch.object(deep_research_langchain, "_create_compression_retriever") as mock_create_compression_retriever, - patch.object(deep_research_langchain, "_create_summary_chain") as mock_create_summary_chain, - patch.object(deep_research_langchain, "_create_research_chain") as mock_create_research_chain, patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), ): - mock_compression_retriever = MagicMock() - mock_compression_retriever.ainvoke = AsyncMock(return_value=[Document(page_content="compressed_doc")]) - mock_create_compression_retriever.return_value = mock_compression_retriever - - mock_summary_chain = MagicMock() - mock_summary_chain.ainvoke = AsyncMock(return_value="summary") - mock_create_summary_chain.return_value = mock_summary_chain - - mock_research_chain = MagicMock() - mock_research_chain.ainvoke = AsyncMock(return_value="research_report") - mock_create_research_chain.return_value = mock_research_chain - - final_chain = MagicMock() - final_chain.ainvoke = AsyncMock(return_value="final_answer") - mock_prompt_template.return_value.__or__.return_value.__or__.return_value = final_chain + agent_chain = AsyncMock() + agent_chain.ainvoke.side_effect = [ + ( + '{"thought": "need info", "action": {"tool": "websearch", ' + '"input": {"query": "test query", "max_results": 3}}}' + ), + '{"thought": "done", "final_answer": "final_answer"}', + ] + mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain - await deep_research_langchain.invoke(messages) + result = await deep_research_langchain.invoke(messages) - mock_websearch.search.assert_called_once_with(query="test question", max_results=5) - mock_create_vector_store.assert_called_once() + mock_websearch.search.assert_called_once_with(query="test query", max_results=3) + assert result[0] == "final_answer" @pytest.mark.asyncio -async def test_invoke_no_documents_found(deep_research_langchain, mock_websearch): - """Test invoke when no documents are found from any source.""" +async def test_invoke_no_websearch_needed_final_answer(deep_research_langchain, mock_websearch): + """Agent can produce a final report without calling websearch.""" messages = [{"role": "user", "content": "test question"}] - mock_websearch.search.return_value = [] - result = await deep_research_langchain.invoke(messages) + with ( + patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, + patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), + ): + agent_chain = AsyncMock() + agent_chain.ainvoke.return_value = '{"thought": "clear", "final_answer": "final_report"}' + mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain - assert result == ("Could not find any information on the topic.", deep_research_langchain.tool_history) + result = await deep_research_langchain.invoke(messages) + + mock_websearch.search.assert_not_called() + assert result[0] == "final_report" @pytest.mark.asyncio -async def test_full_invoke_flow(deep_research_langchain, mock_websearch): - """Test the full, successful execution flow of the invoke method.""" +async def test_full_invoke_flow_with_multiple_actions(deep_research_langchain, mock_websearch): + """Agent performs multiple websearch actions before final answer; tool_history and traces are recorded.""" messages = [{"role": "user", "content": "test question"}] - question = messages[-1]["content"] - web_docs = [MagicMock(content="web_doc", url="http://a.com")] - compressed_docs = [Document(page_content="compressed_doc")] - summary = "summary" - research_report = "research_report" - final_answer = "final_answer" - mock_websearch.search.return_value = web_docs + # Two rounds of search results + mock_websearch.search.side_effect = [ + [ + MagicMock(content="doc A", url="http://a.com", title="A"), + MagicMock(content="doc B", url="http://b.com", title="B"), + ], + [MagicMock(content="doc C", url="http://c.com", title="C")], + ] with ( - patch("apex.services.deep_research.deep_research_langchain.FAISS") as mock_faiss, - patch( - "apex.services.deep_research.deep_research_langchain.ContextualCompressionRetriever" - ) as mock_compression_retriever_class, - patch("apex.services.deep_research.deep_research_langchain.LLMChainFilter") as mock_llm_chain_filter, - patch.object(deep_research_langchain, "_create_summary_chain") as mock_create_summary_chain, - patch.object(deep_research_langchain, "_create_research_chain") as mock_create_research_chain, patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), ): - mock_vector_store = AsyncMock() - mock_faiss.afrom_documents = AsyncMock(return_value=mock_vector_store) + agent_chain = AsyncMock() + agent_chain.ainvoke.side_effect = [ + ( + '{"thought": "need more info", "action": {"tool": "websearch", ' + '"input": {"query": "Q1", "max_results": 2}}}' + ), + ( + '{"thought": "still need more", "action": {"tool": "websearch", ' + '"input": {"query": "Q2", "max_results": 1}}}' + ), + '{"thought": "complete", "final_answer": "final_report"}', + ] + mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain - mock_compression_retriever = AsyncMock() - mock_compression_retriever.ainvoke.return_value = compressed_docs - mock_compression_retriever_class.return_value = mock_compression_retriever - - mock_summary_chain = AsyncMock() - mock_summary_chain.ainvoke.return_value = summary - mock_create_summary_chain.return_value = mock_summary_chain + result = await deep_research_langchain.invoke(messages) - mock_research_chain = AsyncMock() - mock_research_chain.ainvoke.return_value = research_report - mock_create_research_chain.return_value = mock_research_chain + # Two tool uses + assert mock_websearch.search.call_count == 2 + mock_websearch.search.assert_any_call(query="Q1", max_results=2) + mock_websearch.search.assert_any_call(query="Q2", max_results=1) - final_chain = AsyncMock() - final_chain.ainvoke.return_value = final_answer - mock_prompt_template.return_value.__or__.return_value.__or__.return_value = final_chain + # Final answer returned + assert result[0] == "final_report" + # Tool history recorded + assert len(result[1]) == 2 + assert result[1][0]["tool"] == "websearch" + # Reasoning traces present + assert isinstance(result[2], list) - result = await deep_research_langchain.invoke(messages) - mock_websearch.search.assert_called_once_with(query=question, max_results=5) - mock_faiss.afrom_documents.assert_called_once() - mock_llm_chain_filter.from_llm.assert_called_once() - mock_compression_retriever.ainvoke.assert_called_once_with(question) - mock_summary_chain.ainvoke.assert_called_once_with({"context": compressed_docs, "question": question}) - mock_research_chain.ainvoke.assert_called_once_with({"context": compressed_docs, "question": question}) - final_chain.ainvoke.assert_called_once_with( - {"summary": summary, "research_report": research_report, "question": question} - ) - assert result == (final_answer, deep_research_langchain.tool_history) +@pytest.mark.asyncio +async def test_invoke_with_python_repl(deep_research_langchain): + """Agent chooses python_repl then produces final answer.""" + with ( + patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, + patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), + ): + agent_chain = AsyncMock() + agent_chain.ainvoke.side_effect = [ + ('{"thought": "compute needed", "action": {"tool": "python_repl", "input": {"code": "print(1+1)"}}}'), + '{"thought": "done", "final_answer": "final_answer"}', + ] + mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain + + result = await deep_research_langchain.invoke([{"role": "user", "content": "q"}]) + + # Tool history includes python_repl usage + assert any(t["tool"] == "python_repl" for t in result[1]) + assert result[0] == "final_answer" diff --git a/uv.lock b/uv.lock index cde84a241..a5dc6b49c 100644 --- a/uv.lock +++ b/uv.lock @@ -154,6 +154,7 @@ dependencies = [ { name = "langchain" }, { name = "langchain-community" }, { name = "langchain-core" }, + { name = "langchain-experimental" }, { name = "langchain-openai" }, { name = "langchain-sandbox" }, { name = "loguru" }, @@ -210,6 +211,7 @@ requires-dist = [ { name = "langchain", specifier = ">=0.3.26" }, { name = "langchain-community", specifier = ">=0.0.59" }, { name = "langchain-core", specifier = ">=0.3.68" }, + { name = "langchain-experimental", specifier = ">=0.3.4" }, { name = "langchain-openai", specifier = ">=0.3.28" }, { name = "langchain-sandbox", specifier = ">=0.0.6" }, { name = "loguru", specifier = ">=0.7.3" }, @@ -1329,6 +1331,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/42/0d0221cce6f168f644d7d96cb6c87c4e42fc55d2941da7a36e970e3ab8ab/langchain_core-0.3.75-py3-none-any.whl", hash = "sha256:03ca1fadf955ee3c7d5806a841f4b3a37b816acea5e61a7e6ba1298c05eea7f5", size = 443986, upload-time = "2025-08-26T15:24:10.883Z" }, ] +[[package]] +name = "langchain-experimental" +version = "0.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-community" }, + { name = "langchain-core" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/56/a8acbb08a03383c28875b3b151e4cefea5612266917fbd6fc3c14c21e172/langchain_experimental-0.3.4.tar.gz", hash = "sha256:937c4259ee4a639c618d19acf0e2c5c2898ef127050346edc5655259aa281a21", size = 140532, upload-time = "2024-12-20T15:16:09.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/27/fe8caa4884611286b1f7d6c5cfd76e1fef188faaa946db4fde6daa1cd2cd/langchain_experimental-0.3.4-py3-none-any.whl", hash = "sha256:2e587306aea36b60fa5e5fc05dc7281bee9f60a806f0bf9d30916e0ee096af80", size = 209154, upload-time = "2024-12-20T15:16:07.006Z" }, +] + [[package]] name = "langchain-openai" version = "0.3.32" @@ -2142,6 +2157,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3e/82/13e2ca5e43f2c7219fafa069341d52ee1867c2d929a019574ec90f6d0114/py_bip39_bindings-0.2.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:e6064feb105ed5c7278d19e8c4e371710ce56adcdd48e0c5e6b77f9b005201b9", size = 699630, upload-time = "2024-12-09T15:04:51.103Z" }, { url = "https://files.pythonhosted.org/packages/78/d7/f4c33dbc311cd07946994481dea7f09e417a29dc46e3c6de81d367a87194/py_bip39_bindings-0.2.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:0cb7a39bd65455c4cc3feb3f8d5766644410f32ac137aeea88b119c6ebe2d58b", size = 627655, upload-time = "2024-12-09T15:05:03.212Z" }, { url = "https://files.pythonhosted.org/packages/31/68/7dbb7f20d64222bd64e56d976e30614fb10d8b9dd60547028cf12b7d23a4/py_bip39_bindings-0.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:58743e95cedee157a545d060ee401639308f995badb91477d195f1d571664b65", size = 603340, upload-time = "2024-12-09T15:05:17.062Z" }, + { url = "https://files.pythonhosted.org/packages/2d/ba/428b20399740ee40756df785fd0412b8afff35eaf49a9d93fff5d115e931/py_bip39_bindings-0.2.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:ade3ed8be37dcce01674e7f1974e4b53340abd423c940781898d33b096189581", size = 365266, upload-time = "2025-09-02T14:03:38.92Z" }, ] [[package]] @@ -2194,6 +2210,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/7a/0d5073188f94fd3b22a836e867e32fae0e26f3b39f734314e3eff5b530f6/py_ed25519_zebra_bindings-1.2.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:4fd00c8686b17e31ec29d8e4e7ce97f465fe26227f12c9e111e012b9d0dff4b9", size = 585681, upload-time = "2024-12-09T15:32:19.506Z" }, { url = "https://files.pythonhosted.org/packages/ab/99/add86df518d799a17c91763eebf756de68b1a858a5c7977de1b335e886cc/py_ed25519_zebra_bindings-1.2.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:e4e55fc5be4ba0c723d424cefdbb8d863e74d2ff25fbeadca9539ca60d78cc0f", size = 514835, upload-time = "2024-12-09T15:32:30.13Z" }, { url = "https://files.pythonhosted.org/packages/e7/fc/bf32dc80a597501fc7ef8b18638f78e5ee672b0b43cc02373075f9b1f8d4/py_ed25519_zebra_bindings-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:91816ed4cef90d4d08fa9f55fa0c5687c5eba601dc1a44f211adcf1c20d96cc3", size = 488524, upload-time = "2024-12-09T15:32:41.819Z" }, + { url = "https://files.pythonhosted.org/packages/33/66/22746aea9669128918e93c1cb7d381269f2d18d3385612071950eb713011/py_ed25519_zebra_bindings-1.2.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8e078ec4b72ed8492d6048e64c1979fcdc239104e3c23c9fb2cb7b9eee4bb098", size = 270358, upload-time = "2025-09-02T14:29:52.752Z" }, ] [[package]]