diff --git a/tests/test_client.py b/tests/test_client.py index 0c3b479..b45dde7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,22 +1,45 @@ """Unit tests for TRS client.""" import pytest +import requests from trs_cli.client import TRSClient from trs_cli.errors import ( - InvalidURI, + ContentTypeUnavailable, InvalidResponseError, InvalidURI, InvalidResourceIdentifier, ) MOCK_DOMAIN = "x.y.z" MOCK_HOST = f"https://{MOCK_DOMAIN}" +MOCK_PORT = 4434 +MOCK_BASE_PATH = "a/b/c" +MOCK_API = f"{MOCK_HOST}:{MOCK_PORT}/{MOCK_BASE_PATH}" MOCK_ID = "123456" MOCK_ID_INVALID = "N0T VAL!D" MOCK_TRS_URI = f"trs://{MOCK_DOMAIN}/{MOCK_ID}" MOCK_TRS_URI_VERSIONED = f"trs://{MOCK_DOMAIN}/{MOCK_ID}/versions/{MOCK_ID}" -MOCK_PORT = 4434 -MOCK_BASE_PATH = "a/b/c" MOCK_TOKEN = "MyT0k3n" +MOCK_DESCRIPTOR = "CWL" +MOCK_RESPONSE_INVALID = {"not": "valid"} +MOCK_ERROR = { + "code": 400, + "message": "BadRequest", +} +MOCK_FILE_WRAPPER = { + "checksum": [ + { + "checksum": "checksum", + "type": "sha1", + } + ], + "content": "content", + "url": "url", +} + + +def _raise(exception) -> None: + """General purpose exception raiser.""" + raise exception class TestTRSClientConstructor: @@ -31,7 +54,7 @@ def test_invidiual_parts(self): use_http=True, token=MOCK_TOKEN, ) - assert cli.uri == f"https://{MOCK_DOMAIN}:{MOCK_PORT}/{MOCK_BASE_PATH}" + assert cli.uri == MOCK_API def test_trs_uri(self): """Provide TRS URI.""" @@ -47,6 +70,88 @@ def test_trs_uri_http(self): assert cli.uri == f"http://{MOCK_DOMAIN}:80/ga4gh/trs/v2" +class TestGetDescriptor: + """Test getter for primary descriptor.""" + + cli = TRSClient( + uri=MOCK_TRS_URI, + token=MOCK_TOKEN, + ) + endpoint = ( + f"{cli.uri}/tools/{MOCK_ID}/versions/{MOCK_ID}/{MOCK_DESCRIPTOR}" + "/descriptor" + ) + + def test_ConnectionError(self, monkeypatch): + """Connection error occurs.""" + monkeypatch.setattr( + 'requests.get', + lambda *args, **kwargs: _raise(requests.exceptions.ConnectionError) + ) + with pytest.raises(requests.exceptions.ConnectionError): + self.cli.get_descriptor( + id=MOCK_TRS_URI, + type='CWL', + ) + + def test_success(self, monkeypatch, requests_mock): + """Returns 200 response.""" + requests_mock.get(self.endpoint, json=MOCK_FILE_WRAPPER) + r = self.cli.get_descriptor( + type=MOCK_DESCRIPTOR, + id=MOCK_ID, + version_id=MOCK_ID, + ) + assert r.dict() == MOCK_FILE_WRAPPER + + def test_success_trs_uri(self, monkeypatch, requests_mock): + """Returns 200 response with TRS URI.""" + requests_mock.get(self.endpoint, json=MOCK_FILE_WRAPPER) + r = self.cli.get_descriptor( + type=MOCK_DESCRIPTOR, + id=MOCK_TRS_URI_VERSIONED, + ) + assert r.dict() == MOCK_FILE_WRAPPER + + def test_success_InvalidResponseError(self, requests_mock): + """Returns 200 response but schema validation fails.""" + requests_mock.get(self.endpoint, json={"not": "correct"}) + with pytest.raises(InvalidResponseError): + self.cli.get_descriptor( + type=MOCK_DESCRIPTOR, + id=MOCK_ID, + version_id=MOCK_ID, + ) + + def test_no_success_valid_error_response(self, requests_mock): + """Returns no 200 but valid error response.""" + requests_mock.get( + self.endpoint, + json=MOCK_ERROR, + status_code=400, + ) + r = self.cli.get_descriptor( + type=MOCK_DESCRIPTOR, + id=MOCK_ID, + version_id=MOCK_ID, + ) + assert r.dict() == MOCK_ERROR + + def test_no_success_InvalidResponseError(self, requests_mock): + """Returns no 200 and error schema validation fails.""" + requests_mock.get( + self.endpoint, + json=MOCK_RESPONSE_INVALID, + status_code=400, + ) + with pytest.raises(InvalidResponseError): + self.cli.get_descriptor( + type=MOCK_DESCRIPTOR, + id=MOCK_ID, + version_id=MOCK_ID, + ) + + class TestGetHost: """Test domain/schema parser.""" @@ -111,7 +216,7 @@ def test_trs_uri_ip(self): class TestGetToolIdVersionId: - """Test too/version ID parser.""" + """Test tool/version ID parser.""" cli = TRSClient(uri=MOCK_TRS_URI) @@ -179,3 +284,51 @@ def test_trs_uri_versioned_override(self): version_id=MOCK_ID + MOCK_ID ) assert res == (MOCK_ID, MOCK_ID + MOCK_ID) + + +class TestGetHeaders: + """Test headers getter.""" + + cli = TRSClient(uri=MOCK_TRS_URI) + + def test_no_args(self): + """No arguments passed.""" + self.cli._get_headers() + assert self.cli.headers['Accept'] == 'application/json' + assert 'Content-Type' not in self.cli.headers + assert 'Authorization' not in self.cli.headers + + def test_all_args(self): + """All arguments passed.""" + self.cli.token = MOCK_TOKEN + self.cli._get_headers( + content_accept='text/plain', + content_type='application/json' + ) + assert self.cli.headers['Authorization'] == f"Bearer {MOCK_TOKEN}" + assert self.cli.headers['Accept'] == 'text/plain' + assert self.cli.headers['Content-Type'] == 'application/json' + + +class TestValidateContentType: + """Test content type validation.""" + + cli = TRSClient(uri=MOCK_TRS_URI) + + def test_available_content_type(self): + """Requested content type provided by service.""" + available_types = ['application/json'] + self.cli._validate_content_type( + requested_type='application/json', + available_types=available_types, + ) + assert 'application/json' in available_types + + def test_unavailable_content_type(self): + """Requested content type not provided by service.""" + available_types = ['application/json'] + with pytest.raises(ContentTypeUnavailable): + self.cli._validate_content_type( + requested_type='not/available', + available_types=available_types, + ) diff --git a/trs_cli/client.py b/trs_cli/client.py index 6b7babc..03e00d1 100644 --- a/trs_cli/client.py +++ b/trs_cli/client.py @@ -7,7 +7,7 @@ import requests import socket import sys -from typing import (Dict, Optional, Tuple, Union) +from typing import (Dict, List, Optional, Tuple, Union) import urllib3 from urllib.parse import quote @@ -15,7 +15,8 @@ exception_handler, InvalidURI, InvalidResourceIdentifier, - InvalidResponseError + InvalidResponseError, + ContentTypeUnavailable, ) from trs_cli.models import Error, FileWrapper # noqa: F401 @@ -84,7 +85,7 @@ def __init__( port = 80 if schema == 'http' else 443 self.uri = f"{schema}://{host}:{port}/{base_path}" self.token = token - self.headers = self._get_headers() + self.headers = {} logger.info(f"Instantiated client for: {self.uri}") # TODO: implement methods to connect to various endpoints below, e.g.,: @@ -96,26 +97,29 @@ def __init__( def get_descriptor( self, - id: str, - version_id: Optional[str], type: str, + id: str, + version_id: Optional[str] = None, + accept: str = 'application/json', token: Optional[str] = None ) -> Union[FileWrapper, Error]: """Get the tool descriptor for the specified tool. Arguments: - id: A unique identifier of the tool, scoped to this registry OR - a hostname-based TRS URI. If TRS URIs include the version - information, passing a `version_id` is optional. - version_id: An optional identifier of the tool version, scoped - to this registry. It is optional if version info is included - in the TRS URI. If passed, then the existing `version_id` - retreived from the TRS URI is overridden. type: The output type of the descriptor. Plain types return the bare descriptor while the "non-plain" types return a descriptor wrapped with metadata. Allowable values include "CWL", "WDL", "NFL", "GALAXY", "PLAIN_CWL", "PLAIN_WDL", "PLAIN_NFL", "PLAIN_GALAXY". + id: A unique identifier of the tool, scoped to this registry OR + a TRS URI. If a TRS URI is passed and includes the version + identifier, passing a `version_id` is optional. For more + information on TRS URIs, cf. + https://ga4gh.github.io/tool-registry-service-schemas/DataModel/#trs_uris + version_id: Identifier of the tool version, scoped to this + registry. It is optional if a TRS URI is passed and includes + version information. If provided nevertheless, then the + `version_id` retrieved from the TRS URI is overridden. token: Bearer token for authentication. Set if required by TRS implementation and if not provided when instatiating client or if expired. @@ -131,15 +135,29 @@ def get_descriptor( trs_cli.errors.InvalidResponseError: The response could not be validated against the API schema. """ - id, version_id = self._get_tool_id_version_id( - tool_id=id, - version_id=version_id + # validate requested content type, set token and get request headers + self._validate_content_type( + requested_type=accept, + available_types=['application/json', 'text/plain'], ) - url = f"{self.uri}/tools/{id}/versions/{version_id}/{type}/descriptor" if token: self.token = token + self._get_headers(content_accept=accept) + + # get/sanitize tool and version identifiers + _id, _version_id = self._get_tool_id_version_id( + tool_id=id, + version_id=version_id, + ) + + # build request URL + url = ( + f"{self.uri}/tools/{_id}/versions/{_version_id}/{type}/" + "descriptor" + ) + logger.info(f"Connecting to '{url}'...") - self._get_headers() + # send request and handle exceptions and error responses try: response = requests.get( url=url, @@ -173,24 +191,11 @@ def get_descriptor( ) logger.info( f"Retrieved descriptor of type '{type}' " - f"for tool '{id}', version '{version_id}'." + f"for tool '{_id}', version '{_version_id}'." ) return response_val - def _get_headers(self) -> Dict: - """Build dictionary of request headers. - - Returns: - A dictionary of request headers - """ - headers: Dict = { - 'Accept': 'application/json', - } - if self.token: - headers['Authorization'] = 'Bearer ' + self.token - return headers - def _get_host( self, uri: str, @@ -228,9 +233,9 @@ def _get_host( def _get_tool_id_version_id( self, - tool_id: str, + tool_id: Optional[str] = None, version_id: Optional[str] = None, - ) -> Tuple[str, str]: + ) -> Tuple[Optional[str], Optional[str]]: """ Return sanitized tool and/or version identifiers or extract them from a TRS URI. @@ -287,3 +292,43 @@ def _get_tool_id_version_id( ret_version_id = quote(ret_version_id, safe='') return (ret_tool_id, ret_version_id) + + def _get_headers( + self, + content_accept: str = 'application/json', + content_type: Optional[str] = None, + ) -> None: + """Build dictionary of request headers. + + Arguments: + content_accept: Requested MIME/content type. + content_type: Type of content sent with the request. + """ + self.headers['Accept'] = content_accept + if content_type: + self.headers['Content-Type'] = content_type + if self.token: + self.headers['Authorization'] = f"Bearer {self.token}" + + def _validate_content_type( + self, + requested_type: str, + available_types: List[str] = ['application/json'], + ) -> None: + """Ensure that content type is among content types provided by the + service. + + Arguments: + requested_type: Requested MIME/content type. + available_types: Content types provided by the service for a given + endpoint. + + Raises: + ContentTypeUnavailable: The service does not provide the requested + content type. + """ + if requested_type not in available_types: + logger.error( + "Requested content type not provided by the service." + ) + raise ContentTypeUnavailable diff --git a/trs_cli/errors.py b/trs_cli/errors.py index 56ddc6e..2ce090c 100644 --- a/trs_cli/errors.py +++ b/trs_cli/errors.py @@ -27,3 +27,7 @@ class InvalidResourceIdentifier(Exception): class InvalidResponseError(Exception): """Exception raised when an invalid API response is encountered.""" + + +class ContentTypeUnavailable(Exception): + """Exception raised when an unavailable content type is requested.""" diff --git a/trs_cli/models.py b/trs_cli/models.py index d4552af..9f3fd62 100644 --- a/trs_cli/models.py +++ b/trs_cli/models.py @@ -9,7 +9,16 @@ from pydantic import AnyUrl, BaseModel, Field -class Checksum(BaseModel): +class CustomBaseModel(BaseModel): + """Settings subclass.""" + + class Config: + """Configuration for `pydantic` model class.""" + extra = 'forbid' + arbitrary_types_allowed = False + + +class Checksum(CustomBaseModel): checksum: str = Field( ..., description='The hex-string encoded checksum for the data. ' ) @@ -19,7 +28,7 @@ class Checksum(BaseModel): ) -class ChecksumRegister(BaseModel): +class ChecksumRegister(CustomBaseModel): checksum: str = Field( ..., description='The hex-string encoded checksum for the data. ' ) @@ -36,12 +45,12 @@ class DescriptorType(Enum): GALAXY = 'GALAXY' -class Error(BaseModel): +class Error(CustomBaseModel): code: int message: Optional[str] = 'Internal Server Error' -class FileWrapper(BaseModel): +class FileWrapper(CustomBaseModel): checksum: Optional[List[Checksum]] = Field( None, description='A production (immutable) tool version is required to have a hashcode. Not required otherwise, but might be useful to detect changes. ', @@ -59,7 +68,7 @@ class FileWrapper(BaseModel): ) -class FileWrapperRegister(BaseModel): +class FileWrapperRegister(CustomBaseModel): checksum: Optional[List[ChecksumRegister]] = Field( None, description='A production (immutable) tool version is required to have a hashcode. Not required otherwise, but might be useful to detect changes. ', @@ -88,7 +97,7 @@ class OtherType(Enum): OTHER = 'OTHER' -class Organization(BaseModel): +class Organization(CustomBaseModel): name: str = Field( ..., description='Name of the organization responsible for the service', @@ -101,7 +110,7 @@ class Organization(BaseModel): ) -class Organization1(BaseModel): +class Organization1(CustomBaseModel): name: str = Field( ..., description='Name of the organization responsible for the service', @@ -114,7 +123,7 @@ class Organization1(BaseModel): ) -class ServiceType(BaseModel): +class ServiceType(CustomBaseModel): artifact: str = Field( ..., description='Name of the API or GA4GH specification implemented. Official GA4GH types should be assigned as part of standards approval process. Custom artifacts are supported.', @@ -132,7 +141,7 @@ class ServiceType(BaseModel): ) -class ServiceTypeRegister(BaseModel): +class ServiceTypeRegister(CustomBaseModel): artifact: str = Field( ..., description='Name of the API or GA4GH specification implemented. Official GA4GH types should be assigned as part of standards approval process. Custom artifacts are supported.', @@ -150,7 +159,7 @@ class ServiceTypeRegister(BaseModel): ) -class ToolClass(BaseModel): +class ToolClass(CustomBaseModel): description: Optional[str] = Field( None, description='A longer explanation of what this class is and what it can accomplish.', @@ -161,7 +170,7 @@ class ToolClass(BaseModel): ) -class ToolClassRegister(BaseModel): +class ToolClassRegister(CustomBaseModel): description: Optional[str] = Field( None, description='A longer explanation of what this class is and what it can accomplish.', @@ -171,7 +180,7 @@ class ToolClassRegister(BaseModel): ) -class ToolClassRegisterId(BaseModel): +class ToolClassRegisterId(CustomBaseModel): description: Optional[str] = Field( None, description='A longer explanation of what this class is and what it can accomplish.', @@ -190,7 +199,7 @@ class FileType(Enum): OTHER = 'OTHER' -class ToolFile(BaseModel): +class ToolFile(CustomBaseModel): file_type: Optional[FileType] = None path: Optional[str] = Field( None, @@ -206,7 +215,7 @@ class FileType1(Enum): OTHER = 'OTHER' -class ToolFileRegister(BaseModel): +class ToolFileRegister(CustomBaseModel): file_type: Optional[FileType1] = None path: Optional[str] = Field( None, @@ -214,17 +223,17 @@ class ToolFileRegister(BaseModel): ) -class TypeRegister(BaseModel): +class TypeRegister(CustomBaseModel): __root__: str -class FilesRegister(BaseModel): +class FilesRegister(CustomBaseModel): file_wrapper: Optional[FileWrapperRegister] = None tool_file: Optional[ToolFileRegister] = None type: Optional[TypeRegister] = None -class ImageData(BaseModel): +class ImageData(CustomBaseModel): checksum: Optional[List[Checksum]] = Field( None, description='A production (immutable) tool version is required to have a hashcode. Not required otherwise, but might be useful to detect changes. This exposes the hashcode for specific image versions to verify that the container version pulled is actually the version that was indexed by the registry.', @@ -250,7 +259,7 @@ class ImageData(BaseModel): ) -class ImageDataRegister(BaseModel): +class ImageDataRegister(CustomBaseModel): checksum: Optional[List[ChecksumRegister]] = Field( None, description='A production (immutable) tool version is required to have a hashcode. Not required otherwise, but might be useful to detect changes. This exposes the hashcode for specific image versions to verify that the container version pulled is actually the version that was indexed by the registry.', @@ -276,7 +285,7 @@ class ImageDataRegister(BaseModel): ) -class Service(BaseModel): +class Service(CustomBaseModel): contactUrl: Optional[AnyUrl] = Field( None, description='URL of the contact for the provider of this service, e.g. a link to a contact form (RFC 3986 format), or an email (RFC 2368 format).', @@ -328,7 +337,7 @@ class Service(BaseModel): ) -class ServiceRegister(BaseModel): +class ServiceRegister(CustomBaseModel): contactUrl: Optional[AnyUrl] = Field( None, description='URL of the contact for the provider of this service, e.g. a link to a contact form (RFC 3986 format), or an email (RFC 2368 format).', @@ -380,7 +389,7 @@ class ServiceRegister(BaseModel): ) -class ToolVersion(BaseModel): +class ToolVersion(CustomBaseModel): author: Optional[List[str]] = Field( None, description='Contact information for the author of this version of the tool in the registry. (More complex authorship information is handled by the descriptor).', @@ -436,7 +445,7 @@ class ToolVersion(BaseModel): ) -class ToolVersionRegister(BaseModel): +class ToolVersionRegister(CustomBaseModel): author: Optional[List[str]] = Field( None, description='Contact information for the author of this version of the tool in the registry. (More complex authorship information is handled by the descriptor).', @@ -478,7 +487,7 @@ class ToolVersionRegister(BaseModel): ) -class ToolVersionRegisterId(BaseModel): +class ToolVersionRegisterId(CustomBaseModel): author: Optional[List[str]] = Field( None, description='Contact information for the author of this version of the tool in the registry. (More complex authorship information is handled by the descriptor).', @@ -525,7 +534,7 @@ class ToolVersionRegisterId(BaseModel): ) -class Tool(BaseModel): +class Tool(CustomBaseModel): aliases: Optional[List[str]] = Field( None, description='Support for this parameter is optional for tool registries that support aliases.\nA list of strings that can be used to identify this tool which could be straight up URLs. \nThis can be used to expose alternative ids (such as GUIDs) for a tool\nfor registries. Can be used to match tools across registries.', @@ -562,7 +571,7 @@ class Tool(BaseModel): ) -class ToolRegister(BaseModel): +class ToolRegister(CustomBaseModel): aliases: Optional[List[str]] = Field( None, description='Support for this parameter is optional for tool registries that support aliases. A list of strings that can be used to identify this tool which could be straight up URLs. This can be used to expose alternative ids (such as GUIDs) for a tool for registries. Can be used to match tools across registries.',