Skip to content

Commit

Permalink
Feature / Define model attributes in code (#185)
Browse files Browse the repository at this point in the history
* API for defining model attributes in code

* Implementation for importing attrs from model code

* Support array attributes in model attribute definition

* Pick up model attrs im import model jobs

* Add an end-to-end test case for importing model attrs

* Add documentation to new runtime API methods

* mark model attr APIs as experimental

* Prevent defining control attributes in model code
  • Loading branch information
martin-traverse committed Oct 11, 2022
1 parent 0469389 commit cfa246c
Show file tree
Hide file tree
Showing 14 changed files with 436 additions and 24 deletions.
8 changes: 8 additions & 0 deletions examples/models/python/src/tutorial/schema_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@

class SchemaFilesModel(trac.TracModel):

def define_attributes(self) -> tp.List[trac.TagUpdate]:

return trac.define_attributes(
trac.A("model_description", "A example model, for testing purposes"),
trac.A("business_segment", "retail_products", categorical=True),
trac.A("classifiers", ["loans", "uk", "examples"], attr_type=trac.STRING)
)

def define_parameters(self) -> tp.Dict[str, trac.ModelParameter]:

return trac.define_parameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ option java_package = "org.finos.tracdap.config";
option java_multiple_files = true;

import "tracdap/metadata/object_id.proto";
import "tracdap/metadata/object.proto";;
import "tracdap/metadata/object.proto";
import "tracdap/metadata/job.proto";
import "tracdap/metadata/tag_update.proto";


message TagUpdateList {

repeated metadata.TagUpdate attrs = 1;
}

message JobResult {

metadata.TagHeader jobId = 1;
Expand All @@ -33,4 +39,5 @@ message JobResult {
string statusMessage = 3;

map<string, metadata.ObjectDefinition> results = 4;
map<string, TagUpdateList> attrs = 5;
}
2 changes: 1 addition & 1 deletion tracdap-runtime/python/src/tracdap/rt/_exec/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def on_signal(self, signal: Signal) -> tp.Optional[bool]:
error = signal.error
else:
error = _ex.ETracInternal("An unknown error occurred")
self.job_failed(failed_job_key, error)
self.actors().send("job_failed", failed_job_key, error)

# Failed signal has been handled, do not propagate
return True
Expand Down
29 changes: 28 additions & 1 deletion tracdap-runtime/python/src/tracdap/rt/_exec/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ def _execute(self, ctx: NodeContext) -> _config.JobResult:
obj_def = _ctx_lookup(node_id, ctx)
job_result.results[obj_id] = obj_def

for obj_id, node_id in self.node.attrs.items():
attrs = _ctx_lookup(node_id, ctx)
job_result.attrs[obj_id] = attrs

for bundle_id in self.node.bundles:
bundle = _ctx_lookup(bundle_id, ctx)
job_result.results.update(bundle.items())
Expand Down Expand Up @@ -504,6 +508,25 @@ def _execute(self, ctx: NodeContext) -> meta.ObjectDefinition:
return meta.ObjectDefinition(meta.ObjectType.MODEL, model=model_def)


class ImportAttrsFunc(NodeFunction[_config.TagUpdateList]):

def __init__(self, node: ImportAttrsNode, models: _models.ModelLoader):
self.node = node
self._models = models

def _execute(self, ctx: NodeContext) -> _config.TagUpdateList:

stub_model_def = meta.ModelDefinition(
language=self.node.import_details.language,
repository=self.node.import_details.repository,
path=self.node.import_details.path,
entryPoint=self.node.import_details.entryPoint,
version=self.node.import_details.version)

model_class = self._models.load_model_class(self.node.model_scope, stub_model_def)
return self._models.scan_model_attrs(model_class)


class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):

def __init__(self, node: RunModelNode, model_class: _api.TracModel.__class__):
Expand Down Expand Up @@ -620,6 +643,9 @@ def resolve_dynamic_data_spec(self, node: DynamicDataSpecNode):
def resolve_import_model_node(self, node: ImportModelNode):
return ImportModelFunc(node, self._models)

def resolve_import_attrs_node(self, node: ImportAttrsNode):
return ImportAttrsFunc(node, self._models)

def resolve_run_model_node(self, node: RunModelNode) -> NodeFunction:

model_class = self._models.load_model_class(node.model_scope, node.model_def)
Expand Down Expand Up @@ -650,5 +676,6 @@ def resolve_run_model_node(self, node: RunModelNode) -> NodeFunction:
SaveDataNode: resolve_save_data,
DynamicDataSpecNode: resolve_dynamic_data_spec,
RunModelNode: resolve_run_model_node,
ImportModelNode: resolve_import_model_node
ImportModelNode: resolve_import_model_node,
ImportAttrsNode: resolve_import_attrs_node
}
11 changes: 10 additions & 1 deletion tracdap-runtime/python/src/tracdap/rt/_exec/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ class ImportModelNode(Node[meta.ObjectDefinition]):
import_details: meta.ImportModelJob


@_node_type
class ImportAttrsNode(Node[cfg.TagUpdateList]):

model_scope: str
import_details: meta.ImportModelJob


@_node_type
class RunModelNode(Node[Bundle[_data.DataView]]):

Expand All @@ -363,10 +370,12 @@ class BuildJobResultNode(Node[cfg.JobResult]):
job_id: meta.TagHeader

objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = dc.field(default_factory=dict)
attrs: tp.Dict[str, NodeId[cfg.TagUpdateList]] = dc.field(default_factory=dict)

bundles: tp.List[NodeId[ObjectBundle]] = dc.field(default_factory=list)

def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
dep_ids = [*self.bundles, *self.objects.values()]
dep_ids = [*self.bundles, *self.objects.values(), *self.attrs.values()]
return {node_id: DependencyType.HARD for node_id in dep_ids}


Expand Down
21 changes: 14 additions & 7 deletions tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,17 @@ def build_import_model_job(
import_id = NodeId.of("trac_import_model", job_namespace, meta.ObjectDefinition)
import_node = ImportModelNode(import_id, model_scope, import_details, explicit_deps=[job_push_id])

main_section = GraphSection(nodes={import_id: import_node}, must_run=[import_id])
import_attrs_id = NodeId.of("trac_import_attrs", job_namespace, config.TagUpdateList)
import_attrs_node = ImportAttrsNode(import_attrs_id, model_scope, import_details, explicit_deps=[import_id])

# Build job-level metadata outputs
main_section = GraphSection(nodes={import_id: import_node, import_attrs_id: import_attrs_node})

result_objects = {new_model_key: import_id}
# Build job-level metadata outputs

result_section = cls.build_job_results(
job_config, job_namespace,
result_spec, objects=result_objects,
job_config, job_namespace, result_spec,
objects={new_model_key: import_id},
attrs={new_model_key: import_attrs_id},
explicit_deps=[job_push_id, *main_section.must_run])

return cls._join_sections(main_section, result_section)
Expand Down Expand Up @@ -327,7 +329,9 @@ def build_job_outputs(
@classmethod
def build_job_results(
cls, job_config: cfg.JobConfig, job_namespace: NodeNamespace, result_spec: JobResultSpec,
objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = None, bundles: tp.List[NodeId[ObjectBundle]] = None,
objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = None,
attrs: tp.Dict[str, NodeId[config.TagUpdateList]] = None,
bundles: tp.List[NodeId[ObjectBundle]] = None,
explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
-> GraphSection:

Expand All @@ -337,9 +341,12 @@ def build_job_results(

results_inputs = set(objects.values())

if attrs is not None:
results_inputs.update(attrs.values())

build_result_node = BuildJobResultNode(
build_result_id, job_config.jobId,
objects=objects, explicit_deps=explicit_deps)
objects=objects, attrs=attrs, explicit_deps=explicit_deps)

elif bundles is not None:

Expand Down
37 changes: 37 additions & 0 deletions tracdap-runtime/python/src/tracdap/rt/_impl/api_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tracdap.rt.metadata as _meta
import tracdap.rt.exceptions as _ex
import tracdap.rt._impl.schemas as _schemas
import tracdap.rt._impl.type_system as _type_system
import tracdap.rt._impl.util as _util

# Import hook interfaces into this module namespace
Expand Down Expand Up @@ -204,6 +205,42 @@ def register_impl(cls):

log.warning("Runtime API hook is already registered")

def define_attributes(
self, *attrs: _tp.Union[_meta.TagUpdate, _tp.List[_meta.TagUpdate]]) \
-> _tp.List[_meta.TagUpdate]:

ApiGuard.validate_signature(self.define_attributes, *attrs)

if len(attrs) == 1 and isinstance(attrs[0], list):
return attrs[0]
else:
return [*attrs]

def define_attribute(
self, attr_name: str, attr_value: _tp.Any,
attr_type: _tp.Optional[_meta.BasicType] = None,
categorical: bool = False) \
-> _meta.TagUpdate:

ApiGuard.validate_signature(self.define_attribute, attr_name, attr_value, attr_type, categorical)

if isinstance(attr_value, list) and attr_type is None:
raise _ex.EModelValidation(f"Attribute type must be specified for multi-valued attribute [{attr_name}]")

if categorical and not (isinstance(attr_name, str) or attr_type == _meta.BasicType.STRING):
raise _ex.EModelValidation("Categorical flag is only allowed for STRING attributes")

if attr_type is None:
trac_value = _type_system.MetadataCodec.encode_value(attr_value)
elif isinstance(attr_value, list):
type_desc = _meta.TypeDescriptor(_meta.BasicType.ARRAY, arrayType=_meta.TypeDescriptor(attr_type))
trac_value = _type_system.MetadataCodec.convert_value(attr_value, type_desc)
else:
type_desc = _meta.TypeDescriptor(attr_type)
trac_value = _type_system.MetadataCodec.convert_value(attr_value, type_desc)

return _meta.TagUpdate(_meta.TagOperation.CREATE_OR_APPEND_ATTR, attr_name, trac_value)

def define_parameter(
self, param_name: str, param_type: _tp.Union[_meta.TypeDescriptor, _meta.BasicType],
label: str, default_value: _tp.Optional[_tp.Any] = None) \
Expand Down
19 changes: 19 additions & 0 deletions tracdap-runtime/python/src/tracdap/rt/_impl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,32 @@ def load_model_class(self, scope: str, model_def: _meta.ModelDefinition) -> _api
scope_state.model_cache[model_key] = model_class
return model_class

def scan_model_attrs(self, model_class: _api.TracModel.__class__) -> _cfg.TagUpdateList:

model: _api.TracModel = object.__new__(model_class)
model_class.__init__(model)

attributes = model.define_attributes()

for attr in attributes:

if attr.attrName.startswith("trac_") or attr.attrName.startswith("_"):
err = f"Controlled attribute [{attr.attrName}] cannot be defined in model code"
self.__log.error(err)
raise _ex.EModelValidation(err)

self.__log.info(f"Attribute [{attr.attrName}] - {_types.MetadataCodec.decode_value(attr.value)}")

return _cfg.TagUpdateList(attributes)

def scan_model(self, model_class: _api.TracModel.__class__) -> _meta.ModelDefinition:

try:

model: _api.TracModel = object.__new__(model_class)
model_class.__init__(model)

attributes = model.define_attributes()
parameters = model.define_parameters()
inputs = model.define_inputs()
outputs = model.define_outputs()
Expand Down
29 changes: 27 additions & 2 deletions tracdap-runtime/python/src/tracdap/rt/_impl/type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ def decode_value(value: _meta.Value) -> tp.Any:

raise _ex.ETracInternal("Missing type information")

basic_type = value.type.basicType
return MetadataCodec._decode_value_for_type(value, value.type)

@staticmethod
def _decode_value_for_type(value: _meta.Value, type_desc: _meta.TypeDescriptor):

basic_type = type_desc.basicType

if basic_type == _meta.BasicType.BOOLEAN:
return value.booleanValue
Expand All @@ -125,6 +130,10 @@ def decode_value(value: _meta.Value) -> tp.Any:
if basic_type == _meta.BasicType.DATETIME:
return dt.datetime.fromisoformat(value.datetimeValue.isoDatetime)

if basic_type == _meta.BasicType.ARRAY:
items = value.arrayValue.items
return list(map(lambda x: MetadataCodec._decode_value_for_type(x, type_desc.arrayType), items))

raise _ex.ETracInternal(f"Decoding value type [{basic_type}] is not supported yet")

@classmethod
Expand Down Expand Up @@ -162,7 +171,7 @@ def encode_value(cls, value: tp.Any) -> _meta.Value:
type_desc = _meta.TypeDescriptor(_meta.BasicType.DATE)
return _meta.Value(type_desc, dateValue=_meta.DateValue(value.isoformat()))

raise _ex.ETracInternal(f"Encoding value type [{type(value)}] is not supported yet")
raise _ex.ETracInternal(f"Value type [{type(value)}] is not supported yet")

@classmethod
def convert_value(cls, raw_value: tp.Any, type_desc: _meta.TypeDescriptor):
Expand All @@ -188,8 +197,24 @@ def convert_value(cls, raw_value: tp.Any, type_desc: _meta.TypeDescriptor):
if type_desc.basicType == _meta.BasicType.DATETIME:
return cls.convert_datetime_value(raw_value)

if type_desc.basicType == _meta.BasicType.ARRAY:
return cls.convert_array_value(raw_value, type_desc.arrayType)

raise _ex.ETracInternal(f"Conversion to value type [{type_desc.basicType.name}] is not supported yet")

@staticmethod
def convert_array_value(raw_value: tp.List[tp.Any], array_type: _meta.TypeDescriptor) -> _meta.Value:

type_desc = _meta.TypeDescriptor(_meta.BasicType.ARRAY, array_type)

if not isinstance(raw_value, list):
msg = f"Value of type [{type(raw_value)}] cannot be converted to {_meta.BasicType.ARRAY.name}"
raise _ex.ETracInternal(msg)

items = list(map(lambda x: MetadataCodec.convert_value(x, array_type), raw_value))

return _meta.Value(type_desc, arrayValue=_meta.ArrayValue(items))

@staticmethod
def convert_boolean_value(raw_value: tp.Any) -> _meta.Value:

Expand Down
16 changes: 16 additions & 0 deletions tracdap-runtime/python/src/tracdap/rt/api/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def runtime(cls) -> _RuntimeHook:

return cls.__runtime_hook

@_abc.abstractmethod
def define_attributes(
self, *attrs: _tp.Union[_meta.TagUpdate, _tp.List[_meta.TagUpdate]]) \
-> _tp.List[_meta.TagUpdate]:

pass

@_abc.abstractmethod
def define_attribute(
self, attr_name: str, attr_value: _tp.Any,
attr_type: _tp.Optional[_meta.BasicType] = None,
categorical: bool = False) \
-> _meta.TagUpdate:

pass

@_abc.abstractmethod
def define_parameter(
self, param_name: str, param_type: _tp.Union[_meta.TypeDescriptor, _meta.BasicType],
Expand Down
Loading

0 comments on commit cfa246c

Please sign in to comment.