diff --git a/.appveyor.yml b/.appveyor.yml index e1c040776..a8daff2f5 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -117,9 +117,7 @@ before_test: - "flake8 cis_interface" test_script: - # - "nosetests --nologcapture --with-coverage --cover-package=cis_interface -svx cis_interface/tests/test_runner.py" - - "nosetests -svx --nologcapture --with-coverage --cover-package=cis_interface cis_interface" - + - "cistest --nologcapture --with-coverage --cover-package=cis_interface -svx" after_test: - "codecov" diff --git a/.travis.yml b/.travis.yml index dc0bdd1af..aa763af0d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -159,7 +159,7 @@ before_script: script: - | - nosetests --nologcapture --with-coverage --cover-package=cis_interface -svx + cistest --nologcapture --with-coverage --cover-package=cis_interface -svx after_success: - coveralls diff --git a/MANIFEST.in b/MANIFEST.in index 2a4e1b2c9..c411fc505 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ +include VERSION include *.md include *.png include *.svg diff --git a/VERSION b/VERSION new file mode 100644 index 000000000..a2268e2de --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.3.1 \ No newline at end of file diff --git a/cis_interface/.cis_schema.yml b/cis_interface/.cis_schema.yml index 20d151528..4a4905688 100644 --- a/cis_interface/.cis_schema.yml +++ b/cis_interface/.cis_schema.yml @@ -2,28 +2,28 @@ comm: as_array: &id001 required: false type: boolean - field_names: &id002 + dtype: &id002 + required: false + type: string + field_names: &id003 required: false schema: type: string type: list - field_units: &id003 + field_units: &id004 required: false schema: type: string type: list - format_str: &id004 + format_str: &id005 required: false type: string - name: &id005 + name: &id006 required: true type: string - stype: &id006 + stype: &id007 required: false type: integer - type: &id007 - required: false - type: string units: &id008 required: false type: string @@ -45,66 +45,7 @@ connection: input_file: excludes: input required: true - schema: - append: &id009 - required: false - type: boolean - as_array: *id001 - comment: &id010 - dependencies: - filetype: - - ascii - - table - required: false - type: string - delimiter: &id011 - dependencies: - filetype: - - pandas - - table - required: false - type: string - field_names: *id002 - field_units: *id003 - filetype: &id012 - allowed: - - ascii - - binary - - map - - obj - - pandas - - pickle - - ply - - table - default: binary - required: false - type: string - format_str: *id004 - in_temp: &id013 - required: false - type: boolean - is_series: &id014 - required: false - type: boolean - name: *id005 - newline: &id015 - default: ' - - ' - required: false - type: string - stype: *id006 - type: *id007 - units: *id008 - use_astropy: &id016 - dependencies: - filetype: - - table - required: false - type: boolean - working_dir: &id017 - required: true - type: string + schema: file type: dict onexit: required: false @@ -120,24 +61,7 @@ connection: output_file: excludes: output required: true - schema: - append: *id009 - as_array: *id001 - comment: *id010 - delimiter: *id011 - field_names: *id002 - field_units: *id003 - filetype: *id012 - format_str: *id004 - in_temp: *id013 - is_series: *id014 - name: *id005 - newline: *id015 - stype: *id006 - type: *id007 - units: *id008 - use_astropy: *id016 - working_dir: *id017 + schema: file type: dict translator: required: false @@ -253,23 +177,65 @@ connection_schema_subtypes: - RMQComm - output file: - append: *id009 + append: + required: false + type: boolean as_array: *id001 - comment: *id010 - delimiter: *id011 - field_names: *id002 - field_units: *id003 - filetype: *id012 - format_str: *id004 - in_temp: *id013 - is_series: *id014 - name: *id005 - newline: *id015 - stype: *id006 - type: *id007 + comment: + dependencies: + filetype: + - ascii + - table + required: false + type: string + delimiter: + dependencies: + filetype: + - table + - pandas + required: false + type: string + dtype: *id002 + field_names: *id003 + field_units: *id004 + filetype: + allowed: + - binary + - ply + - table + - pickle + - pandas + - map + - ascii + - obj + default: binary + required: false + type: string + format_str: *id005 + in_temp: + required: false + type: boolean + is_series: + required: false + type: boolean + name: *id006 + newline: + default: ' + + ' + required: false + type: string + stype: *id007 units: *id008 - use_astropy: *id016 - working_dir: *id017 + use_astropy: + dependencies: + filetype: + - table + required: false + type: boolean + working_dir: + required: true + type: string file_schema_subtypes: AsciiFileComm: - ascii @@ -326,15 +292,7 @@ model: inputs: required: false schema: - schema: - as_array: *id001 - field_names: *id002 - field_units: *id003 - format_str: *id004 - name: *id005 - stype: *id006 - type: *id007 - units: *id008 + schema: comm type: dict type: list is_server: @@ -377,15 +335,7 @@ model: outputs: required: false schema: - schema: - as_array: *id001 - field_names: *id002 - field_units: *id003 - format_str: *id004 - name: *id005 - stype: *id006 - type: *id007 - units: *id008 + schema: comm type: dict type: list sourcedir: diff --git a/cis_interface/__init__.py b/cis_interface/__init__.py index 342ee0635..8f303f0dd 100644 --- a/cis_interface/__init__.py +++ b/cis_interface/__init__.py @@ -2,6 +2,8 @@ such that they can be run simultaneously, passing input back and forth.""" from cis_interface import platform import os +import sys +import nose if platform._is_win: # pragma: windows @@ -10,20 +12,46 @@ os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = 'T' -# from cis_interface import backwards -# from cis_interface import platform -# from cis_interface import config -# from cis_interface import tools -# from cis_interface import interface -# from cis_interface import drivers -# from cis_interface import dataio -# from cis_interface import tests -# from cis_interface import examples -# from cis_interface import runner +def run_nose(verbose=False, nocapture=False, stop=False, + nologcapture=False, withcoverage=False): # pragma: debug + r"""Run nose tests for the package. + Args: + verbose (bool, optional): If True, set nose option '-v' which + increases the verbosity. Defaults to False. + nocapture (bool, optional): If True, set nose option '--nocapture' + which allows messages to be printed to stdout. Defaults to False. + stop (bool, optional): If True, set nose option '--stop' which + stops tests at the first failure. Defaults to False. + nologcapture (bool, optional): If True, set nose option '--nologcapture' + which allows logged messages to be printed. Defaults to False. + withcoverage (bool, optional): If True, set nose option '--with-coverage' + which invokes coverage. Defaults to False. + + """ + error_code = 0 + nose_argv = sys.argv + nose_argv += ['--detailed-errors', '--exe'] + if verbose: + nose_argv.append('-v') + if nocapture: + nose_argv.append('--nocapture') + if stop: + nose_argv.append('--stop') + if nologcapture: + nose_argv.append('--nologcapture') + if withcoverage: + nose_argv.append('--with-coverage') + initial_dir = os.getcwd() + package_dir = os.path.dirname(os.path.abspath(__file__)) + os.chdir(package_dir) + try: + nose.run(argv=nose_argv) + except BaseException: + error_code = -1 + finally: + os.chdir(initial_dir) + return error_code -# __all__ = ['backwards', 'platform', 'config', 'tools', -# 'interface', 'drivers', 'dataio', -# 'tests', 'examples', 'runner'] __all__ = [] diff --git a/cis_interface/backwards.py b/cis_interface/backwards.py index b2ea28580..e9159d4cb 100644 --- a/cis_interface/backwards.py +++ b/cis_interface/backwards.py @@ -15,6 +15,7 @@ file_type = types.FileType bytes_type = str unicode_type = str + string_type = str np_dtype_str = 'S' string_types = (str, unicode) else: # pragma: Python 3 @@ -26,6 +27,7 @@ file_type = sio.IOBase bytes_type = bytes unicode_type = str + string_type = str unicode = None np_dtype_str = 'S' string_types = (bytes, str) diff --git a/cis_interface/communication/CommBase.py b/cis_interface/communication/CommBase.py index eb190d266..3ef02f6f3 100755 --- a/cis_interface/communication/CommBase.py +++ b/cis_interface/communication/CommBase.py @@ -294,7 +294,7 @@ class CommBase(tools.CisClass): _commtype = 'default' _schema_type = 'comm' _schema = {'name': {'type': 'string', 'required': True}, - 'type': {'type': 'string', 'required': False}, # TODO: add values + 'dtype': {'type': 'string', 'required': False}, # TODO: add values 'units': {'type': 'string', 'required': False}, # TODO: add values 'format_str': {'type': 'string', 'required': False}, 'as_array': {'type': 'boolean', 'required': False}, diff --git a/cis_interface/datatypes/CisArrayType.py b/cis_interface/datatypes/CisArrayType.py new file mode 100644 index 000000000..d6f3a7a3b --- /dev/null +++ b/cis_interface/datatypes/CisArrayType.py @@ -0,0 +1,318 @@ +import numpy as np +import copy +from cis_interface import units +from cis_interface.datatypes import register_type +from cis_interface.datatypes.CisBaseType import CisBaseType + + +def data2dtype(data): + r"""Get numpy data type for an object. + + Args: + data (object): Python object. + + Returns: + np.dtype: Numpy data type. + + """ + data_nounits = units.get_data(data) + if isinstance(data_nounits, np.ndarray): + dtype = data_nounits.dtype + else: + dtype = np.array([data_nounits]).dtype + # dtype = np.dtype(type(data_nounits)) + return dtype + + +def dtype2definition(dtype): + r"""Get type definition from numpy data type. + + Args: + dtype (np.dtype): Numpy data type. + + Returns: + dict: Type definition. + + """ + out = {} + if np.issubdtype(dtype, np.dtype('float')): + out['type'] = 'float' + elif np.issubdtype(dtype, np.dtype('uint')): + out['type'] = 'uint' + elif np.issubdtype(dtype, np.dtype('int')): + out['type'] = 'int' + elif np.issubdtype(dtype, np.dtype('complex')): + out['type'] = 'complex' + elif np.issubdtype(dtype, np.dtype('S')): + out['type'] = 'string' + else: + raise TypeError('Cannot find type string for dtype %s' % dtype) + out['precision'] = dtype.itemsize * 8 + return out + + +def definition2dtype(props): + r"""Get numpy data type for a type definition. + + Args: + props (dict): Type definition properties. + + Returns: + np.dtype: Numpy data type. + + """ + if props['type'] == 'string': + dtype_str = 'S%d' % (props['precision'] / 8) + else: + dtype_str = props['type'] + dtype_str += str(int(props['precision'])) + return np.dtype(dtype_str) + + +def data2definition(data): + r"""Get the full definition from the data. + + Args: + data (object): Python object. + + Returns: + dict: Type definition. + + """ + out = dtype2definition(data2dtype(data)) + out['units'] = units.get_units(data) + return out + + +@register_type +class CisScalarType(CisBaseType): + r"""Type associated with a scalar.""" + + name = 'scalar' + description = 'A scalar value with or without units.' + properties = {'type': { + 'description': 'The base type for each item.', + 'type': 'string', + 'enum': ['int', 'uint', 'float', 'complex', 'string']}, + 'precision': { + 'description': 'The size (in bits) of each item.', + 'type': 'number', + 'minimum': 1}, + 'units': { + 'description': 'Physical units.', + 'type': 'string'} + } + definition_properties = ['type'] + metadata_properties = ['type', 'precision', 'units'] + + @classmethod + def check_meta_compat(cls, k, v1, v2): + r"""Check that two metadata values are compatible. + + Args: + k (str): Key for the entry. + v1 (object): Value 1. + v2 (object): Value 2. + + Returns: + bool: True if the two entries are compatible, False otherwise. + + """ + if k == 'units': + out = units.are_compatible(v1, v2) + elif k == 'precision': + out = (v1 <= v2) + else: + out = super(CisScalarType, cls).check_meta_compat(k, v1, v2) + return out + + @classmethod + def check_data(cls, data, typedef): + r"""Checks if data matches the provided type definition. + + Args: + obj (object): Object to be tested. + typedef (dict): Type properties that object should be tested + against. + + Returns: + bool: Truth of if the input object is of this type. + + """ + try: + datadef = data2definition(data) + except TypeError: + return False + datadef['typename'] = cls.name + for k, v in typedef.items(): + if not cls.check_meta_compat(k, datadef.get(k, None), v): + return False + return True + + @classmethod + def encode_type(cls, obj): + r"""Encode an object's type definition. + + Args: + obj (object): Object to encode. + + Returns: + dict: Encoded type definition. + + """ + out = data2definition(obj) + return out + + @classmethod + def encode_data(cls, obj, typedef): + r"""Encode an object's data. + + Args: + obj (object): Object to encode. + typedef (dict): Type definition that should be used to encode the + object. + + Returns: + string: Encoded object. + + """ + arr = cls.to_array(obj) + bytes = arr.tobytes() + return bytes + + @classmethod + def decode_data(cls, obj, typedef): + r"""Decode an object. + + Args: + obj (string): Encoded object to decode. + typedef (dict): Type definition that should be used to decode the + object. + + Returns: + object: Decoded object. + + """ + dtype = definition2dtype(typedef) + arr = np.fromstring(obj, dtype=dtype) + if 'shape' in typedef: + arr = arr.reshape(typedef['shape']) + return cls.from_array(arr, typedef['units']) + + @classmethod + def transform_type(cls, obj, typedef=None): + r"""Transform an object based on type info. + + Args: + obj (object): Object to transform. + typedef (dict): Type definition that should be used to transform the + object. + + Returns: + object: Transformed object. + + """ + if typedef is None: + return obj + typedef0 = data2definition(obj) + typedef1 = copy.deepcopy(typedef0) + typedef1.update(**typedef) + dtype = definition2dtype(typedef1) + arr = cls.to_array(obj).astype(dtype) + out = cls.from_array(arr, typedef0['units']) + return units.convert_to(out, typedef1['units']) + + @classmethod + def to_array(cls, obj): + r"""Get np.array representation of the data. + + Args: + obj (object): Object to get array for. + + Returns: + np.ndarray: Array representation of object. + + """ + obj_nounits = units.get_data(obj) + if isinstance(obj_nounits, np.ndarray): + arr = obj_nounits + else: + arr = np.array([obj_nounits], dtype=data2dtype(obj_nounits)) + return arr + + @classmethod + def from_array(cls, arr, unit_str=None): + r"""Get object representation of the data. + + Args: + + Returns: + + """ + if (cls == CisScalarType) and (len(arr.shape) > 0): + out = arr[0] + else: + out = arr + if unit_str is not None: + out = units.add_units(arr, unit_str) + return out + + +@register_type +class Cis1DArrayType(CisScalarType): + r"""Type associated with a scalar.""" + + name = '1darray' + description = 'A 1D array with or without units.' + properties = dict(CisScalarType.properties, + length={ + 'description': 'Number of elements in the 1D array.', + 'type': 'number', + 'minimum': 1}) + metadata_properties = CisScalarType.metadata_properties + ['length'] + + @classmethod + def encode_type(cls, obj): + r"""Encode an object's type definition. + + Args: + obj (object): Object to encode. + + Returns: + dict: Encoded type definition. + + """ + out = super(Cis1DArrayType, cls).encode_type(obj) + out['length'] = len(obj) + return out + + +@register_type +class CisNDArrayType(CisScalarType): + r"""Type associated with a scalar.""" + + name = 'ndarray' + description = 'An ND array with or without units.' + properties = dict(CisScalarType.properties, + shape={ + 'description': 'Shape of the ND array in each dimension.', + 'type': 'array', + 'items': { + 'type': 'integer', + 'minimum': 1}}) + metadata_properties = CisScalarType.metadata_properties + ['shape'] + + @classmethod + def encode_type(cls, obj): + r"""Encode an object's type definition. + + Args: + obj (object): Object to encode. + + Returns: + dict: Encoded type definition. + + """ + out = super(CisNDArrayType, cls).encode_type(obj) + out['shape'] = list(obj.shape) + return out diff --git a/cis_interface/datatypes/CisBaseType.py b/cis_interface/datatypes/CisBaseType.py new file mode 100644 index 000000000..04a28d21c --- /dev/null +++ b/cis_interface/datatypes/CisBaseType.py @@ -0,0 +1,325 @@ +import copy +import json +import jsonschema +from cis_interface import backwards + + +class CisBaseType(object): + r"""Base type that should be subclassed by user defined types. Attributes + should be overwritten to match the type. + + Arguments: + **kwargs: All keyword arguments are assumed to be type definition + properties which will be used to validate serialized/deserialized + messages. + + Attributes: + name (str): Name of the type for use in YAML files & form options. + description (str): A short description of the type. + properties (dict): JSON schema definitions for properties of the + type. + definition_properties (list): Type properties that are required for YAML + or form entries specifying the type. These will also be used to + validate type definitions. + metadata_properties (list): Type properties that are required for + deserializing instances of the type that have been serialized. + data_schema (dict): JSON schema for validating a JSON friendly + representation of the type. + + """ + + name = 'base' + description = 'A generic base type for users to build on.' + properties = {} + definition_properties = [] + metadata_properties = [] + data_schema = {'description': 'JSON friendly version of type instance.', + 'type': 'string'} + _empty_msg = {} + sep = backwards.unicode2bytes(':CIS_TAG:') + + def __init__(self, **typedef): + typedef.setdefault('typename', self.name) + self.__class__.validate_definition(typedef) + self._typedef = typedef + + # Methods to be overridden by subclasses + @classmethod + def check_data(cls, data, typedef): + r"""Checks if data matches the provided type definition. + + Args: + obj (object): Object to be tested. + typedef (dict): Type properties that object should be tested + against. + + Returns: + bool: Truth of if the input object is of this type. + + """ + raise NotImplementedError("Method must be overridden by the subclass.") + + @classmethod + def encode_type(cls, obj): + r"""Encode an object's type definition. + + Args: + obj (object): Object to encode. + + Returns: + dict: Encoded type definition. + + """ + raise NotImplementedError("Method must be overridden by the subclass.") + + @classmethod + def encode_data(cls, obj, typedef): + r"""Encode an object's data. + + Args: + obj (object): Object to encode. + typedef (dict): Type definition that should be used to encode the + object. + + Returns: + string: Encoded object. + + """ + raise NotImplementedError("Method must be overridden by the subclass.") + + @classmethod + def decode_data(cls, obj, typedef): + r"""Decode an object. + + Args: + obj (string): Encoded object to decode. + typedef (dict): Type definition that should be used to decode the + object. + + Returns: + object: Decoded object. + + """ + raise NotImplementedError("Method must be overridden by the subclass.") + + @classmethod + def transform_type(cls, obj, typedef=None): + r"""Transform an object based on type info. + + Args: + obj (object): Object to transform. + typedef (dict): Type definition that should be used to transform the + object. + + Returns: + object: Transformed object. + + """ + raise NotImplementedError("Method must be overridden by the subclass.") + + # Methods not to be modified by subclasses + @classmethod + def definition_schema(cls): + r"""JSON schema for validating a type definition.""" + out = {"$schema": "http://json-schema.org/draft-07/schema#", + 'title': cls.name, + 'description': cls.description, + 'type': 'object', + 'required': copy.deepcopy(cls.definition_properties), + 'properties': copy.deepcopy(cls.properties)} + out['required'] += ['typename'] + out['properties']['typename'] = { + 'description': 'Name of the type encoded.', + 'type': 'string', + 'enum': [cls.name]} + return out + + @classmethod + def metadata_schema(cls): + r"""JSON schema for validating a JSON serialization of the type.""" + out = cls.definition_schema() + out['required'] = copy.deepcopy(cls.metadata_properties) + out['required'] += ['typename'] + # out['required'] += ['data'] + # out['properties']['data'] = copy.deepcopy(cls.data_schema) + return out + + @classmethod + def validate_metadata(cls, obj): + r"""Validates an encoded object. + + Args: + obj (string): Encoded object to validate. + + """ + jsonschema.validate(obj, cls.metadata_schema()) + + @classmethod + def validate_definition(cls, obj): + r"""Validates a type definition. + + Args: + obj (object): Type definition to validate. + + """ + jsonschema.validate(obj, cls.definition_schema()) + + @classmethod + def check_meta_compat(cls, k, v1, v2): + r"""Check that two metadata values are compatible. + + Args: + k (str): Key for the entry. + v1 (object): Value 1. + v2 (object): Value 2. + + Returns: + bool: True if the two entries are compatible going from v1 to v2, + False otherwise. + + """ + return (v1 == v2) + + @classmethod + def check_encoded(cls, metadata, typedef=None): + r"""Checks if the metadata for an encoded object matches the type + definition. + + Args: + metadata (dict): Meta data to be tested. + typedef (dict, optional): Type properties that object should + be tested against. Defaults to None and object may have + any values for the type properties (so long as they match + the schema. + + Returns: + bool: True if the metadata matches the type definition, False + otherwise. + + """ + try: + cls.validate_metadata(metadata) + except jsonschema.exceptions.ValidationError: + return False + if typedef is not None: + try: + cls.validate_definition(typedef) + except jsonschema.exceptions.ValidationError: + return False + for k, v in typedef.items(): + if not cls.check_meta_compat(k, metadata.get(k, None), v): + return False + return True + + @classmethod + def check_decoded(cls, obj, typedef=None): + r"""Checks if an object is of the this type. + + Args: + obj (object): Object to be tested. + typedef (dict): Type properties that object should be tested + against. If None, this will always return True. + + Returns: + bool: Truth of if the input object is of this type. + + """ + if typedef is None: + return True + try: + cls.validate_definition(typedef) + except jsonschema.exceptions.ValidationError: + return False + return cls.check_data(obj, typedef) + + @classmethod + def encode(cls, obj, typedef=None): + r"""Encode an object. + + Args: + obj (object): Object to encode. + typedef (dict, optional): Type properties that object should + be tested against. Defaults to None and object may have + any values for the type properties (so long as they match + the schema. + + Returns: + tuple(dict, bytes): Encoded object with type definition and data + serialized to bytes. + + """ + if not cls.check_decoded(obj, typedef): + raise ValueError("Object is not correct type for encoding.") + obj_t = cls.transform_type(obj, typedef) + metadata = cls.encode_type(obj_t) + metadata['typename'] = cls.name + data = cls.encode_data(obj_t, metadata) + if not cls.check_encoded(metadata, typedef): + raise ValueError("Object was not encoded correctly.") + if not isinstance(data, backwards.bytes_type): + raise TypeError("Encoded data must be of type %s, not %s" % ( + backwards.bytes_type, type(data))) + return metadata, data + + @classmethod + def decode(cls, metadata, data, typedef=None): + r"""Decode an object. + + Args: + metadata (dict): Meta data describing the data. + data (bytes): Encoded data. + typedef (dict, optional): Type properties that decoded object should + be tested against. Defaults to None and object may have any + values for the type properties (so long as they match the schema. + + Returns: + object: Decoded object. + + """ + if not cls.check_encoded(metadata, typedef): + raise ValueError("Metadata does not match type definition.") + out = cls.decode_data(data, metadata) + if not cls.check_decoded(out, typedef): + raise ValueError("Object was not decoded correctly.") + out = cls.transform_type(out, typedef) + return out + + def serialize(self, obj, **kwargs): + r"""Serialize a message. + + Args: + obj (object): Python object to be formatted. + + Returns: + bytes, str: Serialized message. + + """ + metadata, data = self.__class__.encode(obj, self._typedef) + metadata.update(**kwargs) + msg = backwards.unicode2bytes(json.dumps(metadata, sort_keys=True)) + msg += self.sep + msg += data + return msg + + def deserialize(self, msg): + r"""Deserialize a message. + + Args: + msg (str, bytes): Message to be deserialized. + + Returns: + tuple(obj, dict): Deserialized message and header information. + + Raises: + TypeError: If msg is not bytes type (str on Python 2). + + """ + if not isinstance(msg, backwards.bytes_type): + raise TypeError("Message to be deserialized is not bytes type.") + if len(msg) == 0: + obj = self._empty_msg + else: + metadata, data = msg.split(self.sep) + metadata = json.loads(backwards.bytes2unicode(metadata)) + obj = self.__class__.decode(metadata, data, self._typedef) + return obj diff --git a/cis_interface/datatypes/__init__.py b/cis_interface/datatypes/__init__.py new file mode 100644 index 000000000..92ccbc304 --- /dev/null +++ b/cis_interface/datatypes/__init__.py @@ -0,0 +1,53 @@ +import os +import glob +import importlib + + +_type_registry = {} + + +def register_type(type_class): + r"""Register a type class, recording methods for encoding/decoding. + + Args: + type_class (class): Class to be registered. + + """ + global _type_registry + type_name = type_class.name + if type_name in _type_registry: + raise ValueError("Type %s already registered." % type_name) + type_class._datatype = type_name + type_class._schema_type = 'type' + type_class._schema = type_class.definition_schema() + # TODO: Enable schema tracking once ported to jsonschema + # from cis_interface.schema import register_component + # register_component(type_class) + _type_registry[type_name] = type_class + return type_class + + +def import_all_types(): + r"""Import all types to ensure they are registered.""" + for x in glob.glob(os.path.join(os.path.dirname(__file__), '*.py')): + if not x.startswith('__'): + type_mod = os.path.basename(x)[:-3] + importlib.import_module('cis_interface.datatypes.%s' % type_mod) + + +import_all_types() + + +def get_type_class(type_name): + r"""Return a type class given it's name. + + Args: + type_name (str): Name of type class. + + Returns: + class: Type class. + + """ + if type_name not in _type_registry: + raise ValueError("Class for type %s could not be found." % type_name) + return _type_registry[type_name] diff --git a/cis_interface/datatypes/tests/__init__.py b/cis_interface/datatypes/tests/__init__.py new file mode 100644 index 000000000..35336cca7 --- /dev/null +++ b/cis_interface/datatypes/tests/__init__.py @@ -0,0 +1,16 @@ +import nose.tools as nt +from cis_interface import datatypes +from cis_interface.datatypes.CisArrayType import CisScalarType + + +def test_get_type_class(): + r"""Test get_type_class.""" + valid_types = ['scalar', '1darray', 'ndarray'] + for v in valid_types: + datatypes.get_type_class(v) + nt.assert_raises(ValueError, datatypes.get_type_class, 'invalid') + + +def test_error_duplicate(): + r"""Test error in register_type for duplicate.""" + nt.assert_raises(ValueError, datatypes.register_type, CisScalarType) diff --git a/cis_interface/datatypes/tests/test_CisArrayType.py b/cis_interface/datatypes/tests/test_CisArrayType.py new file mode 100644 index 000000000..37e1c7861 --- /dev/null +++ b/cis_interface/datatypes/tests/test_CisArrayType.py @@ -0,0 +1,131 @@ +import copy +import numpy as np +from cis_interface import units +from cis_interface.datatypes.tests import test_CisBaseType as parent + + +class TestCisScalarType(parent.TestCisBaseType): + r"""Test class for CisScalarType class with float.""" + _mod = 'CisArrayType' + _cls = 'CisScalarType' + _prec = 32 + _type = 'float' + _shape = 1 + _array_contents = None + + def __init__(self, *args, **kwargs): + super(TestCisScalarType, self).__init__(*args, **kwargs) + self._typedef = {'type': self._type} + if self._type == 'string': + dtype = 'S%d' % (self._prec / 8) + else: + dtype = '%s%d' % (self._type, self._prec) + if self._array_contents is None: + self._array = np.ones(self._shape, dtype) + else: + self._array = np.array(self._array_contents, dtype) + if self._cls == 'CisScalarType': + self._value = self._array[0] + else: + self._value = self._array + self._valid_encoded = [{'typename': self.import_cls.name, + 'type': self._type, + 'precision': self._prec, + 'units': '', + 'data': self._value.tobytes()}] + self._valid_decoded = [self._value] + if self._type == 'string': + new_dtype = 'S%d' % (self._prec * 2 / 8) + else: + new_dtype = '%s%d' % (self._type, self._prec * 2) + prec_array = self._array.astype(new_dtype) + if self._cls == 'CisScalarType': + self._prec_value = prec_array[0] + else: + self._prec_value = prec_array + self._compatible_objects = [ + (self._value, self._value, None), + (self._value, self._prec_value, {'type': self._type, + 'precision': self._prec * 2})] + + +class TestCisScalarType_int(TestCisScalarType): + r"""Test class for CisScalarType class with int.""" + _type = 'int' + + +class TestCisScalarType_uint(TestCisScalarType): + r"""Test class for CisScalarType class with uint.""" + _type = 'uint' + + +class TestCisScalarType_complex(TestCisScalarType): + r"""Test class for CisScalarType class with complex.""" + _type = 'complex' + _prec = 64 + + +class TestCisScalarType_string(TestCisScalarType): + r"""Test class for CisScalarType class with string.""" + _type = 'string' + _array_contents = ['one', 'two', 'three'] + + +class TestCisScalarType_prec(TestCisScalarType): + r"""Test class for CisScalarType class with precision.""" + + def __init__(self, *args, **kwargs): + super(TestCisScalarType_prec, self).__init__(*args, **kwargs) + self._typedef['precision'] = self._prec + self._valid_encoded.append(copy.deepcopy(self._valid_encoded[0])) + self._invalid_encoded[-1]['precision'] = self._prec / 2 + # Version with incorrect precision + self._invalid_encoded.append(copy.deepcopy(self._valid_encoded[0])) + self._invalid_encoded[-1]['precision'] = self._prec * 2 + self._invalid_decoded.append(self._prec_value) + + +class TestCisScalarType_units(TestCisScalarType): + r"""Test class for CisScalarType class with units.""" + + def __init__(self, *args, **kwargs): + super(TestCisScalarType_units, self).__init__(*args, **kwargs) + self._typedef['units'] = 'cm' + # self._valid_encoded[-1]['units'] = 'cm' + self._valid_encoded.append(copy.deepcopy(self._valid_encoded[0])) + self._valid_encoded[-1]['units'] = 'cm' + self._valid_encoded.append(copy.deepcopy(self._valid_encoded[0])) + self._valid_encoded[-1]['units'] = 'm' + self._valid_decoded.append(copy.deepcopy(self._valid_decoded[0])) + self._valid_decoded[-1] = units.add_units(self._valid_decoded[-1], 'm') + # Version with incorrect units + self._invalid_encoded.append(copy.deepcopy(self._valid_encoded[0])) + self._invalid_encoded[-1]['units'] = 's' + + +class TestCis1DArrayType(TestCisScalarType): + r"""Test class for CisArrayType class.""" + _cls = 'Cis1DArrayType' + _shape = 10 + + def __init__(self, *args, **kwargs): + super(TestCis1DArrayType, self).__init__(*args, **kwargs) + self._valid_encoded[0]['length'] = len(self._array) + + def assert_result_equal(self, x, y): + r"""Assert that serialized/deserialized objects equal.""" + np.testing.assert_array_equal(x, y) + + +class TestCisNDArrayType(TestCisScalarType): + r"""Test class for CisArrayType class with 2D array.""" + _cls = 'CisNDArrayType' + _shape = (4, 5) + + def __init__(self, *args, **kwargs): + super(TestCisNDArrayType, self).__init__(*args, **kwargs) + self._valid_encoded[0]['shape'] = list(self._array.shape) + + def assert_result_equal(self, x, y): + r"""Assert that serialized/deserialized objects equal.""" + np.testing.assert_array_equal(x, y) diff --git a/cis_interface/datatypes/tests/test_CisBaseType.py b/cis_interface/datatypes/tests/test_CisBaseType.py new file mode 100644 index 000000000..c210074cb --- /dev/null +++ b/cis_interface/datatypes/tests/test_CisBaseType.py @@ -0,0 +1,211 @@ +import nose.tools as nt +import jsonschema +from cis_interface import backwards +from cis_interface.datatypes import CisBaseType +from cis_interface.tests import CisTestClassInfo + + +class TestCisBaseType(CisTestClassInfo): + r"""Test class for CisBaseType class.""" + + _mod = 'CisBaseType' + _cls = 'CisBaseType' + + def __init__(self, *args, **kwargs): + super(TestCisBaseType, self).__init__(*args, **kwargs) + self._empty_msg = backwards.unicode2bytes('') + self._typedef = {} + self._valid_encoded = [{'typename': self.import_cls.name, + 'data': 'nothing'}] + self._invalid_encoded = [{}] + self._valid_decoded = ['nothing'] + self._invalid_decoded = [None] + self._compatible_objects = [] + + @property + def mod(self): + r"""str: Absolute name of module containing class to be tested.""" + return 'cis_interface.datatypes.%s' % self._mod + + @property + def typedef(self): + r"""dict: Type definition.""" + out = self._typedef + out['typename'] = self.import_cls.name + return out + + @property + def inst_kwargs(self): + r"""dict: Keyword arguments for creating a class instance.""" + return self._typedef + + def assert_result_equal(self, x, y): + r"""Assert that serialized/deserialized objects equal.""" + nt.assert_equal(x, y) + + def test_definition_schema(self): + r"""Test definition schema.""" + s = self.import_cls.definition_schema() + # jsonschema.Draft3Validator.check_schema(s) + jsonschema.Draft4Validator.check_schema(s) + + def test_metadata_schema(self): + r"""Test metadata schema.""" + s = self.import_cls.metadata_schema() + # jsonschema.Draft3Validator.check_schema(s) + jsonschema.Draft4Validator.check_schema(s) + + def test_encode_data(self): + r"""Test encode/decode data & type.""" + if self._cls == 'CisBaseType': + for x in self._valid_decoded: + nt.assert_raises(NotImplementedError, self.import_cls.encode_type, x) + nt.assert_raises(NotImplementedError, self.import_cls.encode_data, + x, self.typedef) + for x in self._valid_encoded: + nt.assert_raises(NotImplementedError, self.import_cls.decode_data, + x['data'], self.typedef) + else: + for x in self._valid_decoded: + y = self.import_cls.encode_type(x) + z = self.import_cls.encode_data(x, y) + x2 = self.import_cls.decode_data(z, y) + self.assert_result_equal(x2, x) + + def test_check_encoded(self): + r"""Test check_encoded.""" + # Test invalid for incorrect typedef + nt.assert_equal(self.import_cls.check_encoded(self._valid_encoded[0], + {}), False) + # Test valid + for x in self._valid_encoded: + nt.assert_equal(self.import_cls.check_encoded(x, self.typedef), True) + # Test invalid + for x in self._invalid_encoded: + nt.assert_equal(self.import_cls.check_encoded(x, self.typedef), False) + + def test_check_decoded(self): + r"""Test check_decoded.""" + # Test always valid without typedef + nt.assert_equal(self.import_cls.check_decoded(None), True) + # Test always invalid with incorrect typedef + nt.assert_equal(self.import_cls.check_decoded(None, {}), False) + # Not implemented for base class + if self._cls == 'CisBaseType': + for x in self._valid_decoded: + nt.assert_raises(NotImplementedError, self.import_cls.check_decoded, + x, self.typedef) + else: + # Test valid + for x in self._valid_decoded: + nt.assert_equal(self.import_cls.check_decoded(x, self.typedef), True) + # Test invalid + for x in self._invalid_decoded: + nt.assert_equal(self.import_cls.check_decoded(x, self.typedef), False) + + def test_encode_errors(self): + r"""Test error on encode.""" + if self._cls == 'CisBaseType': + nt.assert_raises(NotImplementedError, self.import_cls.encode, + self._invalid_decoded[0], self.typedef) + else: + nt.assert_raises(ValueError, self.import_cls.encode, + self._invalid_decoded[0], self.typedef) + + def test_decode_errors(self): + r"""Test error on decode.""" + nt.assert_raises(ValueError, self.import_cls.decode, + self._invalid_encoded[0], self.typedef) + + def test_transform_type(self): + r"""Test transform_type.""" + if self._cls == 'CisBaseType': + nt.assert_raises(NotImplementedError, self.import_cls.transform_type, + None, None) + else: + for x, y, typedef in self._compatible_objects: + z = self.import_cls.transform_type(x, typedef) + self.assert_result_equal(z, y) + + def test_serialize(self): + r"""Test serialize/deserialize.""" + if self._cls == 'CisBaseType': + for x in self._valid_decoded: + nt.assert_raises(NotImplementedError, self.instance.serialize, x) + else: + for x in self._valid_decoded: + msg = self.instance.serialize(x) + y = self.instance.deserialize(msg) + self.assert_result_equal(y, x) + + def test_deserialize_error(self): + r"""Test error when deserializing message that is not bytes.""" + nt.assert_raises(TypeError, self.instance.deserialize, self) + + def test_deserialize_empty(self): + r"""Test call for empty string.""" + out = self.instance.deserialize(self._empty_msg) + self.assert_result_equal(out, self.instance._empty_msg) + # nt.assert_equal(out, self.instance._empty_msg) + + +class CisErrorType(CisBaseType.CisBaseType): + r"""Class with impropert user defined methods.""" + + _check_encoded = True + _check_decoded = True + + @classmethod + def check_encoded(cls, metadata, typedef=None): + r"""Return constant.""" + return cls._check_encoded + + @classmethod + def check_decoded(cls, obj, typedef=None): + r"""Return constant.""" + return cls._check_decoded + + @classmethod + def encode_type(cls, obj): + r"""Encode type.""" + return {} + + @classmethod + def encode_data(cls, obj, typedef): + r"""Encode data.""" + return obj + + @classmethod + def decode_data(cls, obj, typedef): + r"""Decode data.""" + return obj + + @classmethod + def transform_type(cls, obj, typedef=None): + r"""Transform an object based on type info.""" + return obj + + +class CisErrorType_encode(CisErrorType): + _check_encoded = False + + +class CisErrorType_decode(CisErrorType): + _check_decoded = False + + +def test_encode_error_encoded(): + r"""Test error in encode for failed encode_data.""" + nt.assert_raises(ValueError, CisErrorType_encode.encode, + backwards.unicode2bytes('')) + + +def test_decode_error_decoded(): + r"""Test error in decode for failed decode_data.""" + nt.assert_raises(ValueError, CisErrorType_decode.decode, + {}, backwards.unicode2bytes('')) + + +def test_encode_error_bytes(): + r"""Test error in encode for encode that dosn't produce bytes.""" + nt.assert_raises(TypeError, CisErrorType.encode, None) diff --git a/cis_interface/drivers/ConnectionDriver.py b/cis_interface/drivers/ConnectionDriver.py index 0206e9fc9..ae2d6d071 100755 --- a/cis_interface/drivers/ConnectionDriver.py +++ b/cis_interface/drivers/ConnectionDriver.py @@ -59,13 +59,13 @@ class ConnectionDriver(Driver): _schema = {'input': {'type': ['string', 'list'], 'required': True, 'schema': {'type': 'string'}, 'excludes': 'input_file'}, - 'input_file': {'type': 'dict', 'required': True, - 'excludes': 'input'}, + 'input_file': {'required': True, 'type': 'dict', + 'excludes': 'input', 'schema': 'file'}, 'output': {'type': ['string', 'list'], 'required': True, 'schema': {'type': 'string'}, 'excludes': 'output_file'}, - 'output_file': {'type': 'dict', 'required': True, - 'excludes': 'output'}, + 'output_file': {'required': True, 'type': 'dict', + 'excludes': 'output', 'schema': 'file'}, 'translator': {'type': ['function', 'list'], 'schema': {'type': 'function'}, 'required': False}, diff --git a/cis_interface/drivers/ModelDriver.py b/cis_interface/drivers/ModelDriver.py index 4c85da8ae..632d9da79 100755 --- a/cis_interface/drivers/ModelDriver.py +++ b/cis_interface/drivers/ModelDriver.py @@ -61,9 +61,15 @@ class ModelDriver(Driver): _schema_type = 'model' _schema = {'name': {'type': 'string', 'required': True}, 'language': {'type': 'string', 'required': True}, - 'working_dir': {'type': 'string', 'required': True}, 'args': {'type': ['list', 'string'], 'required': True, 'schema': {'type': 'string'}}, + 'inputs': {'type': 'list', 'required': False, + 'schema': {'type': 'dict', + 'schema': 'comm'}}, + 'outputs': {'type': 'list', 'required': False, + 'schema': {'type': 'dict', + 'schema': 'comm'}}, + 'working_dir': {'type': 'string', 'required': True}, 'is_server': {'type': 'boolean', 'required': False}, 'client_of': {'type': 'list', 'required': False, 'schema': {'type': 'string'}}, diff --git a/cis_interface/schema.py b/cis_interface/schema.py index 92c2f26c1..ea77be61c 100644 --- a/cis_interface/schema.py +++ b/cis_interface/schema.py @@ -8,6 +8,7 @@ import collections from cis_interface.drivers import import_all_drivers from cis_interface.communication import import_all_comms +from cis_interface.datatypes import import_all_types _schema_fname = os.path.abspath(os.path.join( @@ -71,6 +72,7 @@ def init_registry(): if not _registry_complete: import_all_drivers() import_all_comms() + import_all_types() _registry_complete = True @@ -179,7 +181,7 @@ def str_to_function(value): return out -class SchemaValidator(cerberus.Validator): +class CisSchemaValidator(cerberus.Validator): r"""Class for validating the schema.""" types_mapping = cerberus.Validator.types_mapping.copy() @@ -187,7 +189,7 @@ class SchemaValidator(cerberus.Validator): cis_type_order = ['list', 'string', 'integer', 'boolean', 'function'] def _resolve_rules_set(self, *args, **kwargs): - rules = super(SchemaValidator, self)._resolve_rules_set(*args, **kwargs) + rules = super(CisSchemaValidator, self)._resolve_rules_set(*args, **kwargs) if isinstance(rules, collections.Mapping): rules = self._add_coerce(rules) return rules @@ -241,19 +243,45 @@ class ComponentSchema(dict): Args: schema_type (str): The name of the component. - subtype_attr (str, optional): The attribute that should be used to - log subtypes. Defaults to None. + schema_registry (SchemaRegistry, optional): Registry of schemas + that this schema is dependent on. **kwargs: Additional keyword arguments are entries in the component schema. """ + _subtype_keys = {'model': 'language', 'comm': 'commtype', + 'file': 'filetype'} # , 'type': 'datatype'} - def __init__(self, schema_type, subtype_attr=None, **kwargs): + def __init__(self, schema_type, schema_registry=None, **kwargs): + self.schema_registry = schema_registry self._schema_type = schema_type - self._subtype_attr = subtype_attr + self._subtype_key = self._subtype_keys.get(schema_type, None) + if self._subtype_key is not None: + self._subtype_attr = '_' + self._subtype_key + else: + self._subtype_attr = None self._schema_subtypes = {} super(ComponentSchema, self).__init__(**kwargs) + @classmethod + def from_registry(cls, schema_type, schema_classes, **kwargs): + r"""Construct a ComponentSchema from a registry entry. + + Args: + schema_type (str): Name of component type to build. + schema_classes (list): List of classes for the component type. + **kwargs: Additional keyword arguments are passed to the class + __init__ method. + + Returns: + ComponentSchema: Schema with information from classes. + + """ + out = cls(schema_type, **kwargs) + for x in schema_classes: + out.append(x) + return out + @property def class2subtype(self): r"""dict: Mapping from class to list of subtypes.""" @@ -294,14 +322,24 @@ def append(self, comp_cls, subtype=None): assert(comp_cls._schema_type == self._schema_type) name = comp_cls.__name__ rule = comp_cls._schema - if (subtype is None) and (self._subtype_attr is not None): + # Append subtype + if self._schema_type == 'connection': + subtype = (comp_cls._icomm_type, comp_cls._ocomm_type, comp_cls.direction()) + elif (subtype is None) and (self._subtype_attr is not None): subtype = getattr(comp_cls, self._subtype_attr) if subtype is not None: if not isinstance(subtype, list): - self._schema_subtypes[name] = [subtype] + subtype_list = [subtype] else: - self._schema_subtypes[name] = subtype + subtype_list = subtype + self._schema_subtypes[name] = subtype_list + # Add rules self.append_rules(rule) + # Add allowed subtypes + if (self._subtype_key is not None) and (self._subtype_key in self): + self[self._subtype_key]['allowed'] = self.subtypes + # Verify that the schema is valid + CisSchemaValidator(self, schema_registry=self.schema_registry) def append_rules(self, new): r"""Add rules from new class's schema to this one. @@ -310,19 +348,20 @@ def append_rules(self, new): new (dict): New schema to add. """ + old = self for k, v in new.items(): - if k not in self: - self[k] = v + if k not in old: + old[k] = v else: diff = [] for ik in v.keys(): - if (ik not in self[k]) or (v[ik] != self[k][ik]): + if (ik not in old[k]) or (v[ik] != old[k][ik]): diff.append(ik) if (len(diff) == 0): pass elif (len(diff) == 1) and (diff[0] == 'dependencies'): alldeps = {} - deps = [self[k]['dependencies'], v['dependencies']] + deps = [old[k]['dependencies'], v['dependencies']] for idep in deps: for ik, iv in idep.items(): if ik not in alldeps: @@ -335,16 +374,16 @@ def append_rules(self, new): alldeps[ik] = sorted(list(set(alldeps[ik]))) vcopy = copy.deepcopy(v) vcopy['dependencies'] = alldeps - self[k].update(**vcopy) + old[k].update(**vcopy) else: # pragma: debug print('Existing:') - pprint.pprint(self[k]) + pprint.pprint(old[k]) print('New:') pprint.pprint(v) raise ValueError("Cannot merge schemas.") -class SchemaRegistry(dict): +class SchemaRegistry(cerberus.schema.SchemaRegistry): r"""Registry of schema's for different integration components. Args: @@ -360,42 +399,43 @@ class SchemaRegistry(dict): """ _component_attr = ['_schema_subtypes', '_subtype_attr'] - _subtype_attr = {'model': '_language', 'comm': '_commtype', - 'file': '_filetype'} def __init__(self, registry=None, required=None): + super(SchemaRegistry, self).__init__() comp = {} if registry is not None: if required is None: + # required = ['type', 'comm', 'file', 'model', 'connection'] required = ['comm', 'file', 'model', 'connection'] for k in required: if k not in registry: raise ValueError("Component %s required." % k) + # Register dummy schemas for each component + for k in registry.keys(): + self[k] = {'hold': {'type': 'string'}} + # Create schemas for each component for k in registry.keys(): if k not in comp: - isubtype_attr = self._subtype_attr.get(k, None) - comp[k] = ComponentSchema(k, subtype_attr=isubtype_attr) - for x in registry[k]: - subtype = None - if k == 'connection': - subtype = (x._icomm_type, x._ocomm_type, x.direction()) - comp[k].append(x, subtype=subtype) - SchemaValidator(comp[k]) - # Add lists of required properties - comp['file']['filetype']['allowed'] = sorted(comp['file'].subtypes) - comp['model']['language']['allowed'] = sorted(comp['model'].subtypes) - comp['model']['inputs'] = {'type': 'list', 'required': False, - 'schema': {'type': 'dict', - 'schema': comp['comm']}} - comp['model']['outputs'] = {'type': 'list', 'required': False, - 'schema': {'type': 'dict', - 'schema': comp['comm']}} - comp['connection']['input_file']['schema'] = comp['file'] - comp['connection']['output_file']['schema'] = comp['file'] + comp[k] = ComponentSchema.from_registry(k, registry[k], + schema_registry=self) + self[k] = comp[k] # Make sure final versions are valid schemas for x in comp.values(): - SchemaValidator(x) - super(SchemaRegistry, self).__init__(**comp) + CisSchemaValidator(x, schema_registry=self) + + def __getitem__(self, k): + return self.get(k) + + def __setitem__(self, k, v): + return self.add(k, v) + + def keys(self): + return self.all().keys() + + def __eq__(self, other): + if not hasattr(other, 'all'): + return False + return (self.all() == other.all()) @classmethod def from_file(cls, fname): @@ -419,6 +459,9 @@ def load(self, fname): with open(fname, 'r') as f: contents = f.read() schema = yaml.load(contents, Loader=SchemaLoader) + if schema is None: + raise Exception("Failed to load schema from %s" % fname) + comp_list = [] for k, v in schema.items(): is_attr = False for iattr in self._component_attr: @@ -427,11 +470,18 @@ def load(self, fname): break if is_attr: continue - self[k] = ComponentSchema(k, **v) + comp_list.append(k) + # Add dummy schemas to registry + for k in comp_list: + self[k] = {'hold': {'type': 'string'}} + # Create components + for k in comp_list: + icomp = ComponentSchema(k, schema_registry=self, **schema[k]) for iattr in self._component_attr: kattr = k + iattr if kattr in schema: - setattr(self[k], iattr, schema[kattr]) + setattr(icomp, iattr, schema[kattr]) + self[k] = icomp def save(self, fname): r"""Save the schema to a file. @@ -478,13 +528,11 @@ def conntype2class(self): @property def validator(self): r"""Compose complete schema for parsing yaml.""" - out = {'models': {'type': 'list', - 'schema': {'type': 'dict', - 'schema': self['model']}}, - 'connections': {'type': 'list', - 'schema': {'type': 'dict', - 'schema': self['connection']}}} - return SchemaValidator(out) + out = {'models': {'type': 'list', 'schema': {'type': 'dict', + 'schema': 'model'}}, + 'connections': {'type': 'list', 'schema': {'type': 'dict', + 'schema': 'connection'}}} + return CisSchemaValidator(out, schema_registry=self) class SchemaLoader(yaml.SafeLoader): @@ -508,11 +556,13 @@ def represent_ComponentSchema(self, data): return self.represent_data(out) def represent_SchemaRegistry(self, data): - out = dict(**data) - for k in data.keys(): + out = dict(**data.all()) + comp_list = [k for k in out.keys()] + for k in comp_list: for iattr in data._component_attr: - if getattr(data[k], iattr, None): - out[k + iattr] = getattr(data[k], iattr) + icomp = data[k] + if getattr(icomp, iattr, None): + out[k + iattr] = getattr(icomp, iattr) return self.represent_data(out) diff --git a/cis_interface/tests/__init__.py b/cis_interface/tests/__init__.py index 2dfd75452..6c1ab9c7f 100644 --- a/cis_interface/tests/__init__.py +++ b/cis_interface/tests/__init__.py @@ -271,9 +271,10 @@ def shortDescription(self): class CisTestClass(CisTestBase): r"""Test class for a CisClass.""" + _mod = None + _cls = None + def __init__(self, *args, **kwargs): - self._mod = None - self._cls = None self._inst_args = list() self._inst_kwargs = dict() super(CisTestClass, self).__init__(*args, **kwargs) diff --git a/cis_interface/tests/test_schema.py b/cis_interface/tests/test_schema.py index 64447b37e..a8816f286 100644 --- a/cis_interface/tests/test_schema.py +++ b/cis_interface/tests/test_schema.py @@ -23,9 +23,9 @@ def test_str_to_function(): '%s:invalid' % __name__) -def test_SchemaValidator(): +def test_CisSchemaValidator(): r"""Test schema validator.""" - v = schema.SchemaValidator() + v = schema.CisSchemaValidator() test_vals = { 'string': [('s', 's'), (1, '1'), (1.0, '1.0'), (['1', 1], ['1', '1']), @@ -69,6 +69,7 @@ def test_create_schema(): # Test saving/loading schema s0 = schema.create_schema() s0.save(fname) + assert(s0 is not None) assert(os.path.isfile(fname)) s1 = schema.get_schema(fname) nt.assert_equal(s1, s0) diff --git a/cis_interface/tests/test_units.py b/cis_interface/tests/test_units.py index 98d865549..2bc71768b 100644 --- a/cis_interface/tests/test_units.py +++ b/cis_interface/tests/test_units.py @@ -1,12 +1,95 @@ +import nose.tools as nt +import numpy as np +from cis_interface.tests import CisTestBase from cis_interface import units -def test_is_unit(): - r"""Test is_unit.""" - assert(units.is_unit('n/a')) - assert(units.is_unit('')) - assert(units.is_unit('cm/s**2')) - assert(units.is_unit('cm/s^2')) - assert(units.is_unit('umol')) - assert(units.is_unit('mmol')) - assert(not units.is_unit('invalid')) +class TestPint(CisTestBase): + r"""Tests for using pint for units.""" + _unit_package = 'pint' + + def setup(self, *args, **kwargs): + r"""Set use_unyt for tests.""" + self._old_use_unyt = units._use_unyt + if self._unit_package == 'unyt': + units._use_unyt = True + else: + units._use_unyt = False + self._vars_nounits = [1.0, np.zeros(5), int(1)] + self._vars_units = [units.add_units(v, 'cm') for v in self._vars_nounits] + super(TestPint, self).setup(*args, **kwargs) + + def teardown(self, *args, **kwargs): + r"""Reset use_unyt to default.""" + units._use_unyt = self._old_use_unyt + super(TestPint, self).teardown(*args, **kwargs) + + def test_has_units(self): + r"""Test has_units.""" + for v in self._vars_nounits: # + ['string']: + assert(not units.has_units(v)) + for v in self._vars_units: + assert(units.has_units(v)) + + def test_get_data(self): + r"""Test get_data.""" + for v in self._vars_nounits: + np.testing.assert_array_equal(units.get_data(v), v) + for vno, v in zip(self._vars_nounits, self._vars_units): + np.testing.assert_array_equal(units.get_data(v), np.array(vno)) + + def test_get_units(self): + r"""Test get_units.""" + for v in self._vars_nounits: + nt.assert_equal(units.get_units(v), '') + for v in self._vars_units: + nt.assert_equal(units.get_units(v), str(units.as_unit('cm').units)) + + def test_add_units(self): + r"""Test add_units.""" + for v in self._vars_nounits: + x = units.add_units(v, 'cm') + assert(units.has_units(x)) + nt.assert_equal(units.add_units(1.0, ''), 1.0) + nt.assert_equal(units.add_units(1.0, 'n/a'), 1.0) + + def test_is_null_unit(self): + r"""Test is_null_unit.""" + assert(units.is_null_unit('n/a')) + assert(units.is_null_unit('')) + assert(not units.is_null_unit('cm')) + + def test_as_unit(self): + r"""Test as_unit.""" + units.as_unit('cm') + nt.assert_raises(ValueError, units.as_unit, 'invalid') + + def test_is_unit(self): + r"""Test is_unit.""" + assert(units.is_unit('n/a')) + assert(units.is_unit('')) + assert(units.is_unit('cm/s**2')) + # Not supported by unyt + # assert(units.is_unit('cm/s^2')) + assert(units.is_unit('umol')) + assert(units.is_unit('mmol')) + assert(not units.is_unit('invalid')) + + def test_convert_to(self): + r"""Test convert_to.""" + units.convert_to(1, 'm') + for v in self._vars_units: + units.convert_to(v, 'm') + nt.assert_raises(ValueError, units.convert_to, v, 's') + + def test_are_compatible(self): + r"""Test are_compatible.""" + assert(units.are_compatible('cm', 'm')) + assert(units.are_compatible('cm', '')) + assert(not units.are_compatible('cm', 's')) + assert(not units.are_compatible('cm', 'invalid')) + + +class TestUnyt(TestPint): + r"""Test for using unyt for units.""" + _unit_package = 'unyt' diff --git a/cis_interface/units.py b/cis_interface/units.py index 632077875..7158ce2d6 100644 --- a/cis_interface/units.py +++ b/cis_interface/units.py @@ -1,24 +1,188 @@ -import pint +import numpy as np from cis_interface import backwards -_ureg = pint.UnitRegistry() -_ureg.define('micro_mole = 1e-6 * mole = uMol = umol') +import unyt +import pint +_ureg_unyt = unyt.UnitRegistry() +_ureg_pint = pint.UnitRegistry() +_ureg_pint.define('micro_mole = 1e-6 * mole = uMol = umol') +_use_unyt = True + + +def has_units(obj): + r"""Determine if a Python object has associated units. + + Args: + obj (object): Object to be tested for units. + + Returns: + bool: True if the object has units, False otherwise. + + """ + return hasattr(obj, 'units') + + +def get_units(obj): + r"""Get the string representation of the units. + + Args: + obj (object): Object to get units for. + + Returns: + str: Units, empty if input object has none. + + """ + if has_units(obj): + out = str(obj.units) + else: + out = '' + return out + + +def get_data(obj): + r"""Get the array/scalar assocaited with the object. + + Args: + obj (object): Object to get data for. + + Returns: + np.ndarray: Numpy array representation of the underlying data. + + """ + if has_units(obj): + if _use_unyt: + out = obj.to_ndarray() + else: + out = np.array(obj) + else: + out = obj + return out + + +def add_units(arr, unit_str): + r"""Add units to an array or scalar. + + Args: + arr (np.ndarray, float, int): Scalar or array of data to add units to. + unit_str (str): Unit string. + + Returns: + unyt.unyt_array: Array with units. + + """ + if is_null_unit(unit_str): + return arr + if _use_unyt: + out = unyt.unyt_array(arr, unit_str) + else: + out = _ureg_pint.Quantity(arr, unit_str) + return out + + +def are_compatible(units1, units2): + r"""Check if two units are compatible. + + Args: + units1 (str): First units string. + units2 (str): Second units string. + + Returns: + bool: True if the units are compatible, False otherwise. + + """ + # Empty units always compatible + if is_null_unit(units1) or is_null_unit(units2): + return True + if (not is_unit(units1)) or (not is_unit(units2)): + return False + x = add_units(1, units1) + try: + convert_to(x, units2) + except ValueError: + return False + return True + + +def is_null_unit(ustr): + r"""Determines if a string is a null unit. + + Args: + ustr (str): String to test. + + Returns: + bool: True if the string is '' or 'n/a', False otherwise. + + """ + if (len(ustr) == 0) or (ustr == 'n/a'): + return True + return False + + +def as_unit(ustr): + r"""Get unit object for the string. + + Args: + + Returns: + + Raises: + ValueError: If the string is not a recognized unit. + + """ + if _use_unyt: + try: + out = unyt.Unit(ustr) + except unyt.exceptions.UnitParseError as e: + raise ValueError(str(e)) + else: + try: + out = _ureg_pint(ustr) + except pint.errors.UndefinedUnitError as e: + raise ValueError(str(e)) + return out def is_unit(ustr): r"""Determine if a string is a valid unit. Args: - ustr: String representation to test. + ustr (str): String representation to test. Returns: bool: True if the string is a valid unit. False otherwise. """ ustr = backwards.bytes2unicode(ustr) - if ustr == 'n/a': + if is_null_unit(ustr): return True try: - _ureg(ustr) - except pint.errors.UndefinedUnitError: + as_unit(ustr) + except ValueError: return False return True + + +def convert_to(arr, new_units): + r"""Convert qunatity with units to new units. Objects without units + will be returned with the new units. + + Args: + arr (np.ndarray, float, int, unyt.unyt_array): Quantity with or + without units. + new_units (str): New units that should be applied. + + Returns: + unyt.unyt_array: Array with new units. + + """ + if is_null_unit(new_units): + return arr + if not has_units(arr): + return add_units(arr, new_units) + if _use_unyt: + try: + out = arr.to(new_units) + except unyt.exceptions.UnitConversionError as e: + raise ValueError(str(e)) + else: + out = arr.to(new_units) + return out diff --git a/docs/source/conf.py b/docs/source/conf.py index 6396e570e..d25df70f6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,8 @@ import sphinx_rtd_theme # sys.path.insert(0, os.path.abspath('.')) doxydir = os.path.join(os.path.abspath('../'), "doxy", "xml") -srcdir = os.path.join(os.path.abspath('../../'), "cis_interface") +rootdir = os.path.abspath('../../') +srcdir = os.path.join(srcdir, "cis_interface") sys.path.append(doxydir) @@ -99,10 +100,16 @@ # |version| and |release|, also used in various other places throughout the # built documents. # +with open(os.path.join(rootdir, 'VERSION')) as version_file: + cis_ver = version_file.read().strip() +ver_parts = cis_ver.split('.') # The short X.Y version. -version = u'0.3' +version = '.'.join(ver_parts[:2]) # The full version, including alpha/beta/rc tags. -release = u'0.3.0' +if len(ver_parts) <= 2: + release = version + '.0' +else: + release = cis_ver # Substitutions # .. _Docs: http://cis_interface.readthedocs.io/en/latest/ diff --git a/setup.py b/setup.py index d413d2087..7c4461278 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,9 @@ PY2 = (PY_MAJOR_VERSION == 2) IS_WINDOWS = (sys.platform in ['win32', 'cygwin']) -cis_ver = "0.3" + +with open(os.path.join(os.path.dirname(__file__), 'VERSION')) as version_file: + cis_ver = version_file.read().strip() try: @@ -217,8 +219,9 @@ def rm_excl_rule(excl_list, new_rule): raise IOError("Could not find README.rst or README.md") # Create requirements list based on platform -requirements = ["numpy", "scipy", "pyyaml", "pystache", "nose", "zmq", "psutil", - "matplotlib", "cerberus", +requirements = ["numpy", "scipy", "pyyaml", + "pystache", "nose", "zmq", "psutil", + "matplotlib", "cerberus", "jsonschema", 'pandas; python_version >= "3.5"', 'pandas; python_version == "2.7"', 'pandas<0.21; python_version == "3.4"', @@ -262,7 +265,6 @@ def rm_excl_rule(excl_list, new_rule): "License :: OSI Approved :: BSD License", "Natural Language :: English", "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Bio-Informatics", "Development Status :: 3 - Alpha", ], entry_points = { @@ -270,6 +272,7 @@ def rm_excl_rule(excl_list, new_rule): 'ciscc=cis_interface.command_line:ciscc', 'cisccflags=cis_interface.command_line:cc_flags', 'cisldflags=cis_interface.command_line:ld_flags', + 'cistest=cis_interface:run_nose', 'cisschema=cis_interface.command_line:regen_schema'], }, license="BSD",