Skip to content

Commit

Permalink
Merge pull request #50 from markovianhq/fix/slicing
Browse files Browse the repository at this point in the history
Fix/slicing
  • Loading branch information
volkale committed Jul 29, 2017
2 parents 943625e + cfe7dec commit daef5a6
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 22 deletions.
80 changes: 61 additions & 19 deletions bonspy/bonsai.py
Expand Up @@ -103,7 +103,8 @@ def _transform_splits(self):
self.node[node_id]['split'] = OrderedDict()

for child_id in self.successors_iter(node_id):
self.node[node_id]['split'][child_id] = split
if not self.node[child_id].get('is_default_leaf', self.node[child_id].get('is_default_node')):
self.node[node_id]['split'][child_id] = split

def _slice_graph(self):
for slice_feature in self.slice_features:
Expand All @@ -123,7 +124,7 @@ def _slice_feature_out_of_graph(self, slice_feature):
next_nodes = self.successors(node_id)
queue.extend(next_nodes)
else:
self._update_sub_graph(node_id, slice_feature)
queue = self._update_sub_graph(node_id, slice_feature, queue)

def _split_contains_slice_feature(self, node_id, slice_feature):
try:
Expand All @@ -132,24 +133,41 @@ def _split_contains_slice_feature(self, node_id, slice_feature):
except KeyError: # default leaf or leaf
return False

def _update_sub_graph(self, node_id, slice_feature):
def _update_sub_graph(self, node_id, slice_feature, queue):
self._prune_unwanted_children(node_id, slice_feature)

default_child = next((n for n in self.successors_iter(node_id) if self.node[n].get('is_default_leaf')))
normal_child = next((n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf')))
assert len([n for n in self.successors_iter(node_id) if n not in {default_child, normal_child}]) == 0

if self.node[normal_child].get('is_leaf'):
self._remove_leaves_and_update_parent_default(node_id, slice_feature, normal_child, default_child)
else:
self._splice_out_node(normal_child, slice_feature, slicing=True)
try:
normal_child = self._get_normal_child(node_id, slice_feature)
other_children = [n for n in self.successors_iter(node_id) if n not in {normal_child, default_child}]
queue.extend(other_children)

if self.node[normal_child].get('is_leaf'):
self._remove_leaves_and_update_parent_default(
node_id, slice_feature, normal_child, default_child, other_children
)
else:
self._splice_out_node(normal_child, slice_feature, slicing=True)

except StopIteration: # slice feature value not present in subtree
other_children = [n for n in self.successors_iter(node_id) if n != default_child]
if other_children:
queue.extend(other_children)
else:
self._cut_single_default_child(node_id, default_child)

return queue

def _prune_unwanted_children(self, node_id, slice_feature):
prunable_children = [
n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf') and
slice_feature in self.node[n]['state'] and
self.node[n]['state'].get(slice_feature) != self.slice_feature_values[slice_feature]
]
for prunable_child in prunable_children:
if self.node[node_id].get('split'):
del self.node[node_id]['split'][prunable_child]
self._remove_sub_graph(prunable_child)

def _remove_sub_graph(self, node):
Expand All @@ -160,18 +178,30 @@ def _remove_sub_graph(self, node):
self.remove_node(current_node)
queue.extend(next_nodes)

def _remove_leaves_and_update_parent_default(self, node_id, slice_feature, normal_child, default_child):
del self.node[node_id]['split']
self._remove_feature_from_state(node_id, slice_feature)
self.node[node_id] = self.node[normal_child].copy()
def _get_normal_child(self, node_id, slice_feature):
return next((
n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf') and
slice_feature in self.node[n]['state']
))

self.remove_edge(node_id, normal_child)
self.remove_node(normal_child)
def _remove_leaves_and_update_parent_default(self, node_id, slice_feature, normal_child,
default_child, other_children):
if not other_children:
del self.node[node_id]['split']
self._remove_feature_from_state(node_id, slice_feature)
self.node[node_id] = self.node[normal_child].copy()

self.remove_edge(node_id, default_child)
self.remove_node(default_child)
self.remove_edge(node_id, default_child)
self.remove_node(default_child)
else:
del self.node[node_id]['split'][normal_child]
self._remove_feature_from_state(node_id, slice_feature)
self.node[default_child] = self.node[normal_child].copy()
del self.node[default_child]['is_leaf']
self.node[default_child]['is_default_leaf'] = True

self.node[node_id]['is_leaf'] = True
self.remove_edge(node_id, normal_child)
self.remove_node(normal_child)

def _remove_feature_from_state(self, source, feature):
for node_id in self.bfs_nodes(source):
Expand All @@ -195,6 +225,15 @@ def _skip_node(self, node_id, slicing):
self.remove_edge(parent_id, node_id)
self.remove_node(node_id)

def _cut_single_default_child(self, parent_id, default_child):
if not self.node[parent_id].get('is_default_node'):
self.node[parent_id] = self.node[default_child]
del self.node[parent_id]['is_default_leaf']
self.node[parent_id]['is_leaf'] = True
else:
self.node[parent_id] = self.node[default_child]
self.remove_node(default_child)

def _replace_absent_values(self):
root_id = self._get_root()

Expand Down Expand Up @@ -647,7 +686,10 @@ def _get_out_statement(self, parent, child):
def _get_feature(self, parent, child, state_node):
feature = self.node[parent].get('split')
if isinstance(feature, dict):
feature = feature[child]
try:
feature = feature[child]
except KeyError:
assert self.node[child].get('is_default_leaf', self.node[child].get('is_default_node', False))
if isinstance(feature, (list, tuple)):
return self._get_formatted_multidimensional_compound_feature(feature, state_node)
elif '.' in feature:
Expand Down
40 changes: 40 additions & 0 deletions bonspy/tests/conftest.py
Expand Up @@ -1056,3 +1056,43 @@ def small_unsliced_graph():
g.add_edge(0, 'default_one')

return g


@pytest.fixture
def small_unsliced_graph_single_slice_feature_value():
g = nx.DiGraph()
# root
g.add_node(0, split='slice_feature', state=OrderedDict())

# level one
g.add_node(1, state=OrderedDict([('slice_feature', 'value')]), is_leaf=True, output=5.)
g.add_node('default_one', is_default_leaf=True, state=OrderedDict(), output=1.)

# connect root with level one
g.add_edge(0, 1, value='value', type='assignment')
g.add_edge(0, 'default_one')

return g


@pytest.fixture
def small_unsliced_graph_mixed_split():
g = nx.DiGraph()
# root
g.add_node(
0, split=OrderedDict([(1, 'slice_feature'), (2, 'slice_feature'), (3, 'other_feature')]), state=OrderedDict()
)

# level one
g.add_node(1, state=OrderedDict([('slice_feature', 'good')]), is_leaf=True, output=5.)
g.add_node(2, state=OrderedDict([('slice_feature', 'bad')]), is_leaf=True, output=1.)
g.add_node(3, state=OrderedDict([('other_feature', 'value')]), is_leaf=True, output=3.)
g.add_node('default_one', is_default_leaf=True, state=OrderedDict(), output=1.)

# connect root with level one
g.add_edge(0, 1, value='good', type='assignment')
g.add_edge(0, 2, value='bad', type='assignment')
g.add_edge(0, 3, value='other_value', type='assignment')
g.add_edge(0, 'default_one')

return g
41 changes: 39 additions & 2 deletions bonspy/tests/test_bonsai.py
Expand Up @@ -457,7 +457,7 @@ def test_feature_slicer(unsliced_graph, small_unsliced_graph):
)

assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node])
assert all(['slice_feature' not in tree.edge[n].get('split', dict()).values() for n in tree.node])
assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node])

tree = BonsaiTree(
small_unsliced_graph,
Expand All @@ -466,5 +466,42 @@ def test_feature_slicer(unsliced_graph, small_unsliced_graph):
)

assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node])
assert all(['slice_feature' not in tree.edge[n].get('split', dict()).values() for n in tree.node])
assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node])
assert tree.node[0]['output'] == 5.


def test_feature_slicer_single_wrong_slice_feature_value(small_unsliced_graph_single_slice_feature_value):
tree = BonsaiTree(
small_unsliced_graph_single_slice_feature_value,
slice_features=('slice_feature',),
slice_feature_values={'slice_feature': 'value'}
)

assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node])
assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node])
assert tree.node[0]['output'] == 5.


def test_feature_slicer_single_correct_slice_feature_value(small_unsliced_graph_single_slice_feature_value):
tree = BonsaiTree(
small_unsliced_graph_single_slice_feature_value,
slice_features=('slice_feature',),
slice_feature_values={'slice_feature': 'other_value'}
)

assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node])
assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node])
assert tree.node[0]['output'] == 1.


def test_feature_slicer_mixed_split(small_unsliced_graph_mixed_split):
tree = BonsaiTree(
small_unsliced_graph_mixed_split,
slice_features=('slice_feature',),
slice_feature_values={'slice_feature': 'good'}
)

assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node])
assert all(['slice_feature' not in tree.node[n].get('split', dict()).values() for n in tree.node])
assert 'output' not in tree.node[0]
assert tree.node['default_one']['output'] == 5.
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -8,7 +8,7 @@

setup(
name='bonspy',
version='1.2.6',
version='1.2.7',
description='Library that converts bidding trees to the AppNexus Bonsai language.',
author='Alexander Volkmann, Georg Walther',
author_email='contact@markovian.com',
Expand Down

0 comments on commit daef5a6

Please sign in to comment.