Skip to content

Commit

Permalink
Introduce foreach edge
Browse files Browse the repository at this point in the history
  • Loading branch information
fridex committed Oct 7, 2016
1 parent f8948a4 commit 71c148d
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 18 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ clean:
check:
@# We have to adjust CWD so we use our own Celery and modified Selinon Dispatcher for testing
@python3 --version
@cd test && python3 -m unittest -v test_systemState test_nodeFailures test_storage test_nowait test_flow test_node_args test_others
@cd test && python3 -m unittest -v test_systemState test_nodeFailures test_storage test_nowait test_flow \
test_node_args test_others test_foreach

doc:
@sphinx-apidoc -e -o docs.source/api selinon -f
Expand Down
59 changes: 42 additions & 17 deletions selinon/systemState.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _get_successful_and_failed(self):

return ret_successful, ret_failed

def _start_node(self, node_name, parent, node_args, finished=None):
def _start_node(self, node_name, parent, node_args, finished=None, force_propagate_node_args=False):
"""
Start a node in the system
Expand All @@ -134,13 +134,10 @@ def _start_node(self, node_name, parent, node_args, finished=None):
"""
from .dispatcher import Dispatcher
if Config.is_flow(node_name):
if Config.propagate_node_args.get(self._flow_name):
if Config.propagate_node_args.get(self._flow_name) is True or \
(isinstance(Config.propagate_node_args.get(self._flow_name), list) and
node_name in Config.propagate_node_args.get(self._flow_name)):
node_args = node_args
else:
node_args = None
if force_propagate_node_args or Config.propagate_node_args.get(self._flow_name) is True or \
(isinstance(Config.propagate_node_args.get(self._flow_name), list) and
node_name in Config.propagate_node_args.get(self._flow_name)):
node_args = node_args
else:
node_args = None

Expand Down Expand Up @@ -193,6 +190,34 @@ def _start_node(self, node_name, parent, node_args, finished=None):

return record

def _fire_edge(self, edge, storage_pool, parent, node_args, finished=None):
"""
Fire edge - start new nodes as described in edge table
:param edge: edge that should be fired
:param storage_pool: storage pool which makes results of previous tasks available
:param parent: parent nodes
:param node_args: node arguments
:param finished: finished nodes if propagated
:return: list of nodes that were scheduled
"""
ret = []

if 'foreach' in edge:
for res in edge['foreach'](node_args, storage_pool):
for node_name in edge['to']:
if edge.get('foreach_propagate_result'):
record = self._start_node(node_name, parent, res, finished, force_propagate_node_args=True)
else:
record = self._start_node(node_name, parent, node_args, finished)
ret.append(record)
else:
for node_name in edge['to']:
record = self._start_node(node_name, parent, node_args, finished)
ret.append(record)

return ret

def _run_fallback(self):
"""
Run fallback in the system
Expand Down Expand Up @@ -397,10 +422,9 @@ def _start_new_from_finished(self, new_finished):
# finished_nodes to 'condition' in order to do inspection
storage_pool = StoragePool(storage_id_mapping)
if edge['condition'](storage_pool, self._node_args):
for node_name in edge['to']:
record = self._start_node(node_name, parent=parent, node_args=self._node_args,
finished=self._finished)
ret.append(record)
records = self._fire_edge(edge, storage_pool, parent=parent, node_args=self._node_args,
finished=self._finished)
ret.extend(records)

node_name = node['name']
if not self._finished_nodes.get(node_name):
Expand Down Expand Up @@ -428,11 +452,12 @@ def _start_and_update_retry(self):
for start_edge in start_edges:
storage_pool = StoragePool()
if start_edge['condition'](storage_pool, self._node_args):
for node_name in start_edge['to']:
node = self._start_node(node_name, node_args=self._node_args, parent=self._parent,
finished=self._finished)
if node_name not in Config.nowait_nodes.get(self._flow_name, []):
self._update_waiting_edges(node_name)
records = self._fire_edge(start_edge, storage_pool, node_args=self._node_args, parent=self._parent,
finished=self._finished)

for node in records:
if node['name'] not in Config.nowait_nodes.get(self._flow_name, []):
self._update_waiting_edges(node['name'])
new_started_nodes.append(node)

self._retry = Config.strategy_function(previous_retry=None,
Expand Down
20 changes: 20 additions & 0 deletions test/selinonTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def get_task(task_name, idx=None):
tasks = Config.get_task_instance.task_by_name(task_name)
return tasks[idx if idx is not None else -1]

@staticmethod
def get_all_tasks(task_name):
"""
Get all tasks by task name
:param task_name: name of task
:return: tasks
"""
return Config.get_task_instance.task_by_name(task_name)

@staticmethod
def get_flow(flow_name, idx=None):
"""
Expand All @@ -142,6 +152,16 @@ def get_flow(flow_name, idx=None):
tasks = Config.get_task_instance.flow_by_name(flow_name)
return tasks[idx or -1]

@staticmethod
def get_all_flows(flow_name):
"""
Get all flows by its name
:param flow_name: name of flow
:return: flows
"""
return Config.get_task_instance.flow_by_name(flow_name)

@property
def get_task_instance(self):
"""
Expand Down
159 changes: 159 additions & 0 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ####################################################################
# Copyright (C) 2016 Fridolin Pokorny, fpokorny@redhat.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# ####################################################################

from selinonTestCase import SelinonTestCase

from selinon import SystemState


# Let's make it constant, this shouldn't affect tests at all
_FOREACH_COUNT = 20


class TestForeach(SelinonTestCase):
def test_foreach_start(self):
#
# flow1:
#
# | | |
# Task1 Task1 ... Task1
#
# Note:
# There will be spawned _FOREACH_COUNT Task2
#
edge_table = {
'flow1': [{'from': [], 'to': ['Task1'], 'condition': self.cond_true,
'foreach': lambda x, y: range(_FOREACH_COUNT), 'foreach_propagate_result': False}]
}
self.init(edge_table)

system_state = SystemState(id(self), 'flow1')
retry = system_state.update()
state_dict = system_state.to_dict()

self.assertIsNotNone(retry)
self.assertIsNone(system_state.node_args)
self.assertIn('Task1', self.instantiated_tasks)
self.assertEqual(len(self.get_all_tasks('Task1')), _FOREACH_COUNT)
tasks_state_dict = [node for node in state_dict['active_nodes'] if node['name'] == 'Task1']
self.assertEqual(len(tasks_state_dict), _FOREACH_COUNT)

def test_foreach_basic(self):
#
# flow1:
#
# Task1
# |
# |
# ---------------------
# | | |
# | | |
# Task2 Task2 ... Task2
#
# Note:
# There will be spawned _FOREACH_COUNT Task2
#
edge_table = {
'flow1': [{'from': ['Task1'], 'to': ['Task2'], 'condition': self.cond_true,
'foreach': lambda x, y: range(_FOREACH_COUNT), 'foreach_propagate_result': False},
{'from': [], 'to': ['Task1'], 'condition': self.cond_true}]
}
self.init(edge_table)

system_state = SystemState(id(self), 'flow1')
retry = system_state.update()
state_dict = system_state.to_dict()

self.assertIsNotNone(retry)
self.assertIsNone(system_state.node_args)
self.assertIn('Task1', self.instantiated_tasks)
self.assertNotIn('Task2', self.instantiated_tasks)

# Task1 has finished
task1 = self.get_task('Task1')
self.set_finished(task1, "some result")

system_state = SystemState(id(self), 'flow1', state=state_dict,
node_args=system_state.node_args)
retry = system_state.update()
state_dict = system_state.to_dict()

self.assertIsNotNone(retry)
self.assertIsNone(system_state.node_args)
self.assertIn('Task1', self.instantiated_tasks)
self.assertIn('Task2', self.instantiated_tasks)

self.assertEqual(len(self.get_all_tasks('Task2')), _FOREACH_COUNT)
tasks_state_dict = [node for node in state_dict['active_nodes'] if node['name'] == 'Task2']
self.assertEqual(len(tasks_state_dict), _FOREACH_COUNT)

def test_foreach_propagate_result(self):
#
# flow1:
#
# Task1
# |
# |
# ---------------------
# | | |
# | | |
# flow2 flow2 ... flow2
#
# Note:
# There will be spawned _FOREACH_COUNT flow2, arguments are passed from foreach function
#
edge_table = {
'flow1': [{'from': ['Task1'], 'to': ['flow2'], 'condition': self.cond_true,
'foreach': lambda x, y: range(_FOREACH_COUNT), 'foreach_propagate_result': True},
{'from': [], 'to': ['Task1'], 'condition': self.cond_true}],
'flow2': []
}
self.init(edge_table)

system_state = SystemState(id(self), 'flow1')
retry = system_state.update()
state_dict = system_state.to_dict()

self.assertIsNotNone(retry)
self.assertIsNone(system_state.node_args)
self.assertIn('Task1', self.instantiated_tasks)
self.assertNotIn('Task2', self.instantiated_tasks)

# Task1 has finished
task1 = self.get_task('Task1')
self.set_finished(task1, "some result")

system_state = SystemState(id(self), 'flow1', state=state_dict,
node_args=system_state.node_args)
retry = system_state.update()
state_dict = system_state.to_dict()

self.assertIsNotNone(retry)
self.assertIsNone(system_state.node_args)
self.assertIn('Task1', self.instantiated_tasks)
self.assertIn('flow2', self.instantiated_flows)

tasks_state_dict = [node for node in state_dict['active_nodes'] if node['name'] == 'flow2']
self.assertEqual(len(tasks_state_dict), _FOREACH_COUNT)

# Inspect node_args as we set propagate_result for foreach
all_flow_args = [flow.node_args for flow in self.get_all_flows('flow2')]
self.assertEqual(all_flow_args, list(range(_FOREACH_COUNT)))

0 comments on commit 71c148d

Please sign in to comment.