diff --git a/bonspy/bonsai.py b/bonspy/bonsai.py index dccaf35..0818bf4 100644 --- a/bonspy/bonsai.py +++ b/bonspy/bonsai.py @@ -44,22 +44,23 @@ class BonsaiTree(nx.DiGraph): of the form {feature: [feature value 1, feature value 2, ...]}. :param absence_values: (optional), Dictionary feature name -> iterable of values whose communal absence signals absence of the respective feature. - :param slice_feature: (optional) str, feature to be used for slicing. The private _slice_graph method slices out the - part of the graph where the "slice_feature" has a value that is contained in the "slice_feature_values". - Moreover, it splices out the level where the "splice_feature" is split. - The "slice" method assumes that a node never splits on the "slice_feature" together with another feature. - :param slice_feature_values: (optional) iterable of strings, feature values to not be sliced off the graph. + :param slice_features: (optional) iterable, features to be used for slicing. The private _slice_graph method slices + out the part of the graph where the "slice_features" have a value that is equal to the value of the + "slice_feature_values" dict. + Moreover, it splices out the levels where the "splice_features" are split. + The "slice" method assumes that a node never splits on the "slice_features" together with another feature. + :param slice_feature_values: (optional) dict, slice_feature -> feature values to not be sliced off the graph. """ def __init__(self, graph=None, feature_order=(), feature_value_order={}, absence_values=None, - slice_feature=None, slice_feature_values=(), **kwargs): + slice_features=None, slice_feature_values=(), **kwargs): if graph is not None: super(BonsaiTree, self).__init__(graph) self.feature_order = self._convert_to_dict(feature_order) self.feature_value_order = self._get_feature_value_order(feature_value_order) self.absence_values = absence_values or {} - self.slice_feature = slice_feature - self.slice_feature_values = slice_feature_values + self.slice_features = slice_features or () + self.slice_feature_values = slice_feature_values or {} for key, value in kwargs.items(): setattr(self, key, value) self._transform_splits() @@ -105,6 +106,10 @@ def _transform_splits(self): self.node[node_id]['split'][child_id] = split def _slice_graph(self): + for slice_feature in self.slice_features: + self._slice_feature_out_of_graph(slice_feature) + + def _slice_feature_out_of_graph(self, slice_feature): root_id = self._get_root() queue = deque([root_id]) @@ -112,37 +117,37 @@ def _slice_graph(self): node_id = queue.popleft() if self.node[node_id].get('is_default_leaf'): continue - split_contains_slice_feature = self._split_contains_slice_feature(node_id) + split_contains_slice_feature = self._split_contains_slice_feature(node_id, slice_feature) if not split_contains_slice_feature: next_nodes = self.successors(node_id) queue.extend(next_nodes) else: - self._update_sub_graph(node_id) + self._update_sub_graph(node_id, slice_feature) - def _split_contains_slice_feature(self, node_id): + def _split_contains_slice_feature(self, node_id, slice_feature): try: split = self.node[node_id]['split'] - return self.slice_feature in split.values() + return slice_feature in split.values() except KeyError: # default leaf or leaf return False - def _update_sub_graph(self, node_id): - self._prune_unwanted_children(node_id) + def _update_sub_graph(self, node_id, slice_feature): + self._prune_unwanted_children(node_id, slice_feature) default_child = next((n for n in self.successors_iter(node_id) if self.node[n].get('is_default_leaf'))) normal_child = next((n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf'))) assert len([n for n in self.successors_iter(node_id) if n not in {default_child, normal_child}]) == 0 if self.node[normal_child].get('is_leaf'): - self._remove_leaves_and_update_parent_default(node_id, normal_child, default_child) + self._remove_leaves_and_update_parent_default(node_id, slice_feature, normal_child, default_child) else: - self._splice_out_node(normal_child, self.slice_feature, slicing=True) + self._splice_out_node(normal_child, slice_feature, slicing=True) - def _prune_unwanted_children(self, node_id): + def _prune_unwanted_children(self, node_id, slice_feature): prunable_children = [ n for n in self.successors_iter(node_id) if not self.node[n].get('is_default_leaf') and - self.node[n]['state'].get(self.slice_feature) not in self.slice_feature_values + self.node[n]['state'].get(slice_feature) != self.slice_feature_values[slice_feature] ] for prunable_child in prunable_children: self._remove_sub_graph(prunable_child) @@ -155,9 +160,9 @@ def _remove_sub_graph(self, node): self.remove_node(current_node) queue.extend(next_nodes) - def _remove_leaves_and_update_parent_default(self, node_id, normal_child, default_child): + def _remove_leaves_and_update_parent_default(self, node_id, slice_feature, normal_child, default_child): del self.node[node_id]['split'] - self._remove_feature_from_state(node_id, self.slice_feature) + self._remove_feature_from_state(node_id, slice_feature) self.node[node_id] = self.node[normal_child].copy() self.remove_edge(node_id, normal_child) diff --git a/bonspy/tests/test_bonsai.py b/bonspy/tests/test_bonsai.py index 3480673..98fd5c5 100644 --- a/bonspy/tests/test_bonsai.py +++ b/bonspy/tests/test_bonsai.py @@ -452,8 +452,8 @@ def test_negated_values(negated_values_graph): def test_feature_slicer(unsliced_graph, small_unsliced_graph): tree = BonsaiTree( unsliced_graph, - slice_feature='slice_feature', - slice_feature_values=('good',) + slice_features=('slice_feature',), + slice_feature_values={'slice_feature': 'good'} ) assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node]) @@ -461,8 +461,8 @@ def test_feature_slicer(unsliced_graph, small_unsliced_graph): tree = BonsaiTree( small_unsliced_graph, - slice_feature='slice_feature', - slice_feature_values=('good',) + slice_features=('slice_feature',), + slice_feature_values={'slice_feature': 'good'} ) assert all(['slice_feature' not in tree.node[n].get('state', set()) for n in tree.node]) diff --git a/setup.py b/setup.py index e7de20d..a23a7bf 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='bonspy', - version='1.2.5', + version='1.2.6', description='Library that converts bidding trees to the AppNexus Bonsai language.', author='Alexander Volkmann, Georg Walther', author_email='contact@markovian.com',