Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for query params + JSON output to notebook protocol #5831

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,14 @@ def compile_notebook(
queries: List[str],
protocol_version: defines.ProtocolVersion,
implicit_limit: int = 0,
inject_implicit_typenames: bool = True,
output_format: enums.OutputFormat = enums.OutputFormat.BINARY,
) -> List[
Tuple[
bool,
Union[dbstate.QueryUnit, Tuple[str, str, Dict[int, str]]]
]
]:

state = dbstate.CompilerConnectionState(
user_schema=user_schema,
global_schema=global_schema,
Expand All @@ -476,11 +477,12 @@ def compile_notebook(
ctx = CompileContext(
compiler_state=self.state,
state=state,
output_format=enums.OutputFormat.BINARY,
output_format=output_format,
expected_cardinality_one=False,
implicit_limit=implicit_limit,
inline_typeids=False,
inline_typenames=True,
inline_typenames=inject_implicit_typenames,
inline_objectids=inject_implicit_typenames,
json_parameters=False,
source=source,
protocol_version=protocol_version,
Expand Down
82 changes: 75 additions & 7 deletions edb/server/protocol/notebook_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ from edb.server.pgproto.pgproto cimport (

include "./consts.pxi"

cdef tuple CURRENT_PROTOCOL = edbdef.CURRENT_PROTOCOL
cdef tuple CURRENT_PROTOCOL = (1, 0)

ALLOWED_CAPABILITIES = (
enums.Capability.MODIFICATIONS |
Expand Down Expand Up @@ -96,11 +96,63 @@ async def handle_request(
response.body = b'{"kind": "status", "status": "OK"}'
return

if args == ['parse'] and request.method == b'POST':
try:
body = json.loads(request.body)
if not isinstance(body, dict):
raise TypeError(
'the body of the request must be a JSON object')
query = body.get('query')
inject_typenames = body.get('inject_typenames')
json_output = body.get('json_output')

if not query:
raise TypeError(
'invalid notebook parse request: "query" is missing')

_dbv, units = await parse(
db, server, [query], inject_typenames, json_output)
dbv: dbview.DatabaseConnectionView = _dbv
is_error, unit_or_error = units[0]
if is_error:
response.status = http.HTTPStatus.OK
response.body = json.dumps({
'kind': 'error',
'error': {
'message': unit_or_error[1],
'type': unit_or_error[0],
}
}).encode()
return
else:
query_unit = unit_or_error
dbv.check_capabilities(
query_unit.capabilities,
ALLOWED_CAPABILITIES,
errors.UnsupportedCapabilityError,
"disallowed in notebook",
)
in_type_id = base64.b64encode(query_unit.in_type_id)
in_type = base64.b64encode(query_unit.in_type_data)

except Exception as ex:
return handle_error(request, response, ex)
else:
response.status = http.HTTPStatus.OK
response.custom_headers['EdgeDB-Protocol-Version'] = \
f'{CURRENT_PROTOCOL[0]}.{CURRENT_PROTOCOL[1]}'
response.body = (
b'{"kind": "parse_result", "in_type_id": "' + in_type_id +
b'", "in_type": "' + in_type + b'"}')
return


if args != []:
ex = Exception(f'Unknown path')
return handle_error(request, response, ex)

queries = None
params = None

try:
if request.method == b'POST':
Expand All @@ -109,6 +161,9 @@ async def handle_request(
raise TypeError(
'the body of the request must be a JSON object')
queries = body.get('queries')
params = body.get('params')
inject_typenames = body.get('inject_typenames')
json_output = body.get('json_output')

else:
raise TypeError('expected a POST request')
Expand All @@ -122,7 +177,10 @@ async def handle_request(

response.status = http.HTTPStatus.OK
try:
result = await execute(db, server, queries)
result = await execute(
db, server, queries,
base64.b64decode(params) if params else None,
inject_typenames, json_output)
except Exception as ex:
return handle_error(request, response, ex)
else:
Expand Down Expand Up @@ -153,11 +211,11 @@ cdef class NotebookConnection(frontend.AbstractFrontendConnection):
pass


async def execute(db, server, queries: list):
async def parse(db, server, queries, inject_typenames, json_output):
dbv: dbview.DatabaseConnectionView = await server.new_dbview(
dbname=db.name,
query_cache=False,
protocol_version=edbdef.CURRENT_PROTOCOL,
protocol_version=CURRENT_PROTOCOL,
)
compiler_pool = server.get_compiler_pool()
units = await compiler_pool.compile_notebook(
Expand All @@ -170,9 +228,19 @@ async def execute(db, server, queries: list):
queries,
CURRENT_PROTOCOL,
50, # implicit limit
inject_typenames if inject_typenames is not None else True,
enums.OutputFormat.JSON if json_output else enums.OutputFormat.BINARY,
)
return dbv, units


async def execute(
db, server, queries: list, params, inject_typenames, json_output
):
_dbv, units = await parse(
db, server, queries, inject_typenames, json_output)
dbv: dbview.DatabaseConnectionView = _dbv
result = []
bind_data = None
pgcon = await server.acquire_pgcon(db.name)
try:
await pgcon.sql_execute(b'START TRANSACTION;')
Expand All @@ -195,7 +263,7 @@ async def execute(db, server, queries: list):
"disallowed in notebook",
)
try:
if query_unit.in_type_args:
if not params and query_unit.in_type_args:
raise errors.QueryError(
'cannot use query parameters in tutorial')

Expand All @@ -206,7 +274,7 @@ async def execute(db, server, queries: list):
compiled = dbview.CompiledQuery(
query_unit_group=query_unit_group)
await p_execute.execute(
pgcon, dbv, compiled, b'', fe_conn=fe_conn,
pgcon, dbv, compiled, params or b'', fe_conn=fe_conn,
skip_start=True,
)

Expand Down
137 changes: 135 additions & 2 deletions tests/test_http_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import json
import urllib
import base64

from edb.testbase import http as tb

Expand All @@ -34,11 +35,25 @@ class TestHttpNotebook(tb.BaseHttpExtensionTest):
def get_extension_name(cls):
return 'notebook'

def run_queries(self, queries: List[str]):
req_data = {
def run_queries(
self,
queries: List[str],
params: Optional[str] = None,
*,
inject_typenames: Optional[bool] = None,
json_output: Optional[bool] = None
):
req_data: dict[str, Any] = {
'queries': queries
}

if params is not None:
req_data['params'] = params
if inject_typenames is not None:
req_data['inject_typenames'] = inject_typenames
if json_output is not None:
req_data['json_output'] = json_output

req = urllib.request.Request(
self.http_addr, method='POST') # type: ignore
req.add_header('Content-Type', 'application/json')
Expand All @@ -51,6 +66,22 @@ def run_queries(self, queries: List[str]):
resp_data = json.loads(response.read())
return resp_data

def parse_query(self, query: str):
req = urllib.request.Request(
self.http_addr + '/parse', method='POST') # type: ignore
req.add_header('Content-Type', 'application/json')
response = urllib.request.urlopen(
req, json.dumps({'query': query}).encode(),
context=self.tls_context
)

resp_data = json.loads(response.read())

if resp_data['kind'] != 'error':
self.assertIsNotNone(response.headers['EdgeDB-Protocol-Version'])

return resp_data

def test_http_notebook_01(self):
results = self.run_queries([
'SELECT 1',
Expand Down Expand Up @@ -273,3 +304,105 @@ def test_http_notebook_09(self):
])

self.assertNotIn('error', results['results'][0])

def test_http_notebook_10(self):
# Check that if no 'params' field is sent an error is still thrown
# when query contains params, to maintain behaviour of edgeql tutorial
results = self.run_queries(['select <str>$test'])

self.assert_data_shape(results, {
'kind': 'results',
'results': [{
'kind': 'error',
'error': [
'QueryError',
'cannot use query parameters in tutorial',
{}
]
}]
})

def test_http_notebook_11(self):
results = self.run_queries(['select <str>$test'],
'AAAAAQAAAAAAAAAIdGVzdCBzdHI=')

self.assert_data_shape(results, {
'kind': 'results',
'results': [{
'kind': 'data',
'data': [
str,
str,
'RAAAABIAAQAAAAh0ZXN0IHN0cg==',
str
]
}]
})

def test_http_notebook_12(self):
result = self.parse_query('select <array<int32>>$test_param')

self.assert_data_shape(result, {
'kind': 'parse_result',
'in_type_id': str,
'in_type': str
})

error_result = self.parse_query('select $invalid')

self.assert_data_shape(error_result, {
'kind': 'error',
'error': {
'type': 'QueryError',
'message': 'missing a type cast before the parameter'
}
})

def test_http_notebook_13(self):
results = []
for inject_typenames, json_output in [
(None, None),
(True, None),
(False, None),
(None, True),
(None, False),
(True, True),
(True, False),
(False, True),
(False, False)
]:
result = self.run_queries(['''
select {
some := 'free shape'
}
'''], inject_typenames=inject_typenames, json_output=json_output)

self.assert_data_shape(result, {
'kind': 'results',
'results': [{
'kind': 'data',
'data': [str, str, str, str]
}]
})

results.append(result)

# Ideally we'd check the decoded data has/hasn't got the injected
# typeids/names but currently we can't decode the result, so just
# check the expected number of bytes were returned
self.assertEqual(
[
len(base64.b64decode(result['results'][0]['data'][2]))
for result in results
],
[80, 80, 33, 36, 80, 36, 80, 36, 33]
)

# JSON results should be encoded as str which has a stable type id
self.assertEqual(
[
result['results'][0]['data'][0] == 'AAAAAAAAAAAAAAAAAAABAQ=='
for result in results
],
[False, False, False, True, False, True, False, True, False]
)