Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chg: ♻️ make count_occurences a method of ExecNode #203

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 40 additions & 36 deletions tawazi/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,6 @@
Alias = Union[Tag, Identifier, "ExecNode"]


def count_occurrences(id_: str, exec_nodes: Dict[str, "ExecNode"]) -> int:
"""Count the number of occurrences of an id in exec_nodes.

Avoids counting the ids of the arguments passed to previously called ExecNodes.
example: id_ = "a"
ExecNode a is called five times, hence we should have ids a, a<<1>>, a<<2>>, a<<3>>, a<<4>>
ExecNode a is called with many arguments:
we want to avoid counting "a>>>nth argument" and a<<1>>>>nth argument"

Args:
id_ (str): the id to count
exec_nodes (Dict[str, ExecNode]): the dictionary of ExecNodes

Returns:
int: the number of occurrences of id_ in exec_nodes
"""
# only choose the ids that are exactly exactly the same as the original id
candidate_ids = (xn_id for xn_id in exec_nodes if xn_id.split(USE_SEP_START)[0] == id_)

# count the number of ids that are exactly the same as the original id
# or that end with USE_SEP_END (which means they come from a reuse of the same ExecNode)
return sum(xn_id == id_ or xn_id.endswith(USE_SEP_END) for xn_id in candidate_ids)


@dataclass(frozen=True)
class ExecNode:
"""Base class for executable node in a DAG.
Expand Down Expand Up @@ -111,8 +87,8 @@ class ExecNode:
unpack_to: Optional[int] = None
resource: Resource = cfg.TAWAZI_DEFAULT_RESOURCE

args: List[UsageExecNode] = field(default_factory=list) # args or []
kwargs: Dict[Identifier, UsageExecNode] = field(default_factory=dict) # kwargs or {}
args: List[UsageExecNode] = field(default_factory=list)
kwargs: Dict[Identifier, UsageExecNode] = field(default_factory=dict)

def __post_init__(self) -> None:
"""Post init to validate attributes."""
Expand Down Expand Up @@ -235,6 +211,32 @@ def execute(self, results: Dict[Identifier, Any], profiles: Dict[Identifier, Pro
logger.debug("Finished executing {} with task {}", self.id, self.exec_function)
return results[self.id]

def count_occurrences(self, exec_nodes: Dict[str, "ExecNode"]) -> int:
"""Count the number of occurrences of an id in exec_nodes.

Avoids counting the ids of the arguments passed to previously called ExecNodes.
example: id_ = "a"
ExecNode a is called five times, hence we should have ids a, a<<1>>, a<<2>>, a<<3>>, a<<4>>
ExecNode a is called with many arguments:
we want to avoid counting "a>>>nth argument" and a<<1>>>>nth argument"

Args:
exec_nodes (Dict[str, ExecNode]): the dictionary of ExecNodes

Returns:
int: the number of occurrences of self in exec_nodes
"""
# count the number of ids that are exactly the same as the original id
# or that end with USE_SEP_END (which means they come from a reuse of the same ExecNode)
return len(
[
xn_id
for xn_id in exec_nodes
if xn_id == self.id
or (xn_id.split(USE_SEP_START)[0] == self.id and xn_id.endswith(USE_SEP_END))
]
)


class ReturnExecNode(ExecNode):
"""ExecNode corresponding to a constant Return value of a DAG."""
Expand Down Expand Up @@ -351,10 +353,16 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RVXN:
return self.exec_function(*args, **kwargs) # type: ignore[no-any-return]

# 1.1 if ExecNode is used multiple times, <<usage_count>> is appended to its ID
id_ = _lazy_xn_id(self.id, count_occurrences(self.id, exec_nodes))
# 1.1 Construct a new LazyExecNode corresponding to the current call
id_ = _lazy_xn_id(self.id, self.count_occurrences(exec_nodes))
exec_nodes[id_] = self._new_lxn(id_, *args, **kwargs)
return exec_nodes[id_]._usage_exec_node # type: ignore[no-any-return,attr-defined]

def _new_lxn(
self, id_: Identifier, *args: P.args, **kwargs: P.kwargs
) -> "LazyExecNode[P, RVXN]":
"""Construct a new LazyExecNode corresponding to the current call."""
values = dataclasses.asdict(self)
# force deepcopying instead of the default behavior of asdict: recursively apply asdict to dataclasses!
# force deep copy instead of the default behavior of asdict: recursively apply asdict to dataclasses!
values["exec_function"] = deepcopy(self.exec_function)
values["id_"] = id_

Expand All @@ -366,13 +374,9 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RVXN:
values["tag"] = kwargs.get(ARG_NAME_TAG) or self.tag
values["unpack_to"] = kwargs.get(ARG_NAME_UNPACK_TO) or self.unpack_to

new_lxn: LazyExecNode[P, RVXN] = LazyExecNode(**values)

new_lxn._validate_dependencies()

exec_nodes[new_lxn.id] = new_lxn

return new_lxn._usage_exec_node # type: ignore[return-value]
lxn: LazyExecNode[P, RVXN] = LazyExecNode(**values)
lxn._validate_dependencies()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question than before here, why isn't it done in the init ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we initailize LazyExecNodes in DAG.config_from_dict. And during that phase, we don't have any data in the global variable exec_nodes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm then maybe we should consider not having this a method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I can remove the code from this method and put it directly in the __call__...
I only did it to refactor the code a little bit and make the call tinier

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I just further thought about it... I think I have a solution for both cases!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice !!

return lxn

def _validate_dependencies(self) -> None:
for dep in self.dependencies:
Expand Down
Loading