diff --git a/src/machinable/component.py b/src/machinable/component.py index 4f974c2d..4155eac8 100644 --- a/src/machinable/component.py +++ b/src/machinable/component.py @@ -67,6 +67,10 @@ def executions() -> ExecutionCollection: return Execution + @property + def components(self) -> "ComponentCollection": + return ComponentCollection([self]) + @property def execution(self) -> "Execution": from machinable.execution import Execution diff --git a/src/machinable/interface.py b/src/machinable/interface.py index 3de9e4ee..2d986960 100644 --- a/src/machinable/interface.py +++ b/src/machinable/interface.py @@ -18,7 +18,11 @@ import dill as pickle from machinable import errors, schema -from machinable.collection import Collection, InterfaceCollection +from machinable.collection import ( + Collection, + ComponentCollection, + InterfaceCollection, +) from machinable.element import _CONNECTIONS as connected_elements from machinable.element import Element, get_dump, get_lineage from machinable.types import VersionType @@ -289,10 +293,10 @@ def push_related(self, key: str, value: "Interface") -> None: self.__related__[key] = value self._relation_cache[key] = True - def is_staged(self): + def is_staged(self) -> bool: return self.__model__.uuid[-12:] != "0" * 12 - def stage(self): + def stage(self) -> Self: self.__model__.context = context = self.compute_context() self.__model__.uuid = update_uuid_payload(self.__model__.uuid, context) @@ -300,6 +304,8 @@ def stage(self): assert self.config is not None self.__model__.predicate = self.compute_predicate() + return self + def is_committed(self) -> bool: from machinable.index import Index @@ -740,12 +746,19 @@ def save_file(self, filepath: Union[str, List[str]], data: Any) -> str: def launch(self) -> Self: ... - def cached(self): - from machinable.execution import Execution + @property + def components(self) -> "ComponentCollection": + if "components" not in self._cache: + from machinable.execution import Execution + + with Execution().deferred() as e: + self.launch() + self._cache["components"] = e.executables - with Execution().deferred() as e: - self.launch() - return e.executables.reduce( + return self._cache["components"] + + def cached(self): + return self.components.reduce( lambda result, x: result and x.cached(), True ) diff --git a/tests/test_interface.py b/tests/test_interface.py index 3ba0f146..009bcfa2 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -362,6 +362,16 @@ def test(self): p.__exit__() +def test_interface_components(tmp_storage): + class T(Interface): + def launch(self): + get("machinable.component").launch() + + assert len(T().components) == 1 + t = get("machinable.component") + assert t == t.components[0] + + def test_interface_cachable(tmp_storage): counts = { "test": 0,