From cfa246c91b16e41e80a700f735afe73207f7f89d Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Tue, 11 Oct 2022 01:52:28 +0100 Subject: [PATCH] Feature / Define model attributes in code (#185) * API for defining model attributes in code * Implementation for importing attrs from model code * Support array attributes in model attribute definition * Pick up model attrs im import model jobs * Add an end-to-end test case for importing model attrs * Add documentation to new runtime API methods * mark model attr APIs as experimental * Prevent defining control attributes in model code --- .../python/src/tutorial/schema_files.py | 8 + .../main/proto/tracdap/config/result.proto | 9 +- .../python/src/tracdap/rt/_exec/engine.py | 2 +- .../python/src/tracdap/rt/_exec/functions.py | 29 +++- .../python/src/tracdap/rt/_exec/graph.py | 11 +- .../src/tracdap/rt/_exec/graph_builder.py | 21 ++- .../python/src/tracdap/rt/_impl/api_hook.py | 37 +++++ .../python/src/tracdap/rt/_impl/models.py | 19 +++ .../src/tracdap/rt/_impl/type_system.py | 29 +++- .../python/src/tracdap/rt/api/hook.py | 16 ++ .../python/src/tracdap/rt/api/model_api.py | 30 +++- .../python/src/tracdap/rt/api/static_api.py | 84 +++++++++++ .../tracdap/svc/orch/jobs/ImportModelJob.java | 23 ++- .../tracdap/svc/orch/jobs/ModelAttrsTest.java | 142 ++++++++++++++++++ 14 files changed, 436 insertions(+), 24 deletions(-) create mode 100644 tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/jobs/ModelAttrsTest.java diff --git a/examples/models/python/src/tutorial/schema_files.py b/examples/models/python/src/tutorial/schema_files.py index 803ea7477..7295e3364 100644 --- a/examples/models/python/src/tutorial/schema_files.py +++ b/examples/models/python/src/tutorial/schema_files.py @@ -21,6 +21,14 @@ class SchemaFilesModel(trac.TracModel): + def define_attributes(self) -> tp.List[trac.TagUpdate]: + + return trac.define_attributes( + trac.A("model_description", "A example model, for testing purposes"), + trac.A("business_segment", "retail_products", categorical=True), + trac.A("classifiers", ["loans", "uk", "examples"], attr_type=trac.STRING) + ) + def define_parameters(self) -> tp.Dict[str, trac.ModelParameter]: return trac.define_parameters( diff --git a/tracdap-api/tracdap-config/src/main/proto/tracdap/config/result.proto b/tracdap-api/tracdap-config/src/main/proto/tracdap/config/result.proto index 4268f9369..cb3349f5f 100644 --- a/tracdap-api/tracdap-config/src/main/proto/tracdap/config/result.proto +++ b/tracdap-api/tracdap-config/src/main/proto/tracdap/config/result.proto @@ -21,10 +21,16 @@ option java_package = "org.finos.tracdap.config"; option java_multiple_files = true; import "tracdap/metadata/object_id.proto"; -import "tracdap/metadata/object.proto";; +import "tracdap/metadata/object.proto"; import "tracdap/metadata/job.proto"; +import "tracdap/metadata/tag_update.proto"; +message TagUpdateList { + + repeated metadata.TagUpdate attrs = 1; +} + message JobResult { metadata.TagHeader jobId = 1; @@ -33,4 +39,5 @@ message JobResult { string statusMessage = 3; map results = 4; + map attrs = 5; } diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/engine.py b/tracdap-runtime/python/src/tracdap/rt/_exec/engine.py index d4cd0dcd0..ba3768f8a 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/engine.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/engine.py @@ -121,7 +121,7 @@ def on_signal(self, signal: Signal) -> tp.Optional[bool]: error = signal.error else: error = _ex.ETracInternal("An unknown error occurred") - self.job_failed(failed_job_key, error) + self.actors().send("job_failed", failed_job_key, error) # Failed signal has been handled, do not propagate return True diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/functions.py b/tracdap-runtime/python/src/tracdap/rt/_exec/functions.py index b6284e34e..de9e83ebc 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/functions.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/functions.py @@ -202,6 +202,10 @@ def _execute(self, ctx: NodeContext) -> _config.JobResult: obj_def = _ctx_lookup(node_id, ctx) job_result.results[obj_id] = obj_def + for obj_id, node_id in self.node.attrs.items(): + attrs = _ctx_lookup(node_id, ctx) + job_result.attrs[obj_id] = attrs + for bundle_id in self.node.bundles: bundle = _ctx_lookup(bundle_id, ctx) job_result.results.update(bundle.items()) @@ -504,6 +508,25 @@ def _execute(self, ctx: NodeContext) -> meta.ObjectDefinition: return meta.ObjectDefinition(meta.ObjectType.MODEL, model=model_def) +class ImportAttrsFunc(NodeFunction[_config.TagUpdateList]): + + def __init__(self, node: ImportAttrsNode, models: _models.ModelLoader): + self.node = node + self._models = models + + def _execute(self, ctx: NodeContext) -> _config.TagUpdateList: + + stub_model_def = meta.ModelDefinition( + language=self.node.import_details.language, + repository=self.node.import_details.repository, + path=self.node.import_details.path, + entryPoint=self.node.import_details.entryPoint, + version=self.node.import_details.version) + + model_class = self._models.load_model_class(self.node.model_scope, stub_model_def) + return self._models.scan_model_attrs(model_class) + + class RunModelFunc(NodeFunction[Bundle[_data.DataView]]): def __init__(self, node: RunModelNode, model_class: _api.TracModel.__class__): @@ -620,6 +643,9 @@ def resolve_dynamic_data_spec(self, node: DynamicDataSpecNode): def resolve_import_model_node(self, node: ImportModelNode): return ImportModelFunc(node, self._models) + def resolve_import_attrs_node(self, node: ImportAttrsNode): + return ImportAttrsFunc(node, self._models) + def resolve_run_model_node(self, node: RunModelNode) -> NodeFunction: model_class = self._models.load_model_class(node.model_scope, node.model_def) @@ -650,5 +676,6 @@ def resolve_run_model_node(self, node: RunModelNode) -> NodeFunction: SaveDataNode: resolve_save_data, DynamicDataSpecNode: resolve_dynamic_data_spec, RunModelNode: resolve_run_model_node, - ImportModelNode: resolve_import_model_node + ImportModelNode: resolve_import_model_node, + ImportAttrsNode: resolve_import_attrs_node } diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/graph.py b/tracdap-runtime/python/src/tracdap/rt/_exec/graph.py index beee805d7..9054f6b9e 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/graph.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/graph.py @@ -345,6 +345,13 @@ class ImportModelNode(Node[meta.ObjectDefinition]): import_details: meta.ImportModelJob +@_node_type +class ImportAttrsNode(Node[cfg.TagUpdateList]): + + model_scope: str + import_details: meta.ImportModelJob + + @_node_type class RunModelNode(Node[Bundle[_data.DataView]]): @@ -363,10 +370,12 @@ class BuildJobResultNode(Node[cfg.JobResult]): job_id: meta.TagHeader objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = dc.field(default_factory=dict) + attrs: tp.Dict[str, NodeId[cfg.TagUpdateList]] = dc.field(default_factory=dict) + bundles: tp.List[NodeId[ObjectBundle]] = dc.field(default_factory=list) def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]: - dep_ids = [*self.bundles, *self.objects.values()] + dep_ids = [*self.bundles, *self.objects.values(), *self.attrs.values()] return {node_id: DependencyType.HARD for node_id in dep_ids} diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py b/tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py index 07de55a76..0a9d6e9b3 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py @@ -106,15 +106,17 @@ def build_import_model_job( import_id = NodeId.of("trac_import_model", job_namespace, meta.ObjectDefinition) import_node = ImportModelNode(import_id, model_scope, import_details, explicit_deps=[job_push_id]) - main_section = GraphSection(nodes={import_id: import_node}, must_run=[import_id]) + import_attrs_id = NodeId.of("trac_import_attrs", job_namespace, config.TagUpdateList) + import_attrs_node = ImportAttrsNode(import_attrs_id, model_scope, import_details, explicit_deps=[import_id]) - # Build job-level metadata outputs + main_section = GraphSection(nodes={import_id: import_node, import_attrs_id: import_attrs_node}) - result_objects = {new_model_key: import_id} + # Build job-level metadata outputs result_section = cls.build_job_results( - job_config, job_namespace, - result_spec, objects=result_objects, + job_config, job_namespace, result_spec, + objects={new_model_key: import_id}, + attrs={new_model_key: import_attrs_id}, explicit_deps=[job_push_id, *main_section.must_run]) return cls._join_sections(main_section, result_section) @@ -327,7 +329,9 @@ def build_job_outputs( @classmethod def build_job_results( cls, job_config: cfg.JobConfig, job_namespace: NodeNamespace, result_spec: JobResultSpec, - objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = None, bundles: tp.List[NodeId[ObjectBundle]] = None, + objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = None, + attrs: tp.Dict[str, NodeId[config.TagUpdateList]] = None, + bundles: tp.List[NodeId[ObjectBundle]] = None, explicit_deps: tp.Optional[tp.List[NodeId]] = None) \ -> GraphSection: @@ -337,9 +341,12 @@ def build_job_results( results_inputs = set(objects.values()) + if attrs is not None: + results_inputs.update(attrs.values()) + build_result_node = BuildJobResultNode( build_result_id, job_config.jobId, - objects=objects, explicit_deps=explicit_deps) + objects=objects, attrs=attrs, explicit_deps=explicit_deps) elif bundles is not None: diff --git a/tracdap-runtime/python/src/tracdap/rt/_impl/api_hook.py b/tracdap-runtime/python/src/tracdap/rt/_impl/api_hook.py index 9c40b1644..304cb8e2a 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_impl/api_hook.py +++ b/tracdap-runtime/python/src/tracdap/rt/_impl/api_hook.py @@ -20,6 +20,7 @@ import tracdap.rt.metadata as _meta import tracdap.rt.exceptions as _ex import tracdap.rt._impl.schemas as _schemas +import tracdap.rt._impl.type_system as _type_system import tracdap.rt._impl.util as _util # Import hook interfaces into this module namespace @@ -204,6 +205,42 @@ def register_impl(cls): log.warning("Runtime API hook is already registered") + def define_attributes( + self, *attrs: _tp.Union[_meta.TagUpdate, _tp.List[_meta.TagUpdate]]) \ + -> _tp.List[_meta.TagUpdate]: + + ApiGuard.validate_signature(self.define_attributes, *attrs) + + if len(attrs) == 1 and isinstance(attrs[0], list): + return attrs[0] + else: + return [*attrs] + + def define_attribute( + self, attr_name: str, attr_value: _tp.Any, + attr_type: _tp.Optional[_meta.BasicType] = None, + categorical: bool = False) \ + -> _meta.TagUpdate: + + ApiGuard.validate_signature(self.define_attribute, attr_name, attr_value, attr_type, categorical) + + if isinstance(attr_value, list) and attr_type is None: + raise _ex.EModelValidation(f"Attribute type must be specified for multi-valued attribute [{attr_name}]") + + if categorical and not (isinstance(attr_name, str) or attr_type == _meta.BasicType.STRING): + raise _ex.EModelValidation("Categorical flag is only allowed for STRING attributes") + + if attr_type is None: + trac_value = _type_system.MetadataCodec.encode_value(attr_value) + elif isinstance(attr_value, list): + type_desc = _meta.TypeDescriptor(_meta.BasicType.ARRAY, arrayType=_meta.TypeDescriptor(attr_type)) + trac_value = _type_system.MetadataCodec.convert_value(attr_value, type_desc) + else: + type_desc = _meta.TypeDescriptor(attr_type) + trac_value = _type_system.MetadataCodec.convert_value(attr_value, type_desc) + + return _meta.TagUpdate(_meta.TagOperation.CREATE_OR_APPEND_ATTR, attr_name, trac_value) + def define_parameter( self, param_name: str, param_type: _tp.Union[_meta.TypeDescriptor, _meta.BasicType], label: str, default_value: _tp.Optional[_tp.Any] = None) \ diff --git a/tracdap-runtime/python/src/tracdap/rt/_impl/models.py b/tracdap-runtime/python/src/tracdap/rt/_impl/models.py index f36a40192..4b5e0eb6a 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_impl/models.py +++ b/tracdap-runtime/python/src/tracdap/rt/_impl/models.py @@ -149,6 +149,24 @@ def load_model_class(self, scope: str, model_def: _meta.ModelDefinition) -> _api scope_state.model_cache[model_key] = model_class return model_class + def scan_model_attrs(self, model_class: _api.TracModel.__class__) -> _cfg.TagUpdateList: + + model: _api.TracModel = object.__new__(model_class) + model_class.__init__(model) + + attributes = model.define_attributes() + + for attr in attributes: + + if attr.attrName.startswith("trac_") or attr.attrName.startswith("_"): + err = f"Controlled attribute [{attr.attrName}] cannot be defined in model code" + self.__log.error(err) + raise _ex.EModelValidation(err) + + self.__log.info(f"Attribute [{attr.attrName}] - {_types.MetadataCodec.decode_value(attr.value)}") + + return _cfg.TagUpdateList(attributes) + def scan_model(self, model_class: _api.TracModel.__class__) -> _meta.ModelDefinition: try: @@ -156,6 +174,7 @@ def scan_model(self, model_class: _api.TracModel.__class__) -> _meta.ModelDefini model: _api.TracModel = object.__new__(model_class) model_class.__init__(model) + attributes = model.define_attributes() parameters = model.define_parameters() inputs = model.define_inputs() outputs = model.define_outputs() diff --git a/tracdap-runtime/python/src/tracdap/rt/_impl/type_system.py b/tracdap-runtime/python/src/tracdap/rt/_impl/type_system.py index d1f238546..3fd06a37f 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_impl/type_system.py +++ b/tracdap-runtime/python/src/tracdap/rt/_impl/type_system.py @@ -102,7 +102,12 @@ def decode_value(value: _meta.Value) -> tp.Any: raise _ex.ETracInternal("Missing type information") - basic_type = value.type.basicType + return MetadataCodec._decode_value_for_type(value, value.type) + + @staticmethod + def _decode_value_for_type(value: _meta.Value, type_desc: _meta.TypeDescriptor): + + basic_type = type_desc.basicType if basic_type == _meta.BasicType.BOOLEAN: return value.booleanValue @@ -125,6 +130,10 @@ def decode_value(value: _meta.Value) -> tp.Any: if basic_type == _meta.BasicType.DATETIME: return dt.datetime.fromisoformat(value.datetimeValue.isoDatetime) + if basic_type == _meta.BasicType.ARRAY: + items = value.arrayValue.items + return list(map(lambda x: MetadataCodec._decode_value_for_type(x, type_desc.arrayType), items)) + raise _ex.ETracInternal(f"Decoding value type [{basic_type}] is not supported yet") @classmethod @@ -162,7 +171,7 @@ def encode_value(cls, value: tp.Any) -> _meta.Value: type_desc = _meta.TypeDescriptor(_meta.BasicType.DATE) return _meta.Value(type_desc, dateValue=_meta.DateValue(value.isoformat())) - raise _ex.ETracInternal(f"Encoding value type [{type(value)}] is not supported yet") + raise _ex.ETracInternal(f"Value type [{type(value)}] is not supported yet") @classmethod def convert_value(cls, raw_value: tp.Any, type_desc: _meta.TypeDescriptor): @@ -188,8 +197,24 @@ def convert_value(cls, raw_value: tp.Any, type_desc: _meta.TypeDescriptor): if type_desc.basicType == _meta.BasicType.DATETIME: return cls.convert_datetime_value(raw_value) + if type_desc.basicType == _meta.BasicType.ARRAY: + return cls.convert_array_value(raw_value, type_desc.arrayType) + raise _ex.ETracInternal(f"Conversion to value type [{type_desc.basicType.name}] is not supported yet") + @staticmethod + def convert_array_value(raw_value: tp.List[tp.Any], array_type: _meta.TypeDescriptor) -> _meta.Value: + + type_desc = _meta.TypeDescriptor(_meta.BasicType.ARRAY, array_type) + + if not isinstance(raw_value, list): + msg = f"Value of type [{type(raw_value)}] cannot be converted to {_meta.BasicType.ARRAY.name}" + raise _ex.ETracInternal(msg) + + items = list(map(lambda x: MetadataCodec.convert_value(x, array_type), raw_value)) + + return _meta.Value(type_desc, arrayValue=_meta.ArrayValue(items)) + @staticmethod def convert_boolean_value(raw_value: tp.Any) -> _meta.Value: diff --git a/tracdap-runtime/python/src/tracdap/rt/api/hook.py b/tracdap-runtime/python/src/tracdap/rt/api/hook.py index a480c0928..bacfa7893 100644 --- a/tracdap-runtime/python/src/tracdap/rt/api/hook.py +++ b/tracdap-runtime/python/src/tracdap/rt/api/hook.py @@ -61,6 +61,22 @@ def runtime(cls) -> _RuntimeHook: return cls.__runtime_hook + @_abc.abstractmethod + def define_attributes( + self, *attrs: _tp.Union[_meta.TagUpdate, _tp.List[_meta.TagUpdate]]) \ + -> _tp.List[_meta.TagUpdate]: + + pass + + @_abc.abstractmethod + def define_attribute( + self, attr_name: str, attr_value: _tp.Any, + attr_type: _tp.Optional[_meta.BasicType] = None, + categorical: bool = False) \ + -> _meta.TagUpdate: + + pass + @_abc.abstractmethod def define_parameter( self, param_name: str, param_type: _tp.Union[_meta.TypeDescriptor, _meta.BasicType], diff --git a/tracdap-runtime/python/src/tracdap/rt/api/model_api.py b/tracdap-runtime/python/src/tracdap/rt/api/model_api.py index c6ec45070..d55cd31a3 100644 --- a/tracdap-runtime/python/src/tracdap/rt/api/model_api.py +++ b/tracdap-runtime/python/src/tracdap/rt/api/model_api.py @@ -207,6 +207,30 @@ class TracModel: .. seealso:: :py:class:`TracContext` """ + def define_attributes(self) -> _tp.List[TagUpdate]: # noqa + + """ + Define attributes that will be associated with the model when it is loaded into the TRAC platform + + .. note:: + This is an experimental API that is not yet stabilised, expect changes in future versions of TRAC + + These attributes can be used to index or describe the model, they will be available for metadata searches. + Attributes must be primitive (scalar) values that can be expressed in the TRAC type system. + Multivalued attributes can be supplied as lists, in which case the attribute type must be given explicitly. + Controlled attributes (starting with trac\\_ or \\_) are not allowed and will fail validation. + + To define attributes in code, always use the define_* function in the :py:mod:`tracdap.rt.api` package. + This will ensure attributes are defined the correct format with all the required fields. + Attributes that are defined in the wrong format or with required fields missing + will result in a model validation failure. + + :return: A set of attributes that will be applied to the model when it is loaded into the TRAC platform + :rtype: List[:py:class:`TagUpdate `] + """ + + return [] + @_abc.abstractmethod def define_parameters(self) -> _tp.Dict[str, ModelParameter]: @@ -217,7 +241,7 @@ def define_parameters(self) -> _tp.Dict[str, ModelParameter]: model uses must be defined. Models may choose to ignore some parameters, it is ok to define parameters that are not always used. - To declare model parameters in code, always use the declare_* functions in the :py:mod:`tracdap.rt.api` package. + To define model parameters in code, always use the define_* functions in the :py:mod:`tracdap.rt.api` package. This will ensure parameters are defined in the correct format with all the required fields. Parameters that are defined in the wrong format or with required fields missing will result in a model validation failure. @@ -238,7 +262,7 @@ def define_inputs(self) -> _tp.Dict[str, ModelInputSchema]: model uses must be defined. Models may choose to ignore some inputs, it is ok to define inputs that are not always used. - To declare model inputs in code, always use the declare_* functions in the :py:mod:`tracdap.rt.api` package. + To define model inputs in code, always use the define_* functions in the :py:mod:`tracdap.rt.api` package. This will ensure inputs are defined in the correct format with all the required fields. Model inputs that are defined in the wrong format or with required fields missing will result in a model validation failure. @@ -260,7 +284,7 @@ def define_outputs(self) -> _tp.Dict[str, ModelOutputSchema]: produced. If a model defines an output which is not produced, a runtime validation error will be raised after the model completes. - To declare model outputs in code, always use the declare_* functions in the :py:mod:`tracdap.rt.api` package. + To define model outputs in code, always use the define_* functions in the :py:mod:`tracdap.rt.api` package. This will ensure outputs are defined in the correct format with all the required fields. Model outputs that are defined in the wrong format or with required fields missing will result in a model validation failure. diff --git a/tracdap-runtime/python/src/tracdap/rt/api/static_api.py b/tracdap-runtime/python/src/tracdap/rt/api/static_api.py index af2e501ea..6e4ec9fc0 100644 --- a/tracdap-runtime/python/src/tracdap/rt/api/static_api.py +++ b/tracdap-runtime/python/src/tracdap/rt/api/static_api.py @@ -25,6 +25,90 @@ from tracdap.rt.metadata import * # DOCGEN_REMOVE +def define_attributes(*attrs: _tp.Union[TagUpdate, _tp.List[TagUpdate]]) -> _tp.List[TagUpdate]: + + """ + Defined a set of attributes to catalogue and describe a model + + .. note:: + This is an experimental API that is not yet stabilised, expect changes in future versions of TRAC + + Attributes can be supplied either as individual arguments to this function or as a list. + In either case, each attribute should be defined using :py:func:`define_attribute` + (or :py:func:`trac.A `). + + :param attrs: The attributes that will be defined, either as individual arguments or as a list + :return: A set of model attributes, in the correct format to return from + :py:meth:`TracModel.define_attributes` + + :type attrs: :py:class:`TagUpdate ` | + List[:py:class:`TagUpdate `] + :rtype: List[:py:class:`TagUpdate `] + """ + + rh = _RuntimeHook.runtime() + return rh.define_attributes(*attrs) + + +def define_attribute( + attr_name: str, attr_value: _tp.Any, + attr_type: _tp.Optional[BasicType] = None, + categorical: bool = False) \ + -> TagUpdate: + + """ + Define an individual model attribute + + .. note:: + This is an experimental API that is not yet stabilised, expect changes in future versions of TRAC + + Model attributes can be defined using this method (or :py:func:`trac.A `). + The attr_name and attr_value are always required to define an attribute. + attr_type is always required for multivalued attributes but is optional otherwise. + The categorical flag can be applied to STRING attributes if required. + + Once defined attributes can be passed to :py:func:`define_attributes`, + either as a list or as individual arguments, to create the set of attributes for a model. + + :param attr_name: The attribute name + :param attr_value: The attribute value (as a raw Python value) + :param attr_type: The TRAC type for this attribute (optional, except for multivalued attributes) + :param categorical: A flag to indicate whether this attribute is categorical + :return: An attribute for the model, ready for loading into the TRAC platform + + :type attr_name: str + :type attr_value: Any + :type attr_type: Optional[:py:class:`BasicType `] + :type categorical: bool + :rtype: :py:class:`TagUpdate ` + """ + + rh = _RuntimeHook.runtime() + return rh.define_attribute(attr_name, attr_value, attr_type, categorical) + + +def A( # noqa + attr_name: str, attr_value: _tp.Any, + attr_type: _tp.Optional[BasicType] = None, + categorical: bool = False) \ + -> TagUpdate: + + """ + Shorthand alias for :py:func:`define_attribute` + + .. note:: + This is an experimental API that is not yet stabilised, expect changes in future versions of TRAC + + :type attr_name: str + :type attr_value: Any + :type attr_type: Optional[:py:class:`BasicType `] + :type categorical: bool + :rtype: :py:class:`TagUpdate ` + """ + + return define_attribute(attr_name, attr_value, attr_type, categorical) + + def define_parameter( param_name: str, param_type: _tp.Union[TypeDescriptor, BasicType], label: str, default_value: _tp.Optional[_tp.Any] = None) \ diff --git a/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/ImportModelJob.java b/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/ImportModelJob.java index fbd28e3f9..74b8054ee 100644 --- a/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/ImportModelJob.java +++ b/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/ImportModelJob.java @@ -20,6 +20,7 @@ import org.finos.tracdap.common.exception.EUnexpected; import org.finos.tracdap.config.JobConfig; import org.finos.tracdap.config.JobResult; +import org.finos.tracdap.config.TagUpdateList; import org.finos.tracdap.metadata.*; import java.util.List; @@ -81,14 +82,23 @@ public JobDefinition setResultIds( @Override public List buildResultMetadata(String tenant, JobConfig jobConfig, JobResult jobResult) { - var modelObjMaybe = jobResult.getResultsMap().values().stream().findFirst(); + var modelKeyMaybe = jobResult.getResultsMap().keySet().stream().findFirst(); - if (modelObjMaybe.isEmpty()) + if (modelKeyMaybe.isEmpty()) throw new EUnexpected(); - var modelObj = modelObjMaybe.get(); + var modelKey = modelKeyMaybe.get(); + var modelObj = jobResult.getResultsOrThrow(modelKey); var modelDef = modelObj.getModel(); + var modelAttrs = jobResult + .getAttrsOrDefault(modelKey, TagUpdateList.newBuilder().build()) + .getAttrsList(); + + var suppliedAttrs = jobConfig.getJob() + .getImportModel() + .getModelAttrsList(); + var controlledAttrs = List.of( TagUpdate.newBuilder() @@ -116,16 +126,13 @@ public List buildResultMetadata(String tenant, JobConfig j .setValue(encodeValue(modelDef.getVersion())) .build()); - var suppliedAttrs = jobConfig.getJob() - .getImportModel() - .getModelAttrsList(); - var modelReq = MetadataWriteRequest.newBuilder() .setTenant(tenant) .setObjectType(ObjectType.MODEL) .setDefinition(modelObj) - .addAllTagUpdates(controlledAttrs) + .addAllTagUpdates(modelAttrs) .addAllTagUpdates(suppliedAttrs) + .addAllTagUpdates(controlledAttrs) .build(); return List.of(modelReq); diff --git a/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/jobs/ModelAttrsTest.java b/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/jobs/ModelAttrsTest.java new file mode 100644 index 000000000..c800cf929 --- /dev/null +++ b/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/jobs/ModelAttrsTest.java @@ -0,0 +1,142 @@ +/* + * Copyright 2022 Accenture Global Solutions Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.finos.tracdap.svc.orch.jobs; + + +import org.finos.tracdap.api.JobRequest; +import org.finos.tracdap.api.MetadataReadRequest; +import org.finos.tracdap.api.MetadataSearchRequest; +import org.finos.tracdap.common.metadata.MetadataCodec; +import org.finos.tracdap.common.metadata.MetadataUtil; +import org.finos.tracdap.metadata.*; +import org.finos.tracdap.metadata.ImportModelJob; +import org.finos.tracdap.test.helpers.GitHelpers; +import org.finos.tracdap.test.helpers.PlatformTest; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +import static org.finos.tracdap.svc.orch.jobs.Helpers.runJob; + + +public abstract class ModelAttrsTest { + + private static final String TEST_TENANT = "ACME_CORP"; + private static final String E2E_CONFIG = "config/trac-e2e.yaml"; + + protected abstract String useTracRepo(); + + public static class LocalRepoTest extends RunFlowTest { + protected String useTracRepo() { return "TRAC_LOCAL_REPO"; } + } + + @EnabledIfEnvironmentVariable(named = "GITHUB_ACTIONS", matches = "true", disabledReason = "Only run in CI") + public static class GitRepoTest extends RunFlowTest { + protected String useTracRepo() { return "TRAC_GIT_REPO"; } + } + + @RegisterExtension + private static final PlatformTest platform = PlatformTest.forConfig(E2E_CONFIG) + .addTenant(TEST_TENANT) + .startAll() + .build(); + + private final Logger log = LoggerFactory.getLogger(getClass()); + + + @Test + void importModel() throws Exception { + + log.info("Running IMPORT_MODEL job..."); + + var metaClient = platform.metaClientBlocking(); + var orchClient = platform.orchClientBlocking(); + + var modelVersion = GitHelpers.getCurrentCommit(); + + var importModel = ImportModelJob.newBuilder() + .setLanguage("python") + .setRepository(useTracRepo()) + .setPath("examples/models/python/src") + .setEntryPoint("tutorial.using_data.UsingDataModel") + .setVersion(modelVersion) + .addModelAttrs(TagUpdate.newBuilder() + .setAttrName("e2e_test_model") + .setValue(MetadataCodec.encodeValue("run_model:using_data"))) + .build(); + + var jobRequest = JobRequest.newBuilder() + .setTenant(TEST_TENANT) + .setJob(JobDefinition.newBuilder() + .setJobType(JobType.IMPORT_MODEL) + .setImportModel(importModel)) + .addJobAttrs(TagUpdate.newBuilder() + .setAttrName("e2e_test_job") + .setValue(MetadataCodec.encodeValue("run_model:import_model"))) + .build(); + + var jobStatus = runJob(orchClient, jobRequest); + var jobKey = MetadataUtil.objectKey(jobStatus.getJobId()); + + Assertions.assertEquals(JobStatusCode.SUCCEEDED, jobStatus.getStatusCode()); + + var modelSearch = MetadataSearchRequest.newBuilder() + .setTenant(TEST_TENANT) + .setSearchParams(SearchParameters.newBuilder() + .setObjectType(ObjectType.MODEL) + .setSearch(SearchExpression.newBuilder() + .setTerm(SearchTerm.newBuilder() + .setAttrName("trac_create_job") + .setAttrType(BasicType.STRING) + .setOperator(SearchOperator.EQ) + .setSearchValue(MetadataCodec.encodeValue(jobKey))))) + .build(); + + var modelSearchResult = metaClient.search(modelSearch); + + Assertions.assertEquals(1, modelSearchResult.getSearchResultCount()); + + var searchResult = modelSearchResult.getSearchResult(0); + var modelReq = MetadataReadRequest.newBuilder() + .setTenant(TEST_TENANT) + .setSelector(MetadataUtil.selectorFor(searchResult.getHeader())) + .build(); + + var modelTag = metaClient.readObject(modelReq); + var modelDef = modelTag.getDefinition().getModel(); + var modelAttr = modelTag.getAttrsOrThrow("e2e_test_model"); + + Assertions.assertEquals("run_model:using_data", MetadataCodec.decodeStringValue(modelAttr)); + Assertions.assertEquals("tutorial.using_data.UsingDataModel", modelDef.getEntryPoint()); + Assertions.assertTrue(modelDef.getParametersMap().containsKey("eur_usd_rate")); + Assertions.assertTrue(modelDef.getInputsMap().containsKey("customer_loans")); + Assertions.assertTrue(modelDef.getOutputsMap().containsKey("profit_by_region")); + + var descriptionAttr = modelTag.getAttrsOrThrow("model_description"); + var segmentAttr = modelTag.getAttrsOrThrow("business_segment"); + var classifiersAttr = modelTag.getAttrsOrThrow("classifiers"); + + Assertions.assertInstanceOf(String.class, MetadataCodec.decodeValue(descriptionAttr)); + Assertions.assertEquals("retail_products", MetadataCodec.decodeValue(segmentAttr)); + Assertions.assertEquals(List.of("loans", "uk", "examples"), MetadataCodec.decodeValue(classifiersAttr)); + } +}