Skip to content

Commit

Permalink
Merge pull request #49 from markovianhq/feature/multi_feature_slicing
Browse files Browse the repository at this point in the history
Feature/multi feature slicing
  • Loading branch information
volkale committed Jul 7, 2017
2 parents f51a4a5 + 9fd1db4 commit 943625e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
45 changes: 25 additions & 20 deletions bonspy/bonsai.py
Expand Up @@ -44,22 +44,23 @@ class BonsaiTree(nx.DiGraph):
of the form {feature: [feature value 1, feature value 2, ...]}.
:param absence_values: (optional), Dictionary feature name -> iterable of values whose communal absence
signals absence of the respective feature.
:param slice_feature: (optional) str, feature to be used for slicing. The private _slice_graph method slices out the
part of the graph where the "slice_feature" has a value that is contained in the "slice_feature_values".
Moreover, it splices out the level where the "splice_feature" is split.
The "slice" method assumes that a node never splits on the "slice_feature" together with another feature.
:param slice_feature_values: (optional) iterable of strings, feature values to not be sliced off the graph.
:param slice_features: (optional) iterable, features to be used for slicing. The private _slice_graph method slices
out the part of the graph where the "slice_features" have a value that is equal to the value of the
"slice_feature_values" dict.
Moreover, it splices out the levels where the "splice_features" are split.
The "slice" method assumes that a node never splits on the "slice_features" together with another feature.
:param slice_feature_values: (optional) dict, slice_feature -> feature values to not be sliced off the graph.
"""

def __init__(self, graph=None, feature_order=(), feature_value_order={}, absence_values=None,
slice_feature=None, slice_feature_values=(), **kwargs):
slice_features=None, slice_feature_values=(), **kwargs):
if graph is not None:
super(BonsaiTree, self).__init__(graph)
self.feature_order = self._convert_to_dict(feature_order)
self.feature_value_order = self._get_feature_value_order(feature_value_order)
self.absence_values = absence_values or {}
self.slice_feature = slice_feature
self.slice_feature_values = slice_feature_values
self.slice_features = slice_features or ()
self.slice_feature_values = slice_feature_values or {}
for key, value in kwargs.items():
setattr(self, key, value)
self._transform_splits()
Expand Down Expand Up @@ -105,44 +106,48 @@ def _transform_splits(self):
self.node[node_id]['split'][child_id] = split

def _slice_graph(self):
for slice_feature in self.slice_features:
self._slice_feature_out_of_graph(slice_feature)

def _slice_feature_out_of_graph(self, slice_feature):
root_id = self._get_root()

queue = deque([root_id])
while queue:
node_id = queue.popleft()
if self.node[node_id].get('is_default_leaf'):
continue
split_contains_slice_feature = self._split_contains_slice_feature(node_id)
split_contains_slice_feature = self._split_contains_slice_feature(node_id, slice_feature)

if not split_contains_slice_feature:
next_nodes = self.successors(node_id)
queue.extend(next_nodes)
else:
self._update_sub_graph(node_id)
self._update_sub_graph(node_id, slice_feature)

def _split_contains_slice_feature(self, node_id):
def _split_contains_slice_feature(self, node_id, slice_feature):
try:
split = self.node[node_id]['split']
return self.slice_feature in split.values()
return slice_feature in split.values()
except KeyError: # default leaf or leaf
return False

def _update_sub_graph(self, node_id):
self._prune_unwanted_children(node_id)
def _update_sub_graph(self, node_id, slice_feature):
self._prune_unwanted_children(node_id, slice_feature)

default_child = next((n for n in self.successors_iter(node_id) if self.node[n].get('is_default_leaf')))
normal_child = next((n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf')))
assert len([n for n in self.successors_iter(node_id) if n not in {default_child, normal_child}]) == 0

if self.node[normal_child].get('is_leaf'):
self._remove_leaves_and_update_parent_default(node_id, normal_child, default_child)
self._remove_leaves_and_update_parent_default(node_id, slice_feature, normal_child, default_child)
else:
self._splice_out_node(normal_child, self.slice_feature, slicing=True)
self._splice_out_node(normal_child, slice_feature, slicing=True)

def _prune_unwanted_children(self, node_id):
def _prune_unwanted_children(self, node_id, slice_feature):
prunable_children = [
n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf') and
self.node[n]['state'].get(self.slice_feature) not in self.slice_feature_values
self.node[n]['state'].get(slice_feature) != self.slice_feature_values[slice_feature]
]
for prunable_child in prunable_children:
self._remove_sub_graph(prunable_child)
Expand All @@ -155,9 +160,9 @@ def _remove_sub_graph(self, node):
self.remove_node(current_node)
queue.extend(next_nodes)

def _remove_leaves_and_update_parent_default(self, node_id, normal_child, default_child):
def _remove_leaves_and_update_parent_default(self, node_id, slice_feature, normal_child, default_child):
del self.node[node_id]['split']
self._remove_feature_from_state(node_id, self.slice_feature)
self._remove_feature_from_state(node_id, slice_feature)
self.node[node_id] = self.node[normal_child].copy()

self.remove_edge(node_id, normal_child)
Expand Down
8 changes: 4 additions & 4 deletions bonspy/tests/test_bonsai.py
Expand Up @@ -452,17 +452,17 @@ def test_negated_values(negated_values_graph):
def test_feature_slicer(unsliced_graph, small_unsliced_graph):
tree = BonsaiTree(
unsliced_graph,
slice_feature='slice_feature',
slice_feature_values=('good',)
slice_features=('slice_feature',),
slice_feature_values={'slice_feature': 'good'}
)

assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node])
assert all(['slice_feature' not in tree.edge[n].get('split', dict()).values() for n in tree.node])

tree = BonsaiTree(
small_unsliced_graph,
slice_feature='slice_feature',
slice_feature_values=('good',)
slice_features=('slice_feature',),
slice_feature_values={'slice_feature': 'good'}
)

assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node])
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -8,7 +8,7 @@

setup(
name='bonspy',
version='1.2.5',
version='1.2.6',
description='Library that converts bidding trees to the AppNexus Bonsai language.',
author='Alexander Volkmann, Georg Walther',
author_email='contact@markovian.com',
Expand Down

0 comments on commit 943625e

Please sign in to comment.