Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apex/services/deep_research/deep_research_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
class DeepResearchBase(LLMBase):
async def invoke(
self, messages: list[dict[str, str]], body: dict[str, Any] | None = None
) -> tuple[str, list[dict[str, str]]]:
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]:
raise NotImplementedError
383 changes: 333 additions & 50 deletions apex/services/deep_research/deep_research_langchain.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions apex/services/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, base_url: str, model: str, key: str):

async def invoke(
self, messages: list[dict[str, str]], body: dict[str, Any] | None = None
) -> tuple[str, list[dict[str, str]]]:
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]:
headers = {
"Authorization": "Bearer " + self._key,
"Content-Type": "application/json",
Expand All @@ -35,7 +35,8 @@ async def invoke(

data = await response.json()
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
return str(content), []
# This base LLM does not build multi-step chains; return empty reasoning_traces
return str(content), [], []

def __str__(self) -> str:
return f"{self.__class__.__name__}({self._base_url}, {self._model})"
2 changes: 1 addition & 1 deletion apex/services/llm/llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
class LLMBase:
async def invoke(
self, messages: list[dict[str, str]], body: dict[str, Any] | None = None
) -> tuple[str, list[dict[str, str]]]:
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]:
raise NotImplementedError
2 changes: 1 addition & 1 deletion apex/validator/generate_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ async def generate_query(llm: LLMBase, websearch: WebSearchBase) -> str:
search_website = random.choice(search_results)
search_content = search_website.content
query = QUERY_PROMPT_TEMPLATE.format(context=search_content)
query_response, _ = await llm.invoke([{"role": "user", "content": query}])
query_response, _, _ = await llm.invoke([{"role": "user", "content": query}])
logger.debug(f"Generated query.\nPrompt: '{query}'\nResponse: '{query_response}'")
return query_response
12 changes: 8 additions & 4 deletions apex/validator/generate_reference.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any

from loguru import logger

from apex.services.deep_research.deep_research_base import DeepResearchBase


async def generate_reference(llm: DeepResearchBase, query: str) -> tuple[str, list[dict[str, str]]]:
async def generate_reference(
llm: DeepResearchBase, query: str
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]:
"""Generate a reference response for the given prompt.

Args:
Expand All @@ -25,10 +29,10 @@ async def generate_reference(llm: DeepResearchBase, query: str) -> tuple[str, li
"content": (
f"Research Question: {query}\n\n"
"Please think through the answer carefully, annotate each step with citations like [1], [2], etc., "
'and conclude with a "References:" list mapping each [n] to its source URL or title.'
'and conclude with a "Sources:" list mapping each [n] to its source URL or title.'
),
}

response, tool_history = await llm.invoke([system_message, user_message])
response, tool_history, reasoning_traces = await llm.invoke([system_message, user_message])
logger.debug(f"Generated reference.\nPrompt: '{user_message}'\nResponse: '{response}'")
return response, tool_history
return response, tool_history, reasoning_traces
2 changes: 2 additions & 0 deletions apex/validator/logger_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ async def log(
reference: str | None = None,
discriminator_results: MinerDiscriminatorResults | None = None,
tool_history: list[dict[str, str]] | None = None,
reasoning_traces: list[dict[str, Any]] | None = None,
) -> None:
"""Log an event to wandb."""
if self.run:
if discriminator_results:
processed_event = self.process_event(discriminator_results.model_dump())
processed_event["reference"] = reference
processed_event["tool_history"] = tool_history
processed_event["reasoning_trace"] = reasoning_traces
self.run.log(processed_event)

def process_event(self, event: Mapping[str, Any]) -> dict[str, Any]:
Expand Down
14 changes: 11 additions & 3 deletions apex/validator/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ async def run_single(self, task: QueryTask) -> str:

reference = None
tool_history: list[dict[str, str]] = []
reasoning_traces: list[dict[str, Any]] = []
if random.random() < self.reference_rate:
try:
generator_results = None
ground_truth = 0
logger.debug(f"Generating task reference for query: {query[:20]}..")
reference, tool_history = await generate_reference(llm=self.deep_research, query=query)
reference, tool_history, reasoning_traces = await generate_reference(
llm=self.deep_research, query=query
)
except BaseException as exc:
logger.exception(f"Failed to generate reference: {exc}")

Expand All @@ -100,7 +103,9 @@ async def run_single(self, task: QueryTask) -> str:
if random.random() < self.redundancy_rate:
try:
logger.debug(f"Generating redundant task reference for query: {query[:20]}..")
reference, tool_history = await generate_reference(llm=self.deep_research, query=query)
reference, tool_history, reasoning_traces = await generate_reference(
llm=self.deep_research, query=query
)
except BaseException as exc:
logger.warning(f"Failed to generate redundant reference: {exc}")

Expand All @@ -111,7 +116,10 @@ async def run_single(self, task: QueryTask) -> str:

if self.logger_wandb:
await self.logger_wandb.log(
reference=reference, discriminator_results=discriminator_results, tool_history=tool_history
reference=reference,
discriminator_results=discriminator_results,
tool_history=tool_history,
reasoning_traces=reasoning_traces,
)

if self._debug:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"pytest-mock>=3.14.1",
"wandb>=0.21.1",
"ruff>=0.12.5",
"langchain-experimental>=0.3.4",
]


Expand Down
192 changes: 92 additions & 100 deletions tests/services/deep_research/test_deep_research_langchain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from langchain_core.documents import Document

from apex.services.deep_research.deep_research_langchain import (
DeepResearchLangchain,
Expand Down Expand Up @@ -33,6 +32,10 @@ def deep_research_langchain(mock_websearch, mock_llm_embed, mock_chat_openai):
with (
patch("apex.services.deep_research.deep_research_langchain.LLMEmbed", return_value=mock_llm_embed),
patch("apex.services.deep_research.deep_research_langchain.ChatOpenAI", return_value=mock_chat_openai),
patch(
"apex.services.deep_research.deep_research_langchain.PythonREPL",
return_value=MagicMock(run=MagicMock(return_value="2\n")),
),
):
return DeepResearchLangchain(
key="test_key",
Expand Down Expand Up @@ -68,144 +71,133 @@ async def test_custom_embeddings_aembed_query(mock_llm_embed):

@pytest.mark.asyncio
async def test_invoke_with_documents_in_body(deep_research_langchain, mock_websearch):
"""Test invoke method when documents are provided in the body."""
"""When body contains documents, agent can directly produce a final report without websearch."""
messages = [{"role": "user", "content": "test question"}]
docs = [{"page_content": "doc1"}, {"page_content": "doc2"}]
body = {"documents": docs}
body = {"documents": [{"page_content": "doc1"}, {"page_content": "doc2"}]}

with (
patch.object(
deep_research_langchain, "_create_vector_store", new_callable=AsyncMock
) as mock_create_vector_store,
patch.object(deep_research_langchain, "_create_compression_retriever") as mock_create_compression_retriever,
patch.object(deep_research_langchain, "_create_summary_chain") as mock_create_summary_chain,
patch.object(deep_research_langchain, "_create_research_chain") as mock_create_research_chain,
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
):
mock_compression_retriever = MagicMock()
mock_compression_retriever.ainvoke = AsyncMock(return_value=[Document(page_content="compressed_doc")])
mock_create_compression_retriever.return_value = mock_compression_retriever

mock_summary_chain = MagicMock()
mock_summary_chain.ainvoke = AsyncMock(return_value="summary")
mock_create_summary_chain.return_value = mock_summary_chain

mock_research_chain = MagicMock()
mock_research_chain.ainvoke = AsyncMock(return_value="research_report")
mock_create_research_chain.return_value = mock_research_chain

final_chain = MagicMock()
final_chain.ainvoke = AsyncMock(return_value="final_answer")
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = final_chain
agent_chain = AsyncMock()
agent_chain.ainvoke.return_value = '{"thought": "enough info", "final_answer": "final_report"}'
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain

await deep_research_langchain.invoke(messages, body)
result = await deep_research_langchain.invoke(messages, body)

mock_websearch.search.assert_not_called()
mock_create_vector_store.assert_called_once()
assert result[0] == "final_report"


@pytest.mark.asyncio
async def test_invoke_with_websearch(deep_research_langchain, mock_websearch):
"""Test invoke method when no documents are in the body, falling back to websearch."""
"""Agent chooses websearch then produces final answer."""
messages = [{"role": "user", "content": "test question"}]
mock_websearch.search.return_value = [MagicMock(content="web_doc", url="http://a.com")]
mock_websearch.search.return_value = [MagicMock(content="web_doc", url="http://a.com", title="A")]

with (
patch.object(
deep_research_langchain, "_create_vector_store", new_callable=AsyncMock
) as mock_create_vector_store,
patch.object(deep_research_langchain, "_create_compression_retriever") as mock_create_compression_retriever,
patch.object(deep_research_langchain, "_create_summary_chain") as mock_create_summary_chain,
patch.object(deep_research_langchain, "_create_research_chain") as mock_create_research_chain,
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
):
mock_compression_retriever = MagicMock()
mock_compression_retriever.ainvoke = AsyncMock(return_value=[Document(page_content="compressed_doc")])
mock_create_compression_retriever.return_value = mock_compression_retriever

mock_summary_chain = MagicMock()
mock_summary_chain.ainvoke = AsyncMock(return_value="summary")
mock_create_summary_chain.return_value = mock_summary_chain

mock_research_chain = MagicMock()
mock_research_chain.ainvoke = AsyncMock(return_value="research_report")
mock_create_research_chain.return_value = mock_research_chain

final_chain = MagicMock()
final_chain.ainvoke = AsyncMock(return_value="final_answer")
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = final_chain
agent_chain = AsyncMock()
agent_chain.ainvoke.side_effect = [
(
'{"thought": "need info", "action": {"tool": "websearch", '
'"input": {"query": "test query", "max_results": 3}}}'
),
'{"thought": "done", "final_answer": "final_answer"}',
]
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain

await deep_research_langchain.invoke(messages)
result = await deep_research_langchain.invoke(messages)

mock_websearch.search.assert_called_once_with(query="test question", max_results=5)
mock_create_vector_store.assert_called_once()
mock_websearch.search.assert_called_once_with(query="test query", max_results=3)
assert result[0] == "final_answer"


@pytest.mark.asyncio
async def test_invoke_no_documents_found(deep_research_langchain, mock_websearch):
"""Test invoke when no documents are found from any source."""
async def test_invoke_no_websearch_needed_final_answer(deep_research_langchain, mock_websearch):
"""Agent can produce a final report without calling websearch."""
messages = [{"role": "user", "content": "test question"}]
mock_websearch.search.return_value = []

result = await deep_research_langchain.invoke(messages)
with (
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
):
agent_chain = AsyncMock()
agent_chain.ainvoke.return_value = '{"thought": "clear", "final_answer": "final_report"}'
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain

assert result == ("Could not find any information on the topic.", deep_research_langchain.tool_history)
result = await deep_research_langchain.invoke(messages)

mock_websearch.search.assert_not_called()
assert result[0] == "final_report"


@pytest.mark.asyncio
async def test_full_invoke_flow(deep_research_langchain, mock_websearch):
"""Test the full, successful execution flow of the invoke method."""
async def test_full_invoke_flow_with_multiple_actions(deep_research_langchain, mock_websearch):
"""Agent performs multiple websearch actions before final answer; tool_history and traces are recorded."""
messages = [{"role": "user", "content": "test question"}]
question = messages[-1]["content"]
web_docs = [MagicMock(content="web_doc", url="http://a.com")]
compressed_docs = [Document(page_content="compressed_doc")]
summary = "summary"
research_report = "research_report"
final_answer = "final_answer"

mock_websearch.search.return_value = web_docs
# Two rounds of search results
mock_websearch.search.side_effect = [
[
MagicMock(content="doc A", url="http://a.com", title="A"),
MagicMock(content="doc B", url="http://b.com", title="B"),
],
[MagicMock(content="doc C", url="http://c.com", title="C")],
]

with (
patch("apex.services.deep_research.deep_research_langchain.FAISS") as mock_faiss,
patch(
"apex.services.deep_research.deep_research_langchain.ContextualCompressionRetriever"
) as mock_compression_retriever_class,
patch("apex.services.deep_research.deep_research_langchain.LLMChainFilter") as mock_llm_chain_filter,
patch.object(deep_research_langchain, "_create_summary_chain") as mock_create_summary_chain,
patch.object(deep_research_langchain, "_create_research_chain") as mock_create_research_chain,
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
):
mock_vector_store = AsyncMock()
mock_faiss.afrom_documents = AsyncMock(return_value=mock_vector_store)
agent_chain = AsyncMock()
agent_chain.ainvoke.side_effect = [
(
'{"thought": "need more info", "action": {"tool": "websearch", '
'"input": {"query": "Q1", "max_results": 2}}}'
),
(
'{"thought": "still need more", "action": {"tool": "websearch", '
'"input": {"query": "Q2", "max_results": 1}}}'
),
'{"thought": "complete", "final_answer": "final_report"}',
]
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain

mock_compression_retriever = AsyncMock()
mock_compression_retriever.ainvoke.return_value = compressed_docs
mock_compression_retriever_class.return_value = mock_compression_retriever

mock_summary_chain = AsyncMock()
mock_summary_chain.ainvoke.return_value = summary
mock_create_summary_chain.return_value = mock_summary_chain
result = await deep_research_langchain.invoke(messages)

mock_research_chain = AsyncMock()
mock_research_chain.ainvoke.return_value = research_report
mock_create_research_chain.return_value = mock_research_chain
# Two tool uses
assert mock_websearch.search.call_count == 2
mock_websearch.search.assert_any_call(query="Q1", max_results=2)
mock_websearch.search.assert_any_call(query="Q2", max_results=1)

final_chain = AsyncMock()
final_chain.ainvoke.return_value = final_answer
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = final_chain
# Final answer returned
assert result[0] == "final_report"
# Tool history recorded
assert len(result[1]) == 2
assert result[1][0]["tool"] == "websearch"
# Reasoning traces present
assert isinstance(result[2], list)

result = await deep_research_langchain.invoke(messages)

mock_websearch.search.assert_called_once_with(query=question, max_results=5)
mock_faiss.afrom_documents.assert_called_once()
mock_llm_chain_filter.from_llm.assert_called_once()
mock_compression_retriever.ainvoke.assert_called_once_with(question)
mock_summary_chain.ainvoke.assert_called_once_with({"context": compressed_docs, "question": question})
mock_research_chain.ainvoke.assert_called_once_with({"context": compressed_docs, "question": question})
final_chain.ainvoke.assert_called_once_with(
{"summary": summary, "research_report": research_report, "question": question}
)
assert result == (final_answer, deep_research_langchain.tool_history)
@pytest.mark.asyncio
async def test_invoke_with_python_repl(deep_research_langchain):
"""Agent chooses python_repl then produces final answer."""
with (
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
):
agent_chain = AsyncMock()
agent_chain.ainvoke.side_effect = [
('{"thought": "compute needed", "action": {"tool": "python_repl", "input": {"code": "print(1+1)"}}}'),
'{"thought": "done", "final_answer": "final_answer"}',
]
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain

result = await deep_research_langchain.invoke([{"role": "user", "content": "q"}])

# Tool history includes python_repl usage
assert any(t["tool"] == "python_repl" for t in result[1])
assert result[0] == "final_answer"
Loading