Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed May 19, 2023
1 parent 252bbfd commit f7a5190
Show file tree
Hide file tree
Showing 20 changed files with 231 additions and 211 deletions.
26 changes: 14 additions & 12 deletions src/machinable/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,30 +1445,32 @@ def launch(self) -> "ComponentCollection":

return self


class ExecutionCollection(ElementCollection):
def status(self, status="started"):
"""Filters the collection by a status attribute
# Arguments
status: String, status field: 'started', 'finished', 'alive'
"""
try:
return self.filter(lambda item: getattr(item, "is_" + status)())
return self.filter(
lambda item: item.execution
and getattr(item.execution, "is_" + status)()
)
except AttributeError as _ex:
raise ValueError(f"Invalid status field: {status}") from _ex

def finished(self) -> "ExecutionCollection":
return self.filter(lambda x: x.is_finished())

def started(self) -> "ExecutionCollection":
return self.filter(lambda x: x.is_started())

def active(self) -> "ExecutionCollection":
return self.filter(lambda x: x.is_active())
class ExecutionCollection(ElementCollection):
def status(self, status="started"):
"""Filters the collection by a status attribute
def incomplete(self) -> "ExecutionCollection":
return self.filter(lambda x: x.is_incomplete())
# Arguments
status: String, status field: 'started', 'finished', 'alive'
"""
try:
return self.filter(lambda item: getattr(item, "is_" + status)())
except AttributeError as _ex:
raise ValueError(f"Invalid status field: {status}") from _ex

def __str__(self):
return f"Executions <{len(self.items)}>"
40 changes: 30 additions & 10 deletions src/machinable/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

from typing import Dict

from machinable import errors
from machinable import errors, schema
from machinable.collection import ComponentCollection, ExecutionCollection
from machinable.element import Element
from machinable.interface import Interface, belongs_to, has_many
from machinable.element import Element, get_dump, get_lineage
from machinable.interface import Interface, belongs_to, belongs_to_many
from machinable.project import Project
from machinable.storage import Storage
from machinable.types import VersionType
Expand All @@ -30,13 +30,33 @@ class Component(Interface):
kind = "Component"
default = get_settings().default_component

@has_many
def __init__(
self,
version: VersionType = None,
uses: Union[None, "Interface", List["Interface"]] = None,
derived_from: Optional["Interface"] = None,
):
super().__init__(version=version, uses=uses, derived_from=derived_from)
self.__model__ = schema.Component(
kind=self.kind,
module=self.__model__.module,
config=self.__model__.config,
version=self.__model__.version,
lineage=get_lineage(self),
)
self.__model__._dump = get_dump(self)

@belongs_to
def group():
return Group

@belongs_to_many(key="execution_history")
def executions() -> ExecutionCollection:
from machinable.execution import Execution

return Execution

@belongs_to(cached=False)
@belongs_to(key="execution_history", cached=False)
def execution() -> "Execution":
from machinable.execution import Execution

Expand Down Expand Up @@ -84,7 +104,7 @@ def dispatch(self) -> Self:
if writes_meta_data:
self.execution.mark_started()
self.save_file(
"env.json",
"host.json",
data=Project.get().provider().get_host_info(),
)

Expand Down Expand Up @@ -123,6 +143,10 @@ def beat():
f"{self.__class__.__name__} dispatch failed"
) from _ex

@property
def host_info(self) -> Optional[Dict]:
return self.load_file("host.json", None)

def cached(self) -> bool:
if self.execution is None:
return False
Expand Down Expand Up @@ -224,7 +248,3 @@ def group_as(self, group: Union[Group, str]) -> Self:
self.push_related("group", group)

return self

@belongs_to
def group():
return Group
20 changes: 10 additions & 10 deletions src/machinable/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def version(
if version is sentinel:
return self.__model__.version

if self.mounted():
if hasattr(self, "is_mounted") and self.is_mounted():
raise MachinableError(
f"Cannot change version of mounted element {self}"
)
Expand Down Expand Up @@ -625,15 +625,15 @@ def _clear_caches(self) -> None:
self._config = None
self.__model__.config = None

# def __getattr__(self, name) -> Any:
# attr = getattr(self.__mixin__, name, None)
# if attr is not None:
# return attr
# raise AttributeError(
# "{!r} object has no attribute {!r}".format(
# self.__class__.__name__, name
# )
# )
def __getattr__(self, name) -> Any:
attr = getattr(self.__mixin__, name, None)
if attr is not None:
return attr
raise AttributeError(
"{!r} object has no attribute {!r}".format(
self.__class__.__name__, name
)
)

def __enter__(self) -> Self:
_CONNECTIONS[self.kind].append(self)
Expand Down
10 changes: 3 additions & 7 deletions src/machinable/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def timestamp(self) -> int:
def schedule() -> "Schedule":
return Schedule

@has_many
@has_many(key="execution_history")
def executables() -> ComponentCollection:
return Component

Expand Down Expand Up @@ -229,10 +229,6 @@ def on_after_dispatch(self) -> None:
def host_info(self) -> Optional[Dict]:
return self.load_file("host.json", None)

@property
def env_info(self) -> Optional[Dict]:
return self.load_file("env.json", None)

@property
def nickname(self) -> str:
return self.__model__.nickname
Expand Down Expand Up @@ -291,7 +287,7 @@ def output(self, incremental: bool = False) -> Optional[str]:
read_length = self._cache.get("output_read_length", 0)
if read_length == -1:
return ""
output = load_file(self.local_directory("output.log"), None)
output = self.load_file("output.log", None)
if output is None:
return None

Expand All @@ -304,7 +300,7 @@ def output(self, incremental: bool = False) -> Optional[str]:
if "output" in self._cache:
return self._cache["output"]

output = load_file(self.local_directory("output.log"), None)
output = self.load_file("output.log", None)

if self.is_finished():
self._cache["output"] = output
Expand Down
1 change: 1 addition & 0 deletions src/machinable/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def load(database: str, create=False) -> sqlite3.Connection:
if database.startswith("sqlite:///"):
database = database[10:]
if ":memory:" not in database:
database = os.path.expanduser(database)
if not os.path.isfile(database):
if not create:
raise FileNotFoundError(
Expand Down
15 changes: 13 additions & 2 deletions src/machinable/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def push_related(self, key: str, value: "Interface") -> None:
def commit(self) -> Self:
Storage.get().commit(self)

# write deferred files
for filepath, data in self._deferred_data.items():
self.save_file(filepath, data)
self._deferred_data = {}

return self

@belongs_to
Expand Down Expand Up @@ -308,7 +313,7 @@ def find_by_predicate(

return cls.collect(
[
cls.from_model(interface)
cls.find(interface.uuid)
for interface in Storage.get().index.find_by_predicate(
module
if isinstance(module, str)
Expand All @@ -332,11 +337,17 @@ def from_directory(cls, directory: str) -> "Element":
# TODO: users should have an option to register custom interface types
raise ValueError(f"Invalid interface kind: {model['kind']}")

return cls.from_model(model(**data))
interface = model(**data)
if interface.module.startswith("__session__"):
interface._dump = load_file(os.path.join(directory, "dump.p"), None)

return cls.from_model(interface)

def to_directory(self, directory: str, relations=True) -> Self:
save_file(os.path.join(directory, ".machinable"), self.__model__.uuid)
save_file(os.path.join(directory, "model.json"), self.__model__)
if self.__model__._dump is not None:
save_file(os.path.join(directory, "dump.p"), self.__model__._dump)
if relations:
for k, v in self.__related__.items():
if hasattr(v, "uuid"):
Expand Down
69 changes: 35 additions & 34 deletions src/machinable/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import platform
import socket
import sys
from dataclasses import dataclass

import machinable
from commandlib import Command
from machinable import schema
from machinable.config import Field, validator
from machinable.element import Element, get_lineage, instantiate, normversion
from machinable.errors import ConfigurationError
from machinable.interface import Interface
Expand Down Expand Up @@ -163,37 +165,31 @@ def import_element(
class Project(Interface):
kind = "Project"

@dataclass
class Config:
directory: Optional[str] = Field(default_factory=lambda: os.getcwd())

@validator("directory")
def normalize_directory(cls, v):
return os.path.normpath(os.path.abspath(os.path.expanduser(v)))

def __init__(
self,
directory: Optional[str] = None,
version: VersionType = None,
name: Optional[str] = None,
):
super().__init__()
if directory is None:
directory = os.getcwd()
directory = os.path.abspath(directory)
if name is None:
name = os.path.basename(directory)
if isinstance(version, str) and not version.startswith("~"):
# interpret as shortcut for directory
version = {"directory": version}
super().__init__(version=version)
self.__model__ = schema.Project(
kind=self.kind,
directory=directory,
name=name,
version=normversion(version),
lineage=get_lineage(self),
)
self._parent: Optional[Project] = None
self._provider: str = "_machinable/project"
self._resolved_provider: Optional[Project] = None

@classmethod
def instance(
cls,
directory: Optional[str] = None,
version: VersionType = None,
) -> "Project":
return Project(directory, version).provider()

def provider(self, reload: Union[str, bool] = False) -> "Project":
"""Resolves and returns the provider instance"""
if isinstance(reload, str):
Expand All @@ -203,11 +199,11 @@ def provider(self, reload: Union[str, bool] = False) -> "Project":
if isinstance(self._provider, str):
self._resolved_provider = find_subclass_in_module(
module=import_from_directory(
self._provider, self.__model__.directory
self._provider, self.config.directory
),
base_class=Project,
default=Project,
)(self.__model__.directory, version=self.__model__.version)
)(version=self.__model__.version)
else:
self._resolved_provider = Project(
version=self.__model__.version
Expand All @@ -224,19 +220,19 @@ def module(self) -> Optional[str]:

def add_to_path(self) -> None:
if (
os.path.exists(self.__model__.directory)
and self.__model__.directory not in sys.path
os.path.exists(self.config.directory)
and self.config.directory not in sys.path
):
if self.is_root():
sys.path.insert(0, self.__model__.directory)
sys.path.insert(0, self.config.directory)
else:
sys.path.append(self.__model__.directory)
sys.path.append(self.config.directory)

def name(self) -> str:
return self.__model__.name
return os.path.basename(self.config.directory)

def path(self, *append: str) -> str:
return os.path.join(self.__model__.directory, *append)
return os.path.join(self.config.directory, *append)

def is_root(self) -> bool:
return self._parent is None
Expand Down Expand Up @@ -264,8 +260,8 @@ def get_vendors(self) -> List[str]:

def get_code_version(self) -> dict:
return {
"id": get_root_commit(self.__model__.directory),
"project": get_commit(self.__model__.directory),
"id": get_root_commit(self.config.directory),
"project": get_commit(self.config.directory),
"vendor": {
vendor: get_commit(self.path("vendor", vendor))
for vendor in self.get_vendors()
Expand All @@ -290,23 +286,28 @@ def element(
if not isinstance(element_class, Element):
element_class = import_element(self.path(), module, base_class)

return instantiate(
element = instantiate(
module,
element_class,
version,
**constructor_kwargs,
)

if isinstance(element, Interface):
element.push_related("project", self)

return element

def get_diff(self) -> Union[str, None]:
return get_diff(self.path())

def exists(self) -> bool:
return os.path.exists(self.__model__.directory)
return os.path.exists(self.config.directory)

def serialize(self) -> dict:
return {
"directory": self.__model__.directory,
"provider": self._provider,
"directory": self.config.directory,
"name": self.name,
}

@classmethod
Expand Down Expand Up @@ -358,5 +359,5 @@ def global_predicate(self) -> Dict:
"""Project-wide element predicates."""
return {}

# def __repr__(self) -> str:
# return f"Project({self.__model__.directory})"
def __repr__(self) -> str:
return f"Project({self.config.directory})"
Loading

0 comments on commit f7a5190

Please sign in to comment.