Skip to content

Commit 75c2cf9

Browse files
authored
Merge pull request #352 from neurodata/honestoblique_yuxin
FIX attempt to correct pruning
2 parents ccccf9c + 82390b2 commit 75c2cf9

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

treeple/tree/_honest_tree.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def _fit_leaves(self, X, y, sample_weight):
761761
)
762762
self.tree_ = pruned_tree
763763

764+
y = y_encoded
764765
# Fit leaves using other subsample
765766
honest_leaves = self.tree_.apply(X[self.honest_indices_])
766767

@@ -883,6 +884,7 @@ class in a leaf.
883884

884885
if self.n_outputs_ == 1:
885886
proba = proba[:, : self._tree_n_classes_]
887+
886888
if not self.kernel_method:
887889
normalizer = proba.sum(axis=1)[:, np.newaxis]
888890
normalizer[normalizer == 0.0] = 1.0
@@ -896,11 +898,13 @@ class in a leaf.
896898

897899
for k in range(self.n_outputs_):
898900
proba_k = proba[:, k, : self._tree_n_classes_[k]]
901+
899902
if not self.kernel_method:
900903
normalizer = proba_k.sum(axis=1)[:, np.newaxis]
901904
normalizer[normalizer == 0.0] = 1.0
902905
proba_k /= normalizer
903906
proba_k = self._empty_leaf_correction(proba_k, k)
907+
904908
all_proba.append(proba_k)
905909

906910
return all_proba

treeple/tree/honesty/_honest_prune.pyx

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ cdef class HonestPruner(Splitter):
110110
The original tree to be pruned.
111111
"""
112112
self.tree = orig_tree
113-
self.capacity = 0
113+
self.capacity = 2047
114114

115115
cdef int init(
116116
self,
@@ -133,7 +133,6 @@ cdef class HonestPruner(Splitter):
133133
right-most end of `samples`, that is `samples[end_non_missing:end]`.
134134
"""
135135
cdef float64_t threshold = self.tree.nodes[node_idx].threshold
136-
cdef intp_t feature = self.tree.nodes[node_idx].feature
137136
cdef intp_t n_missing = 0
138137
cdef intp_t pos = self.start
139138
cdef intp_t p
@@ -147,10 +146,9 @@ cdef class HonestPruner(Splitter):
147146
sample_idx = self.samples[p]
148147

149148
# missing-values are always placed at the right-most end
150-
if isnan(X_ndarray[sample_idx, feature]):
149+
if isnan(self.tree._compute_feature(X_ndarray, sample_idx, &self.tree.nodes[node_idx])):
151150
self.samples[p], self.samples[current_end] = \
152151
self.samples[current_end], self.samples[p]
153-
154152
n_missing += 1
155153
current_end -= 1
156154

0 commit comments

Comments
 (0)