From 04ed9df56d6ac7a9d032c89c26b9c013442cbce2 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 16 Apr 2024 13:47:13 -0700 Subject: [PATCH] Proper tests for generation, refs #9 --- README.md | 4 + datasette_query_assistant/__init__.py | 1 - .../templates/query_assistant.html | 4 +- pyproject.toml | 2 +- .../test_database_assistant_page.yaml | 81 +++++++++++++++++++ .../test_table_assistant_page.yaml | 81 +++++++++++++++++++ tests/conftest.py | 6 -- tests/test_query_assistant.py | 50 ++++++++++++ 8 files changed, 220 insertions(+), 9 deletions(-) create mode 100644 tests/cassettes/test_query_assistant/test_database_assistant_page.yaml create mode 100644 tests/cassettes/test_query_assistant/test_table_assistant_page.yaml delete mode 100644 tests/conftest.py diff --git a/README.md b/README.md index 24c18ac..ef822c9 100644 --- a/README.md +++ b/README.md @@ -40,3 +40,7 @@ To run the tests: ```bash pytest ``` +To re-generate the tests with refreshed examples from the Claude 3 API: +```bash +pytest -x --record-mode=rewrite --inline-snapshot=fix +``` diff --git a/datasette_query_assistant/__init__.py b/datasette_query_assistant/__init__.py index 910ee04..5228bb3 100644 --- a/datasette_query_assistant/__init__.py +++ b/datasette_query_assistant/__init__.py @@ -139,7 +139,6 @@ async def assistant(request, datasette): + urllib.parse.urlencode({"sql": sql}) ) - # Figure out tables table = request.args.get("table") schema = await get_schema(db, table) return Response.html( diff --git a/datasette_query_assistant/templates/query_assistant.html b/datasette_query_assistant/templates/query_assistant.html index e277e1f..efc6d15 100644 --- a/datasette_query_assistant/templates/query_assistant.html +++ b/datasette_query_assistant/templates/query_assistant.html @@ -17,7 +17,9 @@

Query assistant for {% if table %}{{ table }}{% else %}{{ database }}{% endi

- + {% if table %} + + {% endif %}

diff --git a/pyproject.toml b/pyproject.toml index 47db01b..a4343c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ CI = "https://github.com/datasette/datasette-query-assistant/actions" query_assistant = "datasette_query_assistant" [project.optional-dependencies] -test = ["pytest", "pytest-asyncio", "pytest-recording", "sqlite-utils"] +test = ["inline-snapshot", "pytest", "pytest-asyncio", "pytest-recording", "sqlite-utils"] [tool.pytest.ini_options] asyncio_mode = "strict" diff --git a/tests/cassettes/test_query_assistant/test_database_assistant_page.yaml b/tests/cassettes/test_query_assistant/test_database_assistant_page.yaml new file mode 100644 index 0000000..34a43b8 --- /dev/null +++ b/tests/cassettes/test_query_assistant/test_database_assistant_page.yaml @@ -0,0 +1,81 @@ +interactions: +- request: + body: '{"max_tokens": 1024, "messages": [{"role": "user", "content": "The table + schema is:\nCREATE TABLE foo (id integer primary key, name text)"}, {"role": + "assistant", "content": "Ask questions to generate SQL"}, {"role": "user", "content": + "How many rows in the sqlite_master table?"}, {"role": "assistant", "content": + "select count(*) from sqlite_master\n-- Count rows in the sqlite_master table"}, + {"role": "user", "content": "Show me all the data in the foo table"}, {"role": + "assistant", "content": "select"}], "model": "claude-3-haiku-20240307", "system": + "You answer questions by generating SQL queries using SQLite schema syntax.\nAlways + start with -- SQL comments explaining what you are about to do.\nNo yapping. + Output SQL with extensive SQL comments and nothing else.\nOr output SQL in a + sql tagged fenced markdown code block.\nReturn only one SQL SELECT query."}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + anthropic-version: + - '2023-06-01' + connection: + - keep-alive + content-length: + - '870' + content-type: + - application/json + host: + - api.anthropic.com + user-agent: + - AsyncAnthropic/Python 0.25.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - async:asyncio + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 0.25.2 + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.10.10 + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: !!binary | + H4sIAAAAAAAAA1SOy2rDMBBFf0XcZZHBiRMC+oFCyMJtdn1gVHucOJZHrjQiDcb/Xpy2i64uHA6X + M6FrYDDEU5Wvytvh8akI27a87i/l3j6/8PHQQ0NuIy0WxWhPBI3g3QJsjF0UywKNwTfkYFA7mxrK + iuxsuz5l63y9yYt8B40ofqwifSbimmA4OaeR7o9mQsdjkkp8TxxhVputhk/yj+1mjdqzEAvM6/SX + JfS1BNzHQD2oNvhBtd6/cZapIzmqRVnnVPDXqCw3qvYuDRx/RDnTIiuxH44wv/+GBrLRMwyIm0pS + YMzfAAAA//8DANhA6bQtAQAA + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 875704994c7a67f9-SJC + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 16 Apr 2024 20:45:15 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + request-id: + - req_01KDY28fjzAA9d615XULxyjW + via: + - 1.1 google + x-cloud-trace-context: + - c5f2822f3aa36f6cc77d4241a215077a + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/test_query_assistant/test_table_assistant_page.yaml b/tests/cassettes/test_query_assistant/test_table_assistant_page.yaml new file mode 100644 index 0000000..e5eef89 --- /dev/null +++ b/tests/cassettes/test_query_assistant/test_table_assistant_page.yaml @@ -0,0 +1,81 @@ +interactions: +- request: + body: '{"max_tokens": 1024, "messages": [{"role": "user", "content": "The table + schema is:\nCREATE TABLE foo (id integer primary key, name text)"}, {"role": + "assistant", "content": "Ask questions to generate SQL"}, {"role": "user", "content": + "How many rows in the sqlite_master table?"}, {"role": "assistant", "content": + "select count(*) from sqlite_master\n-- Count rows in the sqlite_master table"}, + {"role": "user", "content": "Count of rows in foo"}, {"role": "assistant", "content": + "select"}], "model": "claude-3-haiku-20240307", "system": "You answer questions + by generating SQL queries using SQLite schema syntax.\nAlways start with -- + SQL comments explaining what you are about to do.\nNo yapping. Output SQL with + extensive SQL comments and nothing else.\nOr output SQL in a sql tagged fenced + markdown code block.\nReturn only one SQL SELECT query."}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + anthropic-version: + - '2023-06-01' + connection: + - keep-alive + content-length: + - '853' + content-type: + - application/json + host: + - api.anthropic.com + user-agent: + - AsyncAnthropic/Python 0.25.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - async:asyncio + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 0.25.2 + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.10.10 + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: !!binary | + H4sIAAAAAAAAA0zOQUvDQBCG4b+yfJeqbCBNiqV7VURQBA8eikrYJpM2NJlps7O2EvLfJUXB08DD + 8PEOaCo4dGFbpPP7F6rWt8vHB34+n5669XH19Za/wkK/DzR9UQh+S7DopZ3Ah9AE9ayw6KSiFg5l + 62NFSZ7sfLOPSZZmizRPl7AIKoci0DESlwTHsW0t4mXRDWj4ELVQ2RMHuPlibiFR/1uWjxalsBIr + 3Pvwl6V0ngIux8GUElmvbq5N3UtnahHzwUli7iY2uiPDsdtQb6Q2vZyCafiis1pkZtRvWsL4+Vvb + kw/CcCCuCo09Y/wBAAD//wMAeHTZSzIBAAA= + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8757049e484f5c22-SJC + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 16 Apr 2024 20:45:16 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + request-id: + - req_01LtBxUFphUrt3NNjHgF6Pfm + via: + - 1.1 google + x-cloud-trace-context: + - 49ebe16a8310eedc8eea83e3dd710db8 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index af2ffcf..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -import pytest - - -@pytest.fixture(autouse=True) -def patch_env(monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "mock-key") diff --git a/tests/test_query_assistant.py b/tests/test_query_assistant.py index 8918f89..7bfb388 100644 --- a/tests/test_query_assistant.py +++ b/tests/test_query_assistant.py @@ -1,8 +1,17 @@ from datasette.app import Datasette from datasette_query_assistant import get_related_tables +from inline_snapshot import snapshot import pytest_asyncio import pytest import sqlite_utils +import urllib + +pytestmark = [pytest.mark.vcr(ignore_localhost=True)] + + +@pytest.fixture(scope="module") +def vcr_config(): + return {"filter_headers": ["x-api-key"]} @pytest_asyncio.fixture @@ -30,6 +39,7 @@ def test_get_related_tables(): @pytest.mark.asyncio +@pytest.mark.vcr() async def test_database_assistant_page(datasette): response = await datasette.client.get("/test/-/assistant") assert response.status_code == 200 @@ -38,6 +48,26 @@ async def test_database_assistant_page(datasette): "
CREATE TABLE foo (id integer primary key, name text)
" in response.text ) + # Submit the form + csrftoken = response.cookies["ds_csrftoken"] + post_response = await datasette.client.post( + "/test/-/assistant", + cookies={ + "ds_csrftoken": csrftoken, + }, + data={ + "question": "Show me all the data in the foo table", + "csrftoken": csrftoken, + }, + ) + assert post_response.status_code == 302 + qs = dict(urllib.parse.parse_qsl(post_response.headers["location"].split("?")[1])) + assert qs["sql"] == snapshot( + """\ +select * from foo +-- Select all rows and columns from the foo table\ +""" + ) @pytest.mark.asyncio @@ -49,3 +79,23 @@ async def test_table_assistant_page(datasette): "
CREATE TABLE foo (id integer primary key, name text)
" in response.text ) + # Submit the form + csrftoken = response.cookies["ds_csrftoken"] + post_response = await datasette.client.post( + "/test/-/assistant", + cookies={ + "ds_csrftoken": csrftoken, + }, + data={ + "question": "Count of rows in foo", + "csrftoken": csrftoken, + }, + ) + assert post_response.status_code == 302 + qs = dict(urllib.parse.parse_qsl(post_response.headers["location"].split("?")[1])) + assert qs["sql"] == snapshot( + """\ +select count(*) from foo +-- Count the number of rows in the 'foo' table\ +""" + )