Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ example, just like with `requests`, a tuple can be passed to the `auth` argument
from marklogic.client import Client
client = Client('http://localhost:8030', auth=('username', 'password'))

### MarkLogic Cloud Authentication

When connecting to a [MarkLogic Cloud instance](https://developer.marklogic.com/products/cloud/), you will need to set
the `cloud_api_key` and `base_path` arguments. You only need to specify a `host` as well, as port 443 and HTTPS will be
used by default. For example:

from marklogic.client import Client
client = Client(host='example.marklogic.cloud', cloud_api_key='some-key-value', base_path='/ml/example/manage')

You may still use a full base URL if you wish:

from marklogic.client import Client
client = Client('https://example.marklogic.cloud', cloud_api_key='some-key-value', base_path='/ml/example/manage')


## SSL

Configuring SSL connections is the same as
Expand Down
26 changes: 20 additions & 6 deletions marklogic/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import requests
from marklogic.cloud_auth import MarkLogicCloudAuth
from requests.auth import HTTPDigestAuth
from urllib.parse import urljoin

Expand All @@ -7,6 +8,7 @@ class Client(requests.Session):
def __init__(
self,
base_url: str = None,
base_path: str = None,
auth=None,
digest=None,
scheme: str = "http",
Expand All @@ -15,17 +17,25 @@ def __init__(
port: int = 0,
username: str = None,
password: str = None,
cloud_api_key: str = None,
):
if base_url:
self.base_url = base_url
else:
self.base_url = f"{scheme}://{host}:{port}"
super(Client, self).__init__()
self.verify = verify

if cloud_api_key:
port = 443 if port == 0 else port
scheme = "https"

self.base_url = base_url if base_url else f"{scheme}://{host}:{port}"
if base_path:
self.base_path = base_path if base_path.endswith("/") else base_path + "/"

if auth:
self.auth = auth
elif digest:
self.auth = HTTPDigestAuth(digest[0], digest[1])
elif cloud_api_key:
self.auth = MarkLogicCloudAuth(self.base_url, cloud_api_key, self.verify)
else:
self.auth = HTTPDigestAuth(username, password)

Expand All @@ -34,15 +44,19 @@ def request(self, method, url, *args, **kwargs):
Overrides the requests function to generate the complete URL before the request
is sent.
"""
url = urljoin(self.base_url, url)
if hasattr(self, "base_path"):
if url.startswith("/"):
url = url[1:]
url = self.base_path + url
return super(Client, self).request(method, url, *args, **kwargs)

def prepare_request(self, request, *args, **kwargs):
"""
Overrides the requests function to generate the complete URL before the
request is prepared. See
https://requests.readthedocs.io/en/latest/user/advanced/#prepared-requests for
more information on prepared requests.
more information on prepared requests. Note that this is invoked after the
'request' method is invoked.
"""
request.url = urljoin(self.base_url, request.url)
return super(Client, self).prepare_request(request, *args, **kwargs)
31 changes: 31 additions & 0 deletions marklogic/cloud_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from urllib.parse import urljoin

import requests
from requests.auth import AuthBase

# See https://requests.readthedocs.io/en/latest/user/advanced/#custom-authentication


class MarkLogicCloudAuth(AuthBase):
def __init__(self, base_url: str, api_key: str, verify):
self._base_url = base_url
self._verify = verify
self._generate_token(api_key)

def _generate_token(self, api_key: str):
response = requests.post(
urljoin(self._base_url, "/token"),
data={"grant_type": "apikey", "key": api_key},
verify=self._verify,
)

if response.status_code != 200:
message = f"Unable to generate token; status code: {response.status_code}"
message = f"{message}; cause: {response.text}"
raise ValueError(message)

self._access_token = response.json()["access_token"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need any error handling here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I forget if a 4xx gets thrown right away - I'll add a test case for it.


def __call__(self, r):
r.headers["Authorization"] = f"Bearer {self._access_token}"
return r
22 changes: 19 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,27 @@ def basic_client():

@pytest.fixture
def ssl_client():
return Client(host="localhost", scheme="https", port=8031,
digest=("python-test-user", "password"),
verify=False)
return Client(
host="localhost",
scheme="https",
port=8031,
digest=("python-test-user", "password"),
verify=False,
)


@pytest.fixture
def client_with_props():
return Client(host="localhost", port=8030, username="admin", password="admin")


@pytest.fixture
def cloud_config():
"""
To run the tests in test_cloud.py, set 'key' to a valid API key. Otherwise, each
test will be skipped.
"""
return {
"host": "support.test.marklogic.cloud",
"key": "changeme",
}
85 changes: 85 additions & 0 deletions tests/test_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest

from marklogic.client import Client

"""
This module is intended for manual testing where the cloud_config fixture
in conftest.py is modified to have a real API key and not "changeme" as a value.
"""

DEFAULT_BASE_PATH = "/ml/test/marklogic/manage"


def test_base_path_doesnt_end_with_slash(cloud_config):
if cloud_config["key"] == "changeme":
return

client = _new_client(cloud_config, DEFAULT_BASE_PATH)
_verify_client_works(client)


def test_base_path_ends_with_slash(cloud_config):
if cloud_config["key"] == "changeme":
return

client = _new_client(cloud_config, DEFAULT_BASE_PATH + "/")
_verify_client_works(client)


def test_base_url_used_instead_of_host(cloud_config):
if cloud_config["key"] == "changeme":
return

base_url = f"https://{cloud_config['host']}"
client = Client(
base_url, cloud_api_key=cloud_config["key"], base_path=DEFAULT_BASE_PATH
)
_verify_client_works(client)


def test_invalid_host():
with pytest.raises(ValueError) as err:
Client(
host="marklogic.com",
cloud_api_key="doesnt-matter-for-this-test",
base_path=DEFAULT_BASE_PATH,
)
assert str(err.value).startswith(
"Unable to generate token; status code: 403; cause: "
)


def test_invalid_api_key(cloud_config):
if cloud_config["key"] == "changeme":
return

with pytest.raises(ValueError) as err:
Client(
host=cloud_config["host"],
cloud_api_key="invalid-api-key",
base_path=DEFAULT_BASE_PATH,
)
assert (
'Unable to generate token; status code: 401; cause: {"statusCode":401,"errorMessage":"API Key is not valid."}'
== str(err.value)
)


def _new_client(cloud_config, base_path: str) -> Client:
return Client(
host=cloud_config["host"],
cloud_api_key=cloud_config["key"],
base_path=base_path,
)


def _verify_client_works(client):
# Verify that the request works regardless of whether the path starts with a slash
# or not.
_verify_search_response(client.get("v1/search?format=json"))
_verify_search_response(client.get("/v1/search?format=json"))


def _verify_search_response(response):
assert 200 == response.status_code
assert 1 == response.json()["start"]