diff --git a/marklogic/client.py b/marklogic/client.py index 6c6f912..61e23ab 100644 --- a/marklogic/client.py +++ b/marklogic/client.py @@ -3,6 +3,7 @@ from marklogic.documents import DocumentManager from marklogic.rows import RowManager from marklogic.transactions import TransactionManager +from marklogic.eval import EvalManager from requests.auth import HTTPDigestAuth from urllib.parse import urljoin @@ -84,3 +85,9 @@ def transactions(self): if not hasattr(self, "_transactions"): self._transactions = TransactionManager(self) return self._transactions + + @property + def eval(self): + if not hasattr(self, "_eval"): + self._eval = EvalManager(self) + return self._eval diff --git a/marklogic/eval.py b/marklogic/eval.py new file mode 100644 index 0000000..07f741a --- /dev/null +++ b/marklogic/eval.py @@ -0,0 +1,138 @@ +import json + +from decimal import Decimal +from marklogic.documents import Document +from requests import Session +from requests_toolbelt.multipart.decoder import MultipartDecoder + +""" +Defines an EvalManager class to simplify usage of the "/v1/eval" REST +endpoint defined at https://docs.marklogic.com/REST/POST/v1/eval. +""" + + +class EvalManager: + """ + Provides a method to simplify sending an XQuery or + JavaScript eval request to the eval endpoint. + """ + + def __init__(self, session: Session): + self._session = session + + def xquery( + self, xquery: str, vars: dict = None, return_response: bool = False, **kwargs + ): + """ + Send an XQuery script to MarkLogic via a POST to the endpoint + defined at https://docs.marklogic.com/REST/POST/v1/eval. + + :param xquery: an XQuery string + :param vars: a dict containing variables to include + :param return_response: boolean specifying if the entire original response + object should be returned (True) or if only the data should be returned (False) + upon a success (2xx) response. Note that if the status code of the response is + not 2xx, then the entire response is always returned. + """ + if xquery is None: + raise ValueError("No script found; must specify a xquery") + return self.__send_request({"xquery": xquery}, vars, return_response, **kwargs) + + def javascript( + self, + javascript: str, + vars: dict = None, + return_response: bool = False, + **kwargs + ): + """ + Send a JavaScript script to MarkLogic via a POST to the endpoint + defined at https://docs.marklogic.com/REST/POST/v1/eval. + + :param javascript: a JavaScript string + :param vars: a dict containing variables to include + :param return_response: boolean specifying if the entire original response + object should be returned (True) or if only the data should be returned (False) + upon a success (2xx) response. Note that if the status code of the response is + not 2xx, then the entire response is always returned. + """ + if javascript is None: + raise ValueError("No script found; must specify a javascript") + return self.__send_request( + {"javascript": javascript}, vars, return_response, **kwargs + ) + + def __send_request( + self, data: dict, vars: dict = None, return_response: bool = False, **kwargs + ): + """ + Send a script (XQuery or javascript) and possibly a dict of vars + to MarkLogic via a POST to the endpoint defined at + https://docs.marklogic.com/REST/POST/v1/eval. + """ + if vars is not None: + data["vars"] = json.dumps(vars) + response = self._session.post("v1/eval", data=data, **kwargs) + return ( + self.__process_response(response) + if response.status_code == 200 and not return_response + else response + ) + + def __process_response(self, response): + """ + Process a multipart REST response by putting them in a list and + transforming each part based on the "X-Primitive" header. + """ + if "Content-Length" in response.headers: + return None + + parts = MultipartDecoder.from_response(response).parts + transformed_parts = [] + for part in parts: + encoding = part.encoding + primitive_header = part.headers["X-Primitive".encode(encoding)].decode( + encoding + ) + primitive_function = EvalManager.__primitive_value_converters.get( + primitive_header + ) + if primitive_function is not None: + transformed_parts.append(primitive_function(part)) + else: + transformed_parts.append(part.text) + return transformed_parts + + __primitive_value_converters = { + "integer": lambda part: int(part.text), + "decimal": lambda part: Decimal(part.text), + "boolean": lambda part: ("False" == part.text), + "string": lambda part: part.text, + "map": lambda part: json.loads(part.text), + "element()": lambda part: part.text, + "array": lambda part: json.loads(part.text), + "array-node()": lambda part: json.loads(part.text), + "object-node()": lambda part: EvalManager.__process_object_node_part(part), + "document-node()": lambda part: EvalManager.__process_document_node_part(part), + "binary()": lambda part: Document( + EvalManager.__get_decoded_uri_from_part(part), part.content + ), + } + + def __get_decoded_uri_from_part(part): + encoding = part.encoding + return part.headers["X-URI".encode(encoding)].decode(encoding) + + def __process_object_node_part(part): + if b"X-URI" in part.headers: + return Document( + EvalManager.__get_decoded_uri_from_part(part), json.loads(part.text) + ) + else: + return json.loads(part.text) + + def __process_document_node_part(part): + if b"X-URI" in part.headers: + return Document(EvalManager.__get_decoded_uri_from_part(part), part.text) + else: + return part.text diff --git a/marklogic/rows.py b/marklogic/rows.py index c4fd7fa..1b3abb8 100644 --- a/marklogic/rows.py +++ b/marklogic/rows.py @@ -3,24 +3,23 @@ """ Defines a RowManager class to simplify usage of the "/v1/rows" & "/v1/rows/graphql" REST -endpoints defined at https://docs.marklogic.com/REST/POST/v1/rows/graphql +endpoints defined at https://docs.marklogic.com/REST/POST/v1/rows/graphql. """ class RowManager: """ - Provides a method to simplify sending a GraphQL request to the GraphQL rows endpoint. + Provides a method to simplify sending a GraphQL + request to the GraphQL rows endpoint. """ def __init__(self, session: Session): self._session = session - def graphql( - self, graphql_query: str, return_response: bool = False, *args, **kwargs - ): + def graphql(self, graphql_query: str, return_response: bool = False, **kwargs): """ Send a GraphQL query to MarkLogic via a POST to the endpoint defined at - https://docs.marklogic.com/REST/POST/v1/rows/graphql + https://docs.marklogic.com/REST/POST/v1/rows/graphql. :param graphql_query: a GraphQL query string. Note - this is the query string only, not the entire query JSON object. See the following for more information: @@ -69,18 +68,17 @@ def query( sparql: str = None, format: str = "json", return_response: bool = False, - *args, **kwargs ): """ Send a query to MarkLogic via a POST to the endpoint defined at - https://docs.marklogic.com/REST/POST/v1/rows + https://docs.marklogic.com/REST/POST/v1/rows. Just like that endpoint, this function can be used for four different types of queries: Optic DSL, Serialized Optic, SQL, and SPARQL. The type of query processed by the function is dependent upon the parameter used in the call to the function. For more information about Optic and using the Optic DSL, SQL, and SPARQL, - see https://docs.marklogic.com/guide/app-dev/OpticAPI + see https://docs.marklogic.com/guide/app-dev/OpticAPI. If multiple query parameters are passed into the call, the function uses the query parameter that is first in the list: dsl, plan, sql, sparql. diff --git a/test-app/src/main/ml-data/musicians/logo.png b/test-app/src/main/ml-data/musicians/logo.png new file mode 100644 index 0000000..c2a0d50 Binary files /dev/null and b/test-app/src/main/ml-data/musicians/logo.png differ diff --git a/tests/test_eval.py b/tests/test_eval.py index ba2c634..bc27f10 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -1,24 +1,133 @@ +import decimal + +from marklogic.documents import Document from requests_toolbelt.multipart.decoder import MultipartDecoder +from pytest import raises -def test_eval(client): - """ - This shows how a user would do an eval today. It's a good example of how a multipart/mixed - response is a little annoying to deal with, as it requires using the requests_toolbelt - library and a class called MultipartDecoder. +def test_xquery_common_primitives(client): + parts = client.eval.xquery( + """( + 'A', 1, 1.1, fn:false(), fn:doc('/musicians/logo.png')) + """ + ) + __verify_common_primitives(parts) - Client support for this might look like this: - response = client.eval.xquery("world") - And then it's debatable whether we want to do anything beyond what MultipartDecoder - is doing for handling the response. - """ - response = client.post( - "v1/eval", - headers={"Content-type": "application/x-www-form-urlencoded"}, - data={"xquery": "world"}, +def test_javascript_common_primitives(client): + parts = client.eval.javascript( + """xdmp.arrayValues([ + 'A', 1, 1.1, false, fn.doc('/musicians/logo.png') + ])""" + ) + __verify_common_primitives(parts) + + +def test_xquery_specific_primitives(client): + parts = client.eval.xquery( + """( + world, + object-node {'A': 'a'}, + fn:doc('/doc2.xml'), + document {}, + array-node {1, "23", 4} + )""" ) + assert type(parts[0]) is str + assert "world" == parts[0] + assert type(parts[1]) is dict + assert {"A": "a"} == parts[1] + assert type(parts[2]) is Document + assert "/doc2.xml" == parts[2].uri + assert "world" in parts[2].content + assert type(parts[3]) is str + assert '\n' == parts[3] + assert type(parts[4]) is list + assert "23" == parts[4][1] + assert 3 == len(parts[4]) + + +def test_javascript_specific_primitives(client): + parts = client.eval.javascript( + """xdmp.arrayValues([ + {'A': 'a'}, + ['Z', 'Y', 1], + fn.head(cts.search('Armstrong')) + ])""" + ) + assert type(parts[0]) is dict + assert {"A": "a"} == parts[0] + assert type(parts[1]) is list + assert "Z" == parts[1][0] + assert 3 == len(parts[1]) + assert type(parts[2]) is Document + assert "/musicians/musician1.json" == parts[2].uri + assert { + "musician": { + "lastName": "Armstrong", + "firstName": "Louis", + "dob": "1901-08-04", + "instrument": ["trumpet", "vocal"], + } + } == parts[2].content + + +def test_javascript_noquery(client): + with raises(ValueError, match="No script found; must specify a javascript"): + client.eval.javascript(None) + + +def test_xquery_noquery(client): + with raises(ValueError, match="No script found; must specify a xquery"): + client.eval.xquery(None) + + +def test_xquery_with_return_response(client): + response = client.eval.xquery("('A', 1, 1.1, fn:false())", return_response=True) + assert 200 == response.status_code + parts = MultipartDecoder.from_response(response).parts + assert 4 == len(parts) + + +def test_xquery_vars(client): + vars = {"word1": "hello", "word2": "world"} + script = """ + xquery version "1.0-ml"; + declare variable $word1 as xs:string external; + declare variable $word2 as xs:string external; + fn:concat($word1, " ", $word2) + """ + parts = client.eval.xquery(script, vars) + assert type(parts[0]) is str + assert "hello world" == parts[0] + + +def test_javascript_vars(client): + vars = {"word1": "hello", "word2": "world"} + parts = client.eval.javascript("xdmp.arrayValues([word1, word2])", vars) + assert type(parts[0]) is str + assert "hello" == parts[0] + + +def test_xquery_empty_sequence(client): + parts = client.eval.xquery("()") + assert parts is None + + +def test_javascript_script(client): + parts = client.eval.javascript("[]") + assert [[]] == parts + - decoder = MultipartDecoder.from_response(response) - content = decoder.parts[0].text - assert "world" == content +def __verify_common_primitives(parts): + assert type(parts[0]) is str + assert "A" == parts[0] + assert type(parts[1]) is int + assert 1 == parts[1] + assert type(parts[2]) is decimal.Decimal + assert decimal.Decimal("1.1") == parts[2] + assert type(parts[3]) is bool + assert parts[3] is False + assert type(parts[4]) is Document + assert "/musicians/logo.png" == parts[4].uri + assert b"PNG" in parts[4].content