Skip to content

Commit

Permalink
Merge pull request #90 from evo-company/fix-wrong-graph_link-from-clo…
Browse files Browse the repository at this point in the history
…sure

do not use graph_link and schedule from closure, pass them directly
  • Loading branch information
kindermax committed Nov 7, 2022
2 parents a50cafb + d471d0c commit b47c21f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
7 changes: 5 additions & 2 deletions hiku/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,9 @@ def process_node(
if isinstance(graph_link.requires, list):
done_deps = set()

def add_done_dep_callback(dep: Dep, req: Any) -> None:
def add_done_dep_callback(
dep: Dep, req: Any, graph_link: Link, schedule: Callable
) -> None:
def done_cb() -> None:
done_deps.add(req)
if done_deps == set(graph_link.requires):
Expand All @@ -615,7 +617,8 @@ def done_cb() -> None:
self._queue.add_callback(dep, done_cb)

for req in graph_link.requires:
add_done_dep_callback(to_dep[to_func[req]], req)
add_done_dep_callback(
to_dep[to_func[req]], req, graph_link, schedule)
else:
dep = to_dep[to_func[graph_link.requires]]
self._queue.add_callback(dep, schedule)
Expand Down
18 changes: 16 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ def execute(query_node):
def link_song_info(reqs: List[Tuple]):
return reqs

def link_artist(ids):
return ids

@pass_context
def song_info_fields(ctx, fields, ids):
db = ctx[SA_ENGINE_KEY]
Expand All @@ -336,19 +339,24 @@ def get_fields(id_):
return [list(get_fields(id_)) for id_ in ids]

song_query = FieldsQuery(SA_ENGINE_KEY, song_table)
artist_query = FieldsQuery(SA_ENGINE_KEY, artist_table)

graph = Graph([
Node('SongInfo', [
Field('album_name', None, song_info_fields),
Field('artist_name', None, song_info_fields),
]),
Node('Artist', [
Field('id', None, artist_query),
]),
Node('Song', [
Field('id', None, song_query),
Field('name', None, song_query),
Field('album_id', None, song_query),
Field('artist_id', None, song_query),
Link('info', TypeRef['SongInfo'], link_song_info,
requires=['album_id', 'artist_id'])
requires=['album_id', 'artist_id']),
Link('artist', TypeRef['Artist'], link_artist, requires='artist_id')
]),
Root([
Link('song', TypeRef['Song'], link_song, requires=None),
Expand All @@ -360,13 +368,19 @@ def get_fields(id_):
Q.info[
Q.album_name,
Q.artist_name,
],
# we are querying 'artist' here to test that its requires does not
# affect the requires of the 'info' link
Q.artist[
Q.id,
]
]
])
result = execute(query)
check_result(
result,
{'song': {'info': {'album_name': 'Reload', 'artist_name': 'Metallica'}}}
{'song': {'info': {'album_name': 'Reload', 'artist_name': 'Metallica'},
'artist': {'id': 1}}}
)


Expand Down

0 comments on commit b47c21f

Please sign in to comment.