Skip to content

Commit

Permalink
Add sources to subquestion engine (#6745)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich committed Jul 7, 2023
1 parent 2b2046c commit a1f29ee
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

### New Features
- Sub question query engine returns source nodes of sub questions in `response.metadata['sources']` (#6745)

### Bug Fixes / Nits
- fixed `response_mode="no_text"` response synthesizer (#6755)
- fixed error setting `num_output` and `context_window` in service context (#6766)
Expand Down
11 changes: 9 additions & 2 deletions llama_index/query_engine/sub_question_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import List, Optional, Sequence, cast
from pydantic import BaseModel

from llama_index.bridge.langchain import get_color_mapping, print_text

from llama_index.async_utils import run_async_tasks
Expand All @@ -27,6 +28,7 @@ class SubQuestionAnswerPair(BaseModel):

sub_q: SubQuestion
answer: Optional[str]
sources: Optional[List[NodeWithScore]]


class SubQuestionQueryEngine(BaseQueryEngine):
Expand Down Expand Up @@ -172,6 +174,7 @@ async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
self._aquery_subq(sub_q, color=colors[str(ind)])
for ind, sub_q in enumerate(sub_questions)
]

qa_pairs_all = await asyncio.gather(*tasks)
qa_pairs_all = cast(List[Optional[SubQuestionAnswerPair]], qa_pairs_all)

Expand Down Expand Up @@ -213,7 +216,9 @@ async def _aquery_subq(
if self._verbose:
print_text(f"[{sub_q.tool_name}] A: {response_text}\n", color=color)

return SubQuestionAnswerPair(sub_q=sub_q, answer=response_text)
return SubQuestionAnswerPair(
sub_q=sub_q, answer=response_text, sources=response.source_nodes
)
except ValueError:
logger.warn(f"[{sub_q.tool_name}] Failed to run {question}")
return None
Expand All @@ -234,7 +239,9 @@ def _query_subq(
if self._verbose:
print_text(f"[{sub_q.tool_name}] A: {response_text}\n", color=color)

return SubQuestionAnswerPair(sub_q=sub_q, answer=response_text)
return SubQuestionAnswerPair(
sub_q=sub_q, answer=response_text, sources=response.source_nodes
)
except ValueError:
logger.warn(f"[{sub_q.tool_name}] Failed to run {question}")
return None

0 comments on commit a1f29ee

Please sign in to comment.