Skip to content

Commit

Permalink
Proper tests for generation, refs #9
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Apr 16, 2024
1 parent 537b3c8 commit 04ed9df
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 9 deletions.
4 changes: 4 additions & 0 deletions README.md
Expand Up @@ -40,3 +40,7 @@ To run the tests:
```bash
pytest
```
To re-generate the tests with refreshed examples from the Claude 3 API:
```bash
pytest -x --record-mode=rewrite --inline-snapshot=fix
```
1 change: 0 additions & 1 deletion datasette_query_assistant/__init__.py
Expand Up @@ -139,7 +139,6 @@ async def assistant(request, datasette):
+ urllib.parse.urlencode({"sql": sql})
)

# Figure out tables
table = request.args.get("table")
schema = await get_schema(db, table)
return Response.html(
Expand Down
4 changes: 3 additions & 1 deletion datasette_query_assistant/templates/query_assistant.html
Expand Up @@ -17,7 +17,9 @@ <h1>Query assistant for {% if table %}{{ table }}{% else %}{{ database }}{% endi
<p><textarea style="width: 80%; height: 3em;" name="question" id="id_question"></textarea>
<p>
<input type="hidden" name="csrftoken" value="{{ csrftoken() }}">
<input type="hidden" name="table" value="{{ table }}">
{% if table %}
<input type="hidden" name="table" value="{{ table }}">
{% endif %}
<input type="submit" value="Submit">
</p>
</form>
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -25,7 +25,7 @@ CI = "https://github.com/datasette/datasette-query-assistant/actions"
query_assistant = "datasette_query_assistant"

[project.optional-dependencies]
test = ["pytest", "pytest-asyncio", "pytest-recording", "sqlite-utils"]
test = ["inline-snapshot", "pytest", "pytest-asyncio", "pytest-recording", "sqlite-utils"]

[tool.pytest.ini_options]
asyncio_mode = "strict"
Expand Down
@@ -0,0 +1,81 @@
interactions:
- request:
body: '{"max_tokens": 1024, "messages": [{"role": "user", "content": "The table
schema is:\nCREATE TABLE foo (id integer primary key, name text)"}, {"role":
"assistant", "content": "Ask questions to generate SQL"}, {"role": "user", "content":
"How many rows in the sqlite_master table?"}, {"role": "assistant", "content":
"select count(*) from sqlite_master\n-- Count rows in the sqlite_master table"},
{"role": "user", "content": "Show me all the data in the foo table"}, {"role":
"assistant", "content": "select"}], "model": "claude-3-haiku-20240307", "system":
"You answer questions by generating SQL queries using SQLite schema syntax.\nAlways
start with -- SQL comments explaining what you are about to do.\nNo yapping.
Output SQL with extensive SQL comments and nothing else.\nOr output SQL in a
sql tagged fenced markdown code block.\nReturn only one SQL SELECT query."}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
anthropic-version:
- '2023-06-01'
connection:
- keep-alive
content-length:
- '870'
content-type:
- application/json
host:
- api.anthropic.com
user-agent:
- AsyncAnthropic/Python 0.25.2
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 0.25.2
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.10.10
method: POST
uri: https://api.anthropic.com/v1/messages
response:
body:
string: !!binary |
H4sIAAAAAAAAA1SOy2rDMBBFf0XcZZHBiRMC+oFCyMJtdn1gVHucOJZHrjQiDcb/Xpy2i64uHA6X
M6FrYDDEU5Wvytvh8akI27a87i/l3j6/8PHQQ0NuIy0WxWhPBI3g3QJsjF0UywKNwTfkYFA7mxrK
iuxsuz5l63y9yYt8B40ofqwifSbimmA4OaeR7o9mQsdjkkp8TxxhVputhk/yj+1mjdqzEAvM6/SX
JfS1BNzHQD2oNvhBtd6/cZapIzmqRVnnVPDXqCw3qvYuDRx/RDnTIiuxH44wv/+GBrLRMwyIm0pS
YMzfAAAA//8DANhA6bQtAQAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 875704994c7a67f9-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 16 Apr 2024 20:45:15 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
request-id:
- req_01KDY28fjzAA9d615XULxyjW
via:
- 1.1 google
x-cloud-trace-context:
- c5f2822f3aa36f6cc77d4241a215077a
status:
code: 200
message: OK
version: 1
@@ -0,0 +1,81 @@
interactions:
- request:
body: '{"max_tokens": 1024, "messages": [{"role": "user", "content": "The table
schema is:\nCREATE TABLE foo (id integer primary key, name text)"}, {"role":
"assistant", "content": "Ask questions to generate SQL"}, {"role": "user", "content":
"How many rows in the sqlite_master table?"}, {"role": "assistant", "content":
"select count(*) from sqlite_master\n-- Count rows in the sqlite_master table"},
{"role": "user", "content": "Count of rows in foo"}, {"role": "assistant", "content":
"select"}], "model": "claude-3-haiku-20240307", "system": "You answer questions
by generating SQL queries using SQLite schema syntax.\nAlways start with --
SQL comments explaining what you are about to do.\nNo yapping. Output SQL with
extensive SQL comments and nothing else.\nOr output SQL in a sql tagged fenced
markdown code block.\nReturn only one SQL SELECT query."}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
anthropic-version:
- '2023-06-01'
connection:
- keep-alive
content-length:
- '853'
content-type:
- application/json
host:
- api.anthropic.com
user-agent:
- AsyncAnthropic/Python 0.25.2
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 0.25.2
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.10.10
method: POST
uri: https://api.anthropic.com/v1/messages
response:
body:
string: !!binary |
H4sIAAAAAAAAA0zOQUvDQBCG4b+yfJeqbCBNiqV7VURQBA8eikrYJpM2NJlps7O2EvLfJUXB08DD
8PEOaCo4dGFbpPP7F6rWt8vHB34+n5669XH19Za/wkK/DzR9UQh+S7DopZ3Ah9AE9ayw6KSiFg5l
62NFSZ7sfLOPSZZmizRPl7AIKoci0DESlwTHsW0t4mXRDWj4ELVQ2RMHuPlibiFR/1uWjxalsBIr
3Pvwl6V0ngIux8GUElmvbq5N3UtnahHzwUli7iY2uiPDsdtQb6Q2vZyCafiis1pkZtRvWsL4+Vvb
kw/CcCCuCo09Y/wBAAD//wMAeHTZSzIBAAA=
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8757049e484f5c22-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 16 Apr 2024 20:45:16 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
request-id:
- req_01LtBxUFphUrt3NNjHgF6Pfm
via:
- 1.1 google
x-cloud-trace-context:
- 49ebe16a8310eedc8eea83e3dd710db8
status:
code: 200
message: OK
version: 1
6 changes: 0 additions & 6 deletions tests/conftest.py

This file was deleted.

50 changes: 50 additions & 0 deletions tests/test_query_assistant.py
@@ -1,8 +1,17 @@
from datasette.app import Datasette
from datasette_query_assistant import get_related_tables
from inline_snapshot import snapshot
import pytest_asyncio
import pytest
import sqlite_utils
import urllib

pytestmark = [pytest.mark.vcr(ignore_localhost=True)]


@pytest.fixture(scope="module")
def vcr_config():
return {"filter_headers": ["x-api-key"]}


@pytest_asyncio.fixture
Expand Down Expand Up @@ -30,6 +39,7 @@ def test_get_related_tables():


@pytest.mark.asyncio
@pytest.mark.vcr()
async def test_database_assistant_page(datasette):
response = await datasette.client.get("/test/-/assistant")
assert response.status_code == 200
Expand All @@ -38,6 +48,26 @@ async def test_database_assistant_page(datasette):
"<pre>CREATE TABLE foo (id integer primary key, name text)</pre>"
in response.text
)
# Submit the form
csrftoken = response.cookies["ds_csrftoken"]
post_response = await datasette.client.post(
"/test/-/assistant",
cookies={
"ds_csrftoken": csrftoken,
},
data={
"question": "Show me all the data in the foo table",
"csrftoken": csrftoken,
},
)
assert post_response.status_code == 302
qs = dict(urllib.parse.parse_qsl(post_response.headers["location"].split("?")[1]))
assert qs["sql"] == snapshot(
"""\
select * from foo
-- Select all rows and columns from the foo table\
"""
)


@pytest.mark.asyncio
Expand All @@ -49,3 +79,23 @@ async def test_table_assistant_page(datasette):
"<pre>CREATE TABLE foo (id integer primary key, name text)</pre>"
in response.text
)
# Submit the form
csrftoken = response.cookies["ds_csrftoken"]
post_response = await datasette.client.post(
"/test/-/assistant",
cookies={
"ds_csrftoken": csrftoken,
},
data={
"question": "Count of rows in foo",
"csrftoken": csrftoken,
},
)
assert post_response.status_code == 302
qs = dict(urllib.parse.parse_qsl(post_response.headers["location"].split("?")[1]))
assert qs["sql"] == snapshot(
"""\
select count(*) from foo
-- Count the number of rows in the 'foo' table\
"""
)

0 comments on commit 04ed9df

Please sign in to comment.