From 089303e6dfcc3ffc453ca33b4f8672209ddc62bf Mon Sep 17 00:00:00 2001 From: James Clarke Date: Thu, 13 Jul 2023 11:54:40 +0100 Subject: [PATCH 1/5] Add json params to notebook protocol --- edb/server/compiler/compiler.py | 2 +- edb/server/protocol/edgeql_ext.pyx | 2 +- edb/server/protocol/execute.pyx | 21 ++++++++++++++++----- edb/server/protocol/notebook_ext.pyx | 26 +++++++++++++++++++------- 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 9f231667d08..8e3e07b5118 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -481,7 +481,7 @@ def compile_notebook( implicit_limit=implicit_limit, inline_typeids=False, inline_typenames=True, - json_parameters=False, + json_parameters=True, source=source, protocol_version=protocol_version, notebook=True, diff --git a/edb/server/protocol/edgeql_ext.pyx b/edb/server/protocol/edgeql_ext.pyx index d6128864f5a..0c831df5278 100644 --- a/edb/server/protocol/edgeql_ext.pyx +++ b/edb/server/protocol/edgeql_ext.pyx @@ -61,7 +61,7 @@ async def handle_request( try: if request.method == b'POST': if request.content_type and b'json' in request.content_type: - body = json.loads(request.body) + body = json.loads(request.body, parse_float=decimal.Decimal) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 5f9b7b40e1f..67f7972a642 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -533,16 +533,27 @@ async def execute_json( return None +class DecimalEncoder(json.JSONEncoder): + def encode(self, obj): + if isinstance(obj, dict): + return '{' + ', '.join( + f'{self.encode(k)}: {self.encode(v)}' + for (k, v) in obj.items() + ) + '}' + if isinstance(obj, list): + return '[' + ', '.join(map(self.encode, obj)) + ']' + if isinstance(obj, decimal.Decimal): + return f'{obj:f}' + return super().encode(obj) + + cdef bytes _encode_json_value(object val): - if isinstance(val, decimal.Decimal): - jarg = str(val) - else: - jarg = json.dumps(val) + jarg = json.dumps(val, cls=DecimalEncoder) return b'\x01' + jarg.encode('utf-8') -cdef bytes _encode_args(list args): +def _encode_args(list args) -> bytes: cdef: WriteBuffer out_buf = WriteBuffer.new() diff --git a/edb/server/protocol/notebook_ext.pyx b/edb/server/protocol/notebook_ext.pyx index 67ee37b8692..6dfc3d282c5 100644 --- a/edb/server/protocol/notebook_ext.pyx +++ b/edb/server/protocol/notebook_ext.pyx @@ -21,6 +21,7 @@ import base64 import http import json import urllib.parse +import decimal import immutables @@ -46,7 +47,7 @@ from edb.server.pgproto.pgproto cimport ( include "./consts.pxi" -cdef tuple CURRENT_PROTOCOL = edbdef.CURRENT_PROTOCOL +cdef tuple CURRENT_PROTOCOL = (1, 0) ALLOWED_CAPABILITIES = ( enums.Capability.MODIFICATIONS | @@ -101,14 +102,17 @@ async def handle_request( return handle_error(request, response, ex) queries = None + params = None try: if request.method == b'POST': - body = json.loads(request.body) + body = json.loads(request.body, parse_float=decimal.Decimal) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') queries = body.get('queries') + allow_params = body.get('allow_params') + params = body.get('params') else: raise TypeError('expected a POST request') @@ -122,7 +126,7 @@ async def handle_request( response.status = http.HTTPStatus.OK try: - result = await execute(db, server, queries) + result = await execute(db, server, queries, allow_params, params or {}) except Exception as ex: return handle_error(request, response, ex) else: @@ -153,11 +157,11 @@ cdef class NotebookConnection(frontend.AbstractFrontendConnection): pass -async def execute(db, server, queries: list): +async def execute(db, server, queries: list, allow_params, params): dbv: dbview.DatabaseConnectionView = await server.new_dbview( dbname=db.name, query_cache=False, - protocol_version=edbdef.CURRENT_PROTOCOL, + protocol_version=CURRENT_PROTOCOL, ) compiler_pool = server.get_compiler_pool() units = await compiler_pool.compile_notebook( @@ -195,10 +199,18 @@ async def execute(db, server, queries: list): "disallowed in notebook", ) try: - if query_unit.in_type_args: + if not allow_params and query_unit.in_type_args: raise errors.QueryError( 'cannot use query parameters in tutorial') + args = [] + if query_unit.in_type_args: + for param in query_unit.in_type_args: + value = params.get(param.name) + args.append(value) + + bind_args = p_execute._encode_args(args) + fe_conn = NotebookConnection() dbv.start_implicit(query_unit) @@ -206,7 +218,7 @@ async def execute(db, server, queries: list): compiled = dbview.CompiledQuery( query_unit_group=query_unit_group) await p_execute.execute( - pgcon, dbv, compiled, b'', fe_conn=fe_conn, + pgcon, dbv, compiled, bind_args, fe_conn=fe_conn, skip_start=True, ) From b073b51c410d319c2b75a5aa4505e09c34b37a33 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Mon, 24 Jul 2023 16:29:37 +0100 Subject: [PATCH 2/5] Add parse endpoint + update to use binary encoded query params --- edb/server/compiler/compiler.py | 10 ++-- edb/server/protocol/notebook_ext.pyx | 88 +++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 20 deletions(-) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 8e3e07b5118..af291c70855 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -443,13 +443,14 @@ def compile_notebook( queries: List[str], protocol_version: defines.ProtocolVersion, implicit_limit: int = 0, + inject_implicit_typenames: bool = True, + output_format: enums.OutputFormat = enums.OutputFormat.BINARY, ) -> List[ Tuple[ bool, Union[dbstate.QueryUnit, Tuple[str, str, Dict[int, str]]] ] ]: - state = dbstate.CompilerConnectionState( user_schema=user_schema, global_schema=global_schema, @@ -476,12 +477,13 @@ def compile_notebook( ctx = CompileContext( compiler_state=self.state, state=state, - output_format=enums.OutputFormat.BINARY, + output_format=output_format, expected_cardinality_one=False, implicit_limit=implicit_limit, inline_typeids=False, - inline_typenames=True, - json_parameters=True, + inline_typenames=inject_implicit_typenames, + inline_objectids=inject_implicit_typenames, + json_parameters=False, source=source, protocol_version=protocol_version, notebook=True, diff --git a/edb/server/protocol/notebook_ext.pyx b/edb/server/protocol/notebook_ext.pyx index 6dfc3d282c5..0df28d2be7d 100644 --- a/edb/server/protocol/notebook_ext.pyx +++ b/edb/server/protocol/notebook_ext.pyx @@ -21,7 +21,6 @@ import base64 import http import json import urllib.parse -import decimal import immutables @@ -97,6 +96,57 @@ async def handle_request( response.body = b'{"kind": "status", "status": "OK"}' return + if args == ['parse'] and request.method == b'POST': + try: + body = json.loads(request.body) + if not isinstance(body, dict): + raise TypeError( + 'the body of the request must be a JSON object') + query = body.get('query') + inject_typenames = body.get('inject_typenames') + json_output = body.get('json_output') + + if not query: + raise TypeError( + 'invalid notebook parse request: "query" is missing') + + _dbv, units = await parse( + db, server, [query], inject_typenames, json_output) + dbv: dbview.DatabaseConnectionView = _dbv + is_error, unit_or_error = units[0] + if is_error: + response.status = http.HTTPStatus.OK + response.body = json.dumps({ + 'kind': 'error', + 'error': { + 'message': unit_or_error[1], + 'type': unit_or_error[0], + } + }).encode() + return + else: + query_unit = unit_or_error + dbv.check_capabilities( + query_unit.capabilities, + ALLOWED_CAPABILITIES, + errors.UnsupportedCapabilityError, + "disallowed in notebook", + ) + in_type_id = base64.b64encode(query_unit.in_type_id) + in_type = base64.b64encode(query_unit.in_type_data) + + except Exception as ex: + return handle_error(request, response, ex) + else: + response.status = http.HTTPStatus.OK + response.custom_headers['EdgeDB-Protocol-Version'] = \ + f'{CURRENT_PROTOCOL[0]}.{CURRENT_PROTOCOL[1]}' + response.body = ( + b'{"kind": "parse_result", "in_type_id": "' + in_type_id + + b'", "in_type": "' + in_type + b'"}') + return + + if args != []: ex = Exception(f'Unknown path') return handle_error(request, response, ex) @@ -106,13 +156,14 @@ async def handle_request( try: if request.method == b'POST': - body = json.loads(request.body, parse_float=decimal.Decimal) + body = json.loads(request.body) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') queries = body.get('queries') - allow_params = body.get('allow_params') params = body.get('params') + inject_typenames = body.get('inject_typenames') + json_output = body.get('json_output') else: raise TypeError('expected a POST request') @@ -126,7 +177,10 @@ async def handle_request( response.status = http.HTTPStatus.OK try: - result = await execute(db, server, queries, allow_params, params or {}) + result = await execute( + db, server, queries, + base64.b64decode(params) if params else None, + inject_typenames, json_output) except Exception as ex: return handle_error(request, response, ex) else: @@ -157,7 +211,7 @@ cdef class NotebookConnection(frontend.AbstractFrontendConnection): pass -async def execute(db, server, queries: list, allow_params, params): +async def parse(db, server, queries, inject_typenames, json_output): dbv: dbview.DatabaseConnectionView = await server.new_dbview( dbname=db.name, query_cache=False, @@ -174,9 +228,19 @@ async def execute(db, server, queries: list, allow_params, params): queries, CURRENT_PROTOCOL, 50, # implicit limit + inject_typenames, + enums.OutputFormat.JSON if json_output else enums.OutputFormat.BINARY, ) + return dbv, units + + +async def execute( + db, server, queries: list, params, inject_typenames, json_output +): + _dbv, units = await parse( + db, server, queries, inject_typenames, json_output) + dbv: dbview.DatabaseConnectionView = _dbv result = [] - bind_data = None pgcon = await server.acquire_pgcon(db.name) try: await pgcon.sql_execute(b'START TRANSACTION;') @@ -199,18 +263,10 @@ async def execute(db, server, queries: list, allow_params, params): "disallowed in notebook", ) try: - if not allow_params and query_unit.in_type_args: + if not params and query_unit.in_type_args: raise errors.QueryError( 'cannot use query parameters in tutorial') - args = [] - if query_unit.in_type_args: - for param in query_unit.in_type_args: - value = params.get(param.name) - args.append(value) - - bind_args = p_execute._encode_args(args) - fe_conn = NotebookConnection() dbv.start_implicit(query_unit) @@ -218,7 +274,7 @@ async def execute(db, server, queries: list, allow_params, params): compiled = dbview.CompiledQuery( query_unit_group=query_unit_group) await p_execute.execute( - pgcon, dbv, compiled, bind_args, fe_conn=fe_conn, + pgcon, dbv, compiled, params or b'', fe_conn=fe_conn, skip_start=True, ) From 5c89787a0512b137b873e7d8a01c978c3c7965f0 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Wed, 26 Jul 2023 14:02:39 +0100 Subject: [PATCH 3/5] Add tests + fix bug --- edb/server/protocol/notebook_ext.pyx | 2 +- tests/test_http_notebook.py | 137 ++++++++++++++++++++++++++- 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/edb/server/protocol/notebook_ext.pyx b/edb/server/protocol/notebook_ext.pyx index 0df28d2be7d..ff0676c1c0c 100644 --- a/edb/server/protocol/notebook_ext.pyx +++ b/edb/server/protocol/notebook_ext.pyx @@ -228,7 +228,7 @@ async def parse(db, server, queries, inject_typenames, json_output): queries, CURRENT_PROTOCOL, 50, # implicit limit - inject_typenames, + inject_typenames if inject_typenames is not None else True, enums.OutputFormat.JSON if json_output else enums.OutputFormat.BINARY, ) return dbv, units diff --git a/tests/test_http_notebook.py b/tests/test_http_notebook.py index e21008c8d6d..f088105096c 100644 --- a/tests/test_http_notebook.py +++ b/tests/test_http_notebook.py @@ -21,6 +21,7 @@ import json import urllib +import base64 from edb.testbase import http as tb @@ -34,11 +35,25 @@ class TestHttpNotebook(tb.BaseHttpExtensionTest): def get_extension_name(cls): return 'notebook' - def run_queries(self, queries: List[str]): - req_data = { + def run_queries( + self, + queries: List[str], + params: Optional[str] = None, + *, + inject_typenames: Optional[bool] = None, + json_output: Optional[bool] = None + ): + req_data: dict[str, Any] = { 'queries': queries } + if params is not None: + req_data['params'] = params + if inject_typenames is not None: + req_data['inject_typenames'] = inject_typenames + if json_output is not None: + req_data['json_output'] = json_output + req = urllib.request.Request( self.http_addr, method='POST') # type: ignore req.add_header('Content-Type', 'application/json') @@ -51,6 +66,22 @@ def run_queries(self, queries: List[str]): resp_data = json.loads(response.read()) return resp_data + def parse_query(self, query: str): + req = urllib.request.Request( + self.http_addr + '/parse', method='POST') # type: ignore + req.add_header('Content-Type', 'application/json') + response = urllib.request.urlopen( + req, json.dumps({'query': query}).encode(), + context=self.tls_context + ) + + resp_data = json.loads(response.read()) + + if resp_data['kind'] != 'error': + self.assertIsNotNone(response.headers['EdgeDB-Protocol-Version']) + + return resp_data + def test_http_notebook_01(self): results = self.run_queries([ 'SELECT 1', @@ -273,3 +304,105 @@ def test_http_notebook_09(self): ]) self.assertNotIn('error', results['results'][0]) + + def test_http_notebook_10(self): + # Check that if no 'params' field is sent an error is still thrown + # when query contains params, to maintain behaviour of edgeql tutorial + results = self.run_queries(['select $test']) + + self.assert_data_shape(results, { + 'kind': 'results', + 'results': [{ + 'kind': 'error', + 'error': [ + 'QueryError', + 'cannot use query parameters in tutorial', + {} + ] + }] + }) + + def test_http_notebook_11(self): + results = self.run_queries(['select $test'], + 'AAAAAQAAAAAAAAAIdGVzdCBzdHI=') + + self.assert_data_shape(results, { + 'kind': 'results', + 'results': [{ + 'kind': 'data', + 'data': [ + str, + str, + 'RAAAABIAAQAAAAh0ZXN0IHN0cg==', + str + ] + }] + }) + + def test_http_notebook_12(self): + result = self.parse_query('select >$test_param') + + self.assert_data_shape(result, { + 'kind': 'parse_result', + 'in_type_id': str, + 'in_type': str + }); + + error_result = self.parse_query('select $invalid') + + self.assert_data_shape(error_result, { + 'kind': 'error', + 'error': { + 'type': 'QueryError', + 'message': 'missing a type cast before the parameter' + } + }) + + def test_http_notebook_13(self): + results = [] + for inject_typenames, json_output in [ + (None, None), + (True, None), + (False, None), + (None, True), + (None, False), + (True, True), + (True, False), + (False, True), + (False, False) + ]: + result = self.run_queries([''' + select { + some := 'free shape' + } + '''], inject_typenames=inject_typenames, json_output=json_output) + + self.assert_data_shape(result, { + 'kind': 'results', + 'results': [{ + 'kind': 'data', + 'data': [str, str, str, str] + }] + }) + + results.append(result) + + # Ideally we'd check the decoded data has/hasn't got the injected + # typeids/names but currently we can't decode the result, so just + # check the expected number of bytes were returned + self.assertEqual( + [ + len(base64.b64decode(result['results'][0]['data'][2])) + for result in results + ], + [80, 80, 33, 36, 80, 36, 80, 36, 33] + ) + + # JSON results should be encoded as str which has a stable type id + self.assertEqual( + [ + result['results'][0]['data'][0] == 'AAAAAAAAAAAAAAAAAAABAQ==' + for result in results + ], + [False, False, False, True, False, True, False, True, False] + ) From 268b882ce730c6158bde9de2378ea498b77f2d61 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Wed, 26 Jul 2023 14:10:28 +0100 Subject: [PATCH 4/5] Clean up old json params handling --- edb/server/protocol/edgeql_ext.pyx | 2 +- edb/server/protocol/execute.pyx | 21 +++++---------------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/edb/server/protocol/edgeql_ext.pyx b/edb/server/protocol/edgeql_ext.pyx index 0c831df5278..d6128864f5a 100644 --- a/edb/server/protocol/edgeql_ext.pyx +++ b/edb/server/protocol/edgeql_ext.pyx @@ -61,7 +61,7 @@ async def handle_request( try: if request.method == b'POST': if request.content_type and b'json' in request.content_type: - body = json.loads(request.body, parse_float=decimal.Decimal) + body = json.loads(request.body) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 67f7972a642..5f9b7b40e1f 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -533,27 +533,16 @@ async def execute_json( return None -class DecimalEncoder(json.JSONEncoder): - def encode(self, obj): - if isinstance(obj, dict): - return '{' + ', '.join( - f'{self.encode(k)}: {self.encode(v)}' - for (k, v) in obj.items() - ) + '}' - if isinstance(obj, list): - return '[' + ', '.join(map(self.encode, obj)) + ']' - if isinstance(obj, decimal.Decimal): - return f'{obj:f}' - return super().encode(obj) - - cdef bytes _encode_json_value(object val): - jarg = json.dumps(val, cls=DecimalEncoder) + if isinstance(val, decimal.Decimal): + jarg = str(val) + else: + jarg = json.dumps(val) return b'\x01' + jarg.encode('utf-8') -def _encode_args(list args) -> bytes: +cdef bytes _encode_args(list args): cdef: WriteBuffer out_buf = WriteBuffer.new() From 61aa7665982815e6adebdbe97ad4e0996b434e55 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Wed, 26 Jul 2023 15:59:38 +0100 Subject: [PATCH 5/5] Fix lint --- tests/test_http_notebook.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_http_notebook.py b/tests/test_http_notebook.py index f088105096c..07bb20406dc 100644 --- a/tests/test_http_notebook.py +++ b/tests/test_http_notebook.py @@ -36,13 +36,13 @@ def get_extension_name(cls): return 'notebook' def run_queries( - self, - queries: List[str], - params: Optional[str] = None, - *, - inject_typenames: Optional[bool] = None, - json_output: Optional[bool] = None - ): + self, + queries: List[str], + params: Optional[str] = None, + *, + inject_typenames: Optional[bool] = None, + json_output: Optional[bool] = None + ): req_data: dict[str, Any] = { 'queries': queries } @@ -346,7 +346,7 @@ def test_http_notebook_12(self): 'kind': 'parse_result', 'in_type_id': str, 'in_type': str - }); + }) error_result = self.parse_query('select $invalid')