Skip to content

Commit

Permalink
Add function calling based sub question generator class, and use it b…
Browse files Browse the repository at this point in the history
…y default if possible (#6949)
  • Loading branch information
Disiok committed Jul 20, 2023
1 parent 991c0eb commit cb6c6ca
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

### New Features
- Default to pydantic question generation when possible for sub-question query engine (#6979)

### Bug Fixes / Nits
- Fix returned order of messages in large chat memory (#6979)

Expand Down
195 changes: 195 additions & 0 deletions docs/examples/output_parsing/openai_sub_question.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c58e17b3-ec09-4e07-8e2e-d19a8e24dd40",
"metadata": {
"tags": []
},
"source": [
"# OpenAI function calling for Sub-Question Query Engine"
]
},
{
"cell_type": "markdown",
"id": "d5637f97-60c3-40bb-840f-fc4e217940a7",
"metadata": {
"tags": []
},
"source": [
"In this notebook, we showcase how to use OpenAI function calling to improve the robustness of our sub-question query engine. "
]
},
{
"cell_type": "markdown",
"id": "bd3d24c8-5b2b-4acf-a9de-53134453c186",
"metadata": {},
"source": [
"The sub-question query engine is designed to accept swappable question generators that implement the `BaseQuestionGenerator` interface. \n",
"To leverage the power of openai function calling API, we implemented a new `OpenAIQuestionGenerator` (powered by our `OpenAIPydanticProgram`)"
]
},
{
"cell_type": "markdown",
"id": "afa2db97-2a46-4629-a201-d4eb99480f3d",
"metadata": {},
"source": [
"## OpenAI Question Generator"
]
},
{
"cell_type": "markdown",
"id": "3977e961-fb19-495f-89c5-6a283596b459",
"metadata": {},
"source": [
"Unlike the default `LLMQuestionGenerator` that supports generic LLMs via the completion API, `OpenAIQuestionGenerator` only works with the latest OpenAI models that supports the function calling API. \n",
"\n",
"The benefit is that these models are fine-tuned to output JSON objects, so we can worry less about output parsing issues."
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "85b9e1d3-2f60-4730-8186-7c3c30b6dae5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from llama_index.question_gen.openai_generator import OpenAIQuestionGenerator"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "0df7f8ad-c026-4bfc-9a12-52efcb24f9d5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"question_gen = OpenAIQuestionGenerator.from_defaults()"
]
},
{
"cell_type": "markdown",
"id": "04039c8c-72df-495d-915c-09d04321bb96",
"metadata": {},
"source": [
"Let's test it out!"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "1e40ac6c-6b66-4cf3-9dd6-de02416b7dd5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from llama_index.tools import ToolMetadata\n",
"from llama_index import QueryBundle"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "77106a07-bccf-471d-8d85-c6438772cf35",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tools = [\n",
" ToolMetadata(\n",
" name=\"march_22\",\n",
" description=\"Provides information about Uber quarterly financials ending March 2022\",\n",
" ),\n",
" ToolMetadata(\n",
" name=\"june_22\",\n",
" description=\"Provides information about Uber quarterly financials ending June 2022\",\n",
" ),\n",
" ToolMetadata(\n",
" name=\"sept_22\",\n",
" description=\"Provides information about Uber quarterly financials ending September 2022\",\n",
" ),\n",
" ToolMetadata(\n",
" name=\"sept_21\",\n",
" description=\"Provides information about Uber quarterly financials ending September 2022\",\n",
" ),\n",
" ToolMetadata(\n",
" name=\"june_21\",\n",
" description=\"Provides information about Uber quarterly financials ending June 2022\",\n",
" ),\n",
" ToolMetadata(\n",
" name=\"march_21\",\n",
" description=\"Provides information about Uber quarterly financials ending March 2022\",\n",
" ),\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "82ed271a-bd0d-4b6a-b9e3-987d75f6a4ad",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"sub_questions = question_gen.generate(\n",
" tools=tools,\n",
" query=QueryBundle(\n",
" \"Compare the fastest growing sectors for Uber in the first two quarters of 2022\"\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "2740e60e-c4e6-412a-b46f-70a1f3fe1231",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"[SubQuestion(sub_question='What were the fastest growing sectors for Uber in March 2022?', tool_name='march_22'),\n",
" SubQuestion(sub_question='What were the fastest growing sectors for Uber in June 2022?', tool_name='june_22')]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sub_questions"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 2 additions & 1 deletion llama_index/llms/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Union
from llama_index.llms.base import LLM

from langchain.base_language import BaseLanguageModel

from llama_index.llms.base import LLM
from llama_index.llms.langchain import LangChainLLM
from llama_index.llms.openai import OpenAI

Expand Down
4 changes: 2 additions & 2 deletions llama_index/program/openai_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel

from llama_index.llms.base import ChatMessage, MessageRole
from llama_index.llms.base import LLM, ChatMessage, MessageRole
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import is_function_calling_model, to_openai_function
from llama_index.program.llm_prompt_program import BaseLLMFunctionProgram
Expand Down Expand Up @@ -45,7 +45,7 @@ def from_defaults(
cls,
output_cls: Type[Model],
prompt_template_str: str,
llm: Optional[OpenAI] = None,
llm: Optional[LLM] = None,
verbose: bool = False,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs: Any,
Expand Down
24 changes: 19 additions & 5 deletions llama_index/query_engine/sub_question_query_engine.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import asyncio
import logging
from typing import List, Optional, Sequence, cast
from pydantic import BaseModel

from llama_index.bridge.langchain import get_color_mapping, print_text
from pydantic import BaseModel

from llama_index.async_utils import run_async_tasks
from llama_index.bridge.langchain import get_color_mapping, print_text
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.indices.query.schema import QueryBundle
from llama_index.indices.service_context import ServiceContext
from llama_index.question_gen.llm_generators import LLMQuestionGenerator
from llama_index.question_gen.openai_generator import OpenAIQuestionGenerator
from llama_index.question_gen.types import BaseQuestionGenerator, SubQuestion
from llama_index.response.schema import RESPONSE_TYPE
from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer
Expand Down Expand Up @@ -87,9 +88,22 @@ def from_defaults(
elif len(query_engine_tools) > 0:
callback_manager = query_engine_tools[0].query_engine.callback_manager

question_gen = question_gen or LLMQuestionGenerator.from_defaults(
service_context=service_context
)
if question_gen is None:
if service_context is None:
# use default openai model that supports function calling API
question_gen = OpenAIQuestionGenerator.from_defaults()
else:
# try to use OpenAI function calling based question generator.
# if incompatible, use general LLM question generator
try:
question_gen = OpenAIQuestionGenerator.from_defaults(
llm=service_context.llm
)
except ValueError:
question_gen = LLMQuestionGenerator.from_defaults(
service_context=service_context
)

synth = response_synthesizer or get_response_synthesizer(
callback_manager=callback_manager,
service_context=service_context,
Expand Down
17 changes: 5 additions & 12 deletions llama_index/question_gen/guidance_generator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import TYPE_CHECKING, List, Optional, Sequence, cast

from pydantic import BaseModel

from llama_index.indices.query.schema import QueryBundle
from llama_index.program.guidance_program import GuidancePydanticProgram
from llama_index.prompts.guidance_utils import convert_to_handlebars
from llama_index.question_gen.prompts import (
DEFAULT_SUB_QUESTION_PROMPT_TMPL,
build_tools_text,
)
from llama_index.question_gen.types import BaseQuestionGenerator, SubQuestion
from llama_index.question_gen.types import (
BaseQuestionGenerator,
SubQuestion,
SubQuestionList,
)
from llama_index.tools.types import ToolMetadata

if TYPE_CHECKING:
Expand All @@ -20,15 +22,6 @@
)


class SubQuestionList(BaseModel):
"""A pydantic object wrapping a list of sub-questions.
This is mostly used to make getting a json schema easier.
"""

items: List[SubQuestion]


class GuidanceQuestionGenerator(BaseQuestionGenerator):
def __init__(
self,
Expand Down
Loading

0 comments on commit cb6c6ca

Please sign in to comment.