Skip to content

Commit

Permalink
Remove usage of stop token in Prompt, SQL gen (#6782)
Browse files Browse the repository at this point in the history
  • Loading branch information
hongyishi committed Jul 8, 2023
1 parent d164316 commit 138034b
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 84 deletions.
132 changes: 103 additions & 29 deletions docs/examples/index_structs/struct_indices/SQLIndexDemo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "119eb42b",
"metadata": {},
"outputs": [],
Expand All @@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "107396a9-4aa7-49b3-9f0f-a755726c19ba",
"metadata": {},
"outputs": [],
Expand All @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "a370b266-66f5-4624-bbf9-2ad57f0511f8",
"metadata": {},
"outputs": [],
Expand All @@ -67,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "ea24f794-f10b-42e6-922d-9258b7167405",
"metadata": {},
"outputs": [],
Expand All @@ -78,7 +78,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "b4154b29-7e23-4c26-a507-370a66186ae7",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -110,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "768d1581-b482-4c73-9963-5ffd68a2aafb",
"metadata": {
"tags": []
Expand All @@ -123,7 +123,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "bffabba0-8e54-4f24-ad14-2c8979c582a5",
"metadata": {
"tags": []
Expand All @@ -136,7 +136,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "9432787b-a8f0-4fc3-8323-e2cd9497df73",
"metadata": {},
"outputs": [],
Expand All @@ -146,10 +146,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "84d4ee54-9f00-40fd-bab0-36e5e579dc9f",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'\\nCREATE TABLE city_stats (\\n\\tcity_name VARCHAR(16) NOT NULL, \\n\\tpopulation INTEGER, \\n\\tcountry VARCHAR(16) NOT NULL, \\n\\tPRIMARY KEY (city_name)\\n)\\n\\n/*\\n3 rows from city_stats table:\\ncity_name\\tpopulation\\tcountry\\n\\n*/'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sql_database.table_info"
]
Expand All @@ -165,7 +176,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "95043e10-6cdf-4f66-96bd-ce307ea7df3e",
"metadata": {},
"outputs": [],
Expand All @@ -188,10 +199,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "b315b8ff-7dd7-4e7d-ac47-8c5a0c3e7ae9",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]\n"
]
}
],
"source": [
"# view current table\n",
"stmt = select(city_stats_table.c[\"city_name\", \"population\", \"country\"]).select_from(\n",
Expand Down Expand Up @@ -223,10 +242,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "eddd3608-31ff-4591-a02a-90987e312669",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('Chicago',)\n",
"('Seoul',)\n",
"('Tokyo',)\n",
"('Toronto',)\n"
]
}
],
"source": [
"from sqlalchemy import text\n",
"\n",
Expand All @@ -253,7 +283,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"id": "5d992fb5",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -297,7 +327,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"id": "d71045c0-7a96-4e86-b38c-c378b7759aa4",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -333,21 +363,45 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"id": "802da9ed",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/markdown": [
"<b> Tokyo has the highest population, with 13,960,000 people.</b>"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"response = query_engine.query(\"Which city has the highest population?\")\n",
"display(Markdown(f\"<b>{response}</b>\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"id": "54a99cb0-578a-40ec-a3eb-1666ac18fbed",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"[('Tokyo', 13960000)]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# you can also fetch the raw result from SQLAlchemy!\n",
"response.metadata[\"result\"]"
Expand All @@ -364,7 +418,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"id": "44a87651",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -395,7 +449,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"id": "8e0acde4-ca61-42e9-97f8-c9cf11502157",
"metadata": {},
"outputs": [],
Expand All @@ -405,7 +459,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"id": "30860e8b-9ad0-418c-b266-753242c1f208",
"metadata": {},
"outputs": [],
Expand All @@ -415,21 +469,41 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"id": "07068a3a-30a4-4473-ba82-ab6e93e3437c",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/hongyishi/Documents/GitHub/gpt_index/.venv/lib/python3.11/site-packages/langchain/chains/sql_database/base.py:63: UserWarning: Directly instantiating an SQLDatabaseChain with an llm is deprecated. Please instantiate with llm_chain argument or using the from_llm class method.\n",
" warnings.warn(\n"
]
}
],
"source": [
"# set Logging to DEBUG for more detailed outputs\n",
"db_chain = SQLDatabaseChain(llm=llm, database=sql_database)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"id": "a04c0a1d-f6a8-4a4a-9181-4123b09ec614",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'Tokyo has the highest population with 13960000 people.'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.run(\"Which city has the highest population?\")"
]
Expand All @@ -451,7 +525,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.11.0"
}
},
"nbformat": 4,
Expand Down
111 changes: 67 additions & 44 deletions docs/guides/tutorials/Airbyte_demo.ipynb

Large diffs are not rendered by default.

16 changes: 12 additions & 4 deletions llama_index/indices/struct_store/sql_query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Default query for SQLStructStoreIndex."""
import logging
from abc import abstractmethod
from typing import Any, List, Optional, Union

from sqlalchemy import Table

from abc import abstractmethod
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.indices.query.schema import QueryBundle
from llama_index.indices.service_context import ServiceContext
Expand All @@ -13,12 +13,12 @@
)
from llama_index.indices.struct_store.sql import SQLStructStoreIndex
from llama_index.langchain_helpers.sql_wrapper import SQLDatabase
from llama_index.objects.base import ObjectRetriever
from llama_index.objects.table_node_mapping import SQLTableSchema
from llama_index.prompts.base import Prompt
from llama_index.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.prompts.prompt_type import PromptType
from llama_index.response.schema import Response
from llama_index.objects.table_node_mapping import SQLTableSchema
from llama_index.objects.base import ObjectRetriever

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,6 +127,10 @@ def service_context(self) -> ServiceContext:

def _parse_response_to_sql(self, response: str) -> str:
"""Parse response to SQL."""
# Find and remove SQLResult part
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
result_response = response.strip()
return result_response

Expand Down Expand Up @@ -237,6 +241,10 @@ def service_context(self) -> ServiceContext:

def _parse_response_to_sql(self, response: str) -> str:
"""Parse response to SQL."""
# Find and remove SQLResult part
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
result_response = response.strip()
return result_response

Expand Down Expand Up @@ -348,7 +356,7 @@ def _get_table_context(self, query_bundle: QueryBundle) -> str:

else:
# get all tables
table_names = self._sql_database.get_table_names()
table_names = self._sql_database.get_usable_table_names()
for table_name in table_names:
table_info = self._sql_database.get_single_table_info(table_name)
context_strs.append(table_info)
Expand Down
6 changes: 0 additions & 6 deletions llama_index/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@ class Prompt:
Wrapper around langchain's prompt class. Adds ability to:
- enforce certain prompt types
- partially fill values
- define stop token
"""

def __init__(
self,
template: Optional[str] = None,
langchain_prompt: Optional[BaseLangchainPrompt] = None,
langchain_prompt_selector: Optional[PromptSelector] = None,
stop_token: Optional[str] = None,
output_parser: Optional[BaseOutputParser] = None,
prompt_type: str = PromptType.CUSTOM,
metadata: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -61,7 +58,6 @@ def __init__(

self.partial_dict: Dict[str, Any] = {}
self.prompt_kwargs = prompt_kwargs
self.stop_token = stop_token
# NOTE: this is only used for token counting and testing
self.prompt_type = prompt_type

Expand Down Expand Up @@ -174,6 +170,4 @@ def get_full_format_args(self, kwargs: Dict) -> Dict[str, Any]:
"""
kwargs.update(self.partial_dict)
if self.stop_token is not None:
kwargs["stop"] = self.stop_token
return kwargs
1 change: 0 additions & 1 deletion llama_index/prompts/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@

DEFAULT_TEXT_TO_SQL_PROMPT = Prompt(
DEFAULT_TEXT_TO_SQL_TMPL,
stop_token="\nSQLResult:",
prompt_type=PromptType.TEXT_TO_SQL,
)

Expand Down

0 comments on commit 138034b

Please sign in to comment.