Skip to content

Commit

Permalink
fix and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
uniqueg committed Sep 27, 2020
1 parent 925d7cd commit 934aa50
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 63 deletions.
163 changes: 158 additions & 5 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
"""Unit tests for TRS client."""

import pytest
import requests

from trs_cli.client import TRSClient
from trs_cli.errors import (
InvalidURI,
ContentTypeUnavailable, InvalidResponseError, InvalidURI,
InvalidResourceIdentifier,
)

MOCK_DOMAIN = "x.y.z"
MOCK_HOST = f"https://{MOCK_DOMAIN}"
MOCK_PORT = 4434
MOCK_BASE_PATH = "a/b/c"
MOCK_API = f"{MOCK_HOST}:{MOCK_PORT}/{MOCK_BASE_PATH}"
MOCK_ID = "123456"
MOCK_ID_INVALID = "N0T VAL!D"
MOCK_TRS_URI = f"trs://{MOCK_DOMAIN}/{MOCK_ID}"
MOCK_TRS_URI_VERSIONED = f"trs://{MOCK_DOMAIN}/{MOCK_ID}/versions/{MOCK_ID}"
MOCK_PORT = 4434
MOCK_BASE_PATH = "a/b/c"
MOCK_TOKEN = "MyT0k3n"
MOCK_DESCRIPTOR = "CWL"
MOCK_RESPONSE_INVALID = {"not": "valid"}
MOCK_ERROR = {
"code": 400,
"message": "BadRequest",
}
MOCK_FILE_WRAPPER = {
"checksum": [
{
"checksum": "checksum",
"type": "sha1",
}
],
"content": "content",
"url": "url",
}


def _raise(exception) -> None:
"""General purpose exception raiser."""
raise exception


class TestTRSClientConstructor:
Expand All @@ -31,7 +54,7 @@ def test_invidiual_parts(self):
use_http=True,
token=MOCK_TOKEN,
)
assert cli.uri == f"https://{MOCK_DOMAIN}:{MOCK_PORT}/{MOCK_BASE_PATH}"
assert cli.uri == MOCK_API

def test_trs_uri(self):
"""Provide TRS URI."""
Expand All @@ -47,6 +70,88 @@ def test_trs_uri_http(self):
assert cli.uri == f"http://{MOCK_DOMAIN}:80/ga4gh/trs/v2"


class TestGetDescriptor:
"""Test getter for primary descriptor."""

cli = TRSClient(
uri=MOCK_TRS_URI,
token=MOCK_TOKEN,
)
endpoint = (
f"{cli.uri}/tools/{MOCK_ID}/versions/{MOCK_ID}/{MOCK_DESCRIPTOR}"
"/descriptor"
)

def test_ConnectionError(self, monkeypatch):
"""Connection error occurs."""
monkeypatch.setattr(
'requests.get',
lambda *args, **kwargs: _raise(requests.exceptions.ConnectionError)
)
with pytest.raises(requests.exceptions.ConnectionError):
self.cli.get_descriptor(
id=MOCK_TRS_URI,
type='CWL',
)

def test_success(self, monkeypatch, requests_mock):
"""Returns 200 response."""
requests_mock.get(self.endpoint, json=MOCK_FILE_WRAPPER)
r = self.cli.get_descriptor(
type=MOCK_DESCRIPTOR,
id=MOCK_ID,
version_id=MOCK_ID,
)
assert r.dict() == MOCK_FILE_WRAPPER

def test_success_trs_uri(self, monkeypatch, requests_mock):
"""Returns 200 response with TRS URI."""
requests_mock.get(self.endpoint, json=MOCK_FILE_WRAPPER)
r = self.cli.get_descriptor(
type=MOCK_DESCRIPTOR,
id=MOCK_TRS_URI_VERSIONED,
)
assert r.dict() == MOCK_FILE_WRAPPER

def test_success_InvalidResponseError(self, requests_mock):
"""Returns 200 response but schema validation fails."""
requests_mock.get(self.endpoint, json={"not": "correct"})
with pytest.raises(InvalidResponseError):
self.cli.get_descriptor(
type=MOCK_DESCRIPTOR,
id=MOCK_ID,
version_id=MOCK_ID,
)

def test_no_success_valid_error_response(self, requests_mock):
"""Returns no 200 but valid error response."""
requests_mock.get(
self.endpoint,
json=MOCK_ERROR,
status_code=400,
)
r = self.cli.get_descriptor(
type=MOCK_DESCRIPTOR,
id=MOCK_ID,
version_id=MOCK_ID,
)
assert r.dict() == MOCK_ERROR

def test_no_success_InvalidResponseError(self, requests_mock):
"""Returns no 200 and error schema validation fails."""
requests_mock.get(
self.endpoint,
json=MOCK_RESPONSE_INVALID,
status_code=400,
)
with pytest.raises(InvalidResponseError):
self.cli.get_descriptor(
type=MOCK_DESCRIPTOR,
id=MOCK_ID,
version_id=MOCK_ID,
)


class TestGetHost:
"""Test domain/schema parser."""

Expand Down Expand Up @@ -111,7 +216,7 @@ def test_trs_uri_ip(self):


class TestGetToolIdVersionId:
"""Test too/version ID parser."""
"""Test tool/version ID parser."""

cli = TRSClient(uri=MOCK_TRS_URI)

Expand Down Expand Up @@ -179,3 +284,51 @@ def test_trs_uri_versioned_override(self):
version_id=MOCK_ID + MOCK_ID
)
assert res == (MOCK_ID, MOCK_ID + MOCK_ID)


class TestGetHeaders:
"""Test headers getter."""

cli = TRSClient(uri=MOCK_TRS_URI)

def test_no_args(self):
"""No arguments passed."""
self.cli._get_headers()
assert self.cli.headers['Accept'] == 'application/json'
assert 'Content-Type' not in self.cli.headers
assert 'Authorization' not in self.cli.headers

def test_all_args(self):
"""All arguments passed."""
self.cli.token = MOCK_TOKEN
self.cli._get_headers(
content_accept='text/plain',
content_type='application/json'
)
assert self.cli.headers['Authorization'] == f"Bearer {MOCK_TOKEN}"
assert self.cli.headers['Accept'] == 'text/plain'
assert self.cli.headers['Content-Type'] == 'application/json'


class TestValidateContentType:
"""Test content type validation."""

cli = TRSClient(uri=MOCK_TRS_URI)

def test_available_content_type(self):
"""Requested content type provided by service."""
available_types = ['application/json']
self.cli._validate_content_type(
requested_type='application/json',
available_types=available_types,
)
assert 'application/json' in available_types

def test_unavailable_content_type(self):
"""Requested content type not provided by service."""
available_types = ['application/json']
with pytest.raises(ContentTypeUnavailable):
self.cli._validate_content_type(
requested_type='not/available',
available_types=available_types,
)
111 changes: 78 additions & 33 deletions trs_cli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
import requests
import socket
import sys
from typing import (Dict, Optional, Tuple, Union)
from typing import (Dict, List, Optional, Tuple, Union)
import urllib3
from urllib.parse import quote

from trs_cli.errors import (
exception_handler,
InvalidURI,
InvalidResourceIdentifier,
InvalidResponseError
InvalidResponseError,
ContentTypeUnavailable,
)
from trs_cli.models import Error, FileWrapper # noqa: F401

Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
port = 80 if schema == 'http' else 443
self.uri = f"{schema}://{host}:{port}/{base_path}"
self.token = token
self.headers = self._get_headers()
self.headers = {}
logger.info(f"Instantiated client for: {self.uri}")

# TODO: implement methods to connect to various endpoints below, e.g.,:
Expand All @@ -96,26 +97,29 @@ def __init__(

def get_descriptor(
self,
id: str,
version_id: Optional[str],
type: str,
id: str,
version_id: Optional[str] = None,
accept: str = 'application/json',
token: Optional[str] = None
) -> Union[FileWrapper, Error]:
"""Get the tool descriptor for the specified tool.
Arguments:
id: A unique identifier of the tool, scoped to this registry OR
a hostname-based TRS URI. If TRS URIs include the version
information, passing a `version_id` is optional.
version_id: An optional identifier of the tool version, scoped
to this registry. It is optional if version info is included
in the TRS URI. If passed, then the existing `version_id`
retreived from the TRS URI is overridden.
type: The output type of the descriptor. Plain types return
the bare descriptor while the "non-plain" types return a
descriptor wrapped with metadata. Allowable values include
"CWL", "WDL", "NFL", "GALAXY", "PLAIN_CWL", "PLAIN_WDL",
"PLAIN_NFL", "PLAIN_GALAXY".
id: A unique identifier of the tool, scoped to this registry OR
a TRS URI. If a TRS URI is passed and includes the version
identifier, passing a `version_id` is optional. For more
information on TRS URIs, cf.
https://ga4gh.github.io/tool-registry-service-schemas/DataModel/#trs_uris
version_id: Identifier of the tool version, scoped to this
registry. It is optional if a TRS URI is passed and includes
version information. If provided nevertheless, then the
`version_id` retrieved from the TRS URI is overridden.
token: Bearer token for authentication. Set if required by TRS
implementation and if not provided when instatiating client or
if expired.
Expand All @@ -131,15 +135,29 @@ def get_descriptor(
trs_cli.errors.InvalidResponseError: The response could not be
validated against the API schema.
"""
id, version_id = self._get_tool_id_version_id(
tool_id=id,
version_id=version_id
# validate requested content type, set token and get request headers
self._validate_content_type(
requested_type=accept,
available_types=['application/json', 'text/plain'],
)
url = f"{self.uri}/tools/{id}/versions/{version_id}/{type}/descriptor"
if token:
self.token = token
self._get_headers(content_accept=accept)

# get/sanitize tool and version identifiers
_id, _version_id = self._get_tool_id_version_id(
tool_id=id,
version_id=version_id,
)

# build request URL
url = (
f"{self.uri}/tools/{_id}/versions/{_version_id}/{type}/"
"descriptor"
)
logger.info(f"Connecting to '{url}'...")

self._get_headers()
# send request and handle exceptions and error responses
try:
response = requests.get(
url=url,
Expand Down Expand Up @@ -173,24 +191,11 @@ def get_descriptor(
)
logger.info(
f"Retrieved descriptor of type '{type}' "
f"for tool '{id}', version '{version_id}'."
f"for tool '{_id}', version '{_version_id}'."
)

return response_val

def _get_headers(self) -> Dict:
"""Build dictionary of request headers.
Returns:
A dictionary of request headers
"""
headers: Dict = {
'Accept': 'application/json',
}
if self.token:
headers['Authorization'] = 'Bearer ' + self.token
return headers

def _get_host(
self,
uri: str,
Expand Down Expand Up @@ -228,9 +233,9 @@ def _get_host(

def _get_tool_id_version_id(
self,
tool_id: str,
tool_id: Optional[str] = None,
version_id: Optional[str] = None,
) -> Tuple[str, str]:
) -> Tuple[Optional[str], Optional[str]]:
"""
Return sanitized tool and/or version identifiers or extract them from
a TRS URI.
Expand Down Expand Up @@ -287,3 +292,43 @@ def _get_tool_id_version_id(
ret_version_id = quote(ret_version_id, safe='')

return (ret_tool_id, ret_version_id)

def _get_headers(
self,
content_accept: str = 'application/json',
content_type: Optional[str] = None,
) -> None:
"""Build dictionary of request headers.
Arguments:
content_accept: Requested MIME/content type.
content_type: Type of content sent with the request.
"""
self.headers['Accept'] = content_accept
if content_type:
self.headers['Content-Type'] = content_type
if self.token:
self.headers['Authorization'] = f"Bearer {self.token}"

def _validate_content_type(
self,
requested_type: str,
available_types: List[str] = ['application/json'],
) -> None:
"""Ensure that content type is among content types provided by the
service.
Arguments:
requested_type: Requested MIME/content type.
available_types: Content types provided by the service for a given
endpoint.
Raises:
ContentTypeUnavailable: The service does not provide the requested
content type.
"""
if requested_type not in available_types:
logger.error(
"Requested content type not provided by the service."
)
raise ContentTypeUnavailable
4 changes: 4 additions & 0 deletions trs_cli/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ class InvalidResourceIdentifier(Exception):

class InvalidResponseError(Exception):
"""Exception raised when an invalid API response is encountered."""


class ContentTypeUnavailable(Exception):
"""Exception raised when an unavailable content type is requested."""
Loading

0 comments on commit 934aa50

Please sign in to comment.