Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions marklogic/rows.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from requests import Session, Response
from requests import Session

"""
Defines a RowManager class to simplify usage of the "/v1/rows" & "/v1/rows/graphql" REST
Expand All @@ -15,7 +15,9 @@ class RowManager:
def __init__(self, session: Session):
self._session = session

def graphql(self, graphql_query: str, return_response: bool = False, *args, **kwargs):
def graphql(
self, graphql_query: str, return_response: bool = False, *args, **kwargs
):
"""
Send a GraphQL query to MarkLogic via a POST to the endpoint defined at
https://docs.marklogic.com/REST/POST/v1/rows/graphql
Expand Down Expand Up @@ -48,18 +50,28 @@ def graphql(self, graphql_query: str, return_response: bool = False, *args, **kw
"xml": "application/xml",
"csv": "text/csv",
"json-seq": "application/json-seq",
"mixed": "application/xml, multipart/mixed"
"mixed": "application/xml, multipart/mixed",
}

__query_format_switch = {
"json": lambda response: response.json(),
"xml": lambda response: response.text,
"csv": lambda response: response.text,
"json-seq": lambda response: response.text,
"mixed": lambda response: response
"mixed": lambda response: response,
}

def query(self, dsl: str = None, plan: dict = None, sql: str = None, sparql: str = None, format: str = "json", return_response: bool = False, *args, **kwargs):
def query(
self,
dsl: str = None,
plan: dict = None,
sql: str = None,
sparql: str = None,
format: str = "json",
return_response: bool = False,
*args,
**kwargs
):
"""
Send a query to MarkLogic via a POST to the endpoint defined at
https://docs.marklogic.com/REST/POST/v1/rows
Expand All @@ -86,10 +98,7 @@ def query(self, dsl: str = None, plan: dict = None, sql: str = None, sparql: str
headers["Content-Type"] = request_info["content-type"]
headers["Accept"] = RowManager.__accept_switch.get(format)
response = self._session.post(
"v1/rows",
headers=headers,
data=request_info["data"],
**kwargs
"v1/rows", headers=headers, data=request_info["data"], **kwargs
)
return (
RowManager.__query_format_switch.get(format)(response)
Expand All @@ -111,22 +120,15 @@ def __get_request_info(self, dsl: str, plan: dict, sql: str, sparql: str):
if dsl is not None:
return {
"content-type": "application/vnd.marklogic.querydsl+javascript",
"data": dsl
"data": dsl,
}
if plan is not None:
return {
"content-type": "application/json",
"data": plan
}
return {"content-type": "application/json", "data": plan}
if sql is not None:
return {
"content-type": "application/sql",
"data": sql
}
return {"content-type": "application/sql", "data": sql}
if sparql is not None:
return {
"content-type": "application/sparql-query",
"data": sparql
}
return {"content-type": "application/sparql-query", "data": sparql}
else:
raise ValueError("No query found; must specify one of: dsl, plan, sql, or sparql")
raise ValueError(
"No query found; must specify one of: dsl, plan, sql, or sparql"
)
24 changes: 18 additions & 6 deletions tests/test_graphql.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
def test_graphql(client):
data = client.rows.graphql("query musicianQuery { test_musician { lastName firstName dob } }")
data = client.rows.graphql(
"query musicianQuery { test_musician { lastName firstName dob } }"
)
musicians = data["data"]["test_musician"]
assert 4 == len(musicians)
assert 1 == len([m for m in musicians if m["lastName"] == "Armstrong"])


def test_graphql_return_response(client):
response = client.rows.graphql("query musicianQuery { test_musician { lastName firstName dob } }", return_response=True)
response = client.rows.graphql(
"query musicianQuery { test_musician { lastName firstName dob } }",
return_response=True,
)
assert 200 == response.status_code
data = response.json()
musicians = data["data"]["test_musician"]
Expand All @@ -15,11 +20,18 @@ def test_graphql_return_response(client):


def test_graphql_bad_graphql(client):
response = client.rows.graphql("query musicianQuery { test_musician { lastName firstName dob } ")
assert 1 == len(response['errors'])
assert 'GRAPHQL-PARSE: Error parsing the GraphQL request string => \nquery musicianQuery { test_musician { lastName firstName dob } ' == response['errors'][0]['message']
response = client.rows.graphql(
"query musicianQuery { test_musician { lastName firstName dob } "
)
assert 1 == len(response["errors"])
assert (
"GRAPHQL-PARSE: Error parsing the GraphQL request string => \nquery musicianQuery { test_musician { lastName firstName dob } "
== response["errors"][0]["message"]
)


def test_graphql_bad_user(not_rest_user_client):
response = not_rest_user_client.rows.graphql("query musicianQuery { test_musician { lastName firstName dob } }")
response = not_rest_user_client.rows.graphql(
"query musicianQuery { test_musician { lastName firstName dob } }"
)
assert 403 == response.status_code
31 changes: 24 additions & 7 deletions tests/test_query.py → tests/test_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

dsl_query = 'op.fromView("test","musician").orderBy(op.col("lastName"))'
serialized_query = '{"$optic":{"ns":"op", "fn":"operators", "args":[{"ns":"op", "fn":"from-view", "args":["test", "musician"]}, {"ns":"op", "fn":"order-by", "args":[{"ns":"op", "fn":"col", "args":["lastName"]}]}]}}'
sql_query = 'select * from musician order by lastName'
sparql_query = 'PREFIX musician: <http://marklogic.com/column/test/musician/> SELECT * WHERE {?s musician:lastName ?lastName} ORDER BY ?lastName'
sql_query = "select * from musician order by lastName"
sparql_query = "PREFIX musician: <http://marklogic.com/column/test/musician/> SELECT * WHERE {?s musician:lastName ?lastName} ORDER BY ?lastName"


def test_dsl_default(client):
Expand All @@ -14,7 +14,9 @@ def test_dsl_default(client):
def test_dsl_default_return_response(client):
response = client.rows.query(dsl_query, return_response=True)
assert 200 == response.status_code
verify_four_musicians_are_returned_in_json(response.json(), "test.musician.lastName")
verify_four_musicians_are_returned_in_json(
response.json(), "test.musician.lastName"
)


def test_query_bad_user(not_rest_user_client):
Expand All @@ -31,17 +33,22 @@ def test_dsl_xml(client):
data = client.rows.query(dsl_query, format="xml")
verify_four_musicians_are_returned_in_xml_string(data)


def test_dsl_csv(client):
data = client.rows.query(dsl_query, format="csv")
verify_four_musicians_are_returned_in_csv(data)


def test_dsl_json_seq(client):
data = client.rows.query(dsl_query, format="json-seq")
verify_four_musicians_are_returned_in_json_seq(data)


def test_dsl_mixed(client):
response = client.rows.query(dsl_query, format="mixed")
verify_four_musicians_are_returned_in_json(response.json(), "test.musician.lastName")
verify_four_musicians_are_returned_in_json(
response.json(), "test.musician.lastName"
)


def test_serialized_default(client):
Expand All @@ -60,15 +67,20 @@ def test_sparql_default(client):


def test_no_query_parameter_provided(client):
with raises(ValueError, match="No query found; must specify one of: dsl, plan, sql, or sparql"):
with raises(
ValueError,
match="No query found; must specify one of: dsl, plan, sql, or sparql",
):
client.rows.query()


def verify_four_musicians_are_returned_in_json(data, column_name):
assert type(data) is dict
assert 4 == len(data["rows"])
for index, musician in enumerate(["Armstrong", "Byron", "Coltrane", "Davis"]):
assert {'type': 'xs:string', 'value': musician} == data["rows"][index][column_name]
assert {"type": "xs:string", "value": musician} == data["rows"][index][
column_name
]


def verify_four_musicians_are_returned_in_xml_string(data):
Expand All @@ -81,7 +93,12 @@ def verify_four_musicians_are_returned_in_xml_string(data):
def verify_four_musicians_are_returned_in_csv(data):
assert type(data) is str
assert 5 == len(data.split("\n"))
for musician in ['Armstrong,Louis,1901-08-04', 'Byron,Don,1958-11-08', 'Coltrane,John,1926-09-23', 'Davis,Miles,1926-05-26']:
for musician in [
"Armstrong,Louis,1901-08-04",
"Byron,Don,1958-11-08",
"Coltrane,John,1926-09-23",
"Davis,Miles,1926-05-26",
]:
assert musician in data


Expand Down