Skip to content

Commit

Permalink
fix #169 with suggestion from @davidefiocco
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 20, 2023
1 parent 1243240 commit cee96ef
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 205 deletions.
220 changes: 140 additions & 80 deletions docs/experimental/figs_ensembles.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ <h2 id="support-for-different-tasks">Support for different tasks</h2>
</tr>
<tr>
<td style="text-align: left;">Skope rule set</td>
<td style="text-align: center;"><a href="https://csinva.io/imodels/rule_set/slipper.html#imodels.rule_set.slipper.SlipperClassifier">SkopeRulesClassifier</a></td>
<td style="text-align: center;"><a href="https://csinva.io/imodels/rule_set/skope_rules.html#imodels.rule_set.skope_rules.SkopeRulesClassifier">SkopeRulesClassifier</a></td>
<td style="text-align: center;"></td>
<td></td>
</tr>
Expand Down
44 changes: 27 additions & 17 deletions docs/rule_list/greedy_rule_list.html
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@
import math
from copy import deepcopy

import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_array, check_is_fitted
from sklearn.tree import DecisionTreeClassifier
Expand Down Expand Up @@ -87,9 +85,9 @@
elif np.all(y == y[0]):
return [{&#39;val&#39;: y[0], &#39;num_pts&#39;: y.size}]

# base case 3: max depth reached
elif depth &gt;= self.max_depth:
return []
# base case 3: max depth reached
elif depth == self.max_depth:
return [{&#39;val&#39;: np.mean(y), &#39;num_pts&#39;: y.size}]

# recursively generate rule list
else:
Expand Down Expand Up @@ -128,7 +126,7 @@
&#39;col&#39;: self.feature_names_[col],
&#39;index_col&#39;: col,
&#39;cutoff&#39;: cutoff,
&#39;val&#39;: np.mean(y), # values before splitting
&#39;val&#39;: np.mean(y_left), # will be the values before splitting in the next lower level
&#39;flip&#39;: flip,
&#39;val_right&#39;: np.mean(y_right),
&#39;num_pts&#39;: y.size,
Expand All @@ -155,7 +153,11 @@
for j, rule in enumerate(self.rules_):
if j == len(self.rules_) - 1:
probs[i] = rule[&#39;val&#39;]
elif x[rule[&#39;index_col&#39;]] &gt;= rule[&#39;cutoff&#39;]:
continue
regular_condition = x[rule[&#34;index_col&#34;]] &gt;= rule[&#34;cutoff&#34;]
flipped_condition = x[rule[&#34;index_col&#34;]] &lt; rule[&#34;cutoff&#34;]
condition = flipped_condition if rule[&#34;flip&#34;] else regular_condition
if condition:
probs[i] = rule[&#39;val_right&#39;]
break
return np.vstack((1 - probs, probs)).transpose() # probs (n, 2)
Expand Down Expand Up @@ -434,9 +436,9 @@ <h2 id="params">Params</h2>
elif np.all(y == y[0]):
return [{&#39;val&#39;: y[0], &#39;num_pts&#39;: y.size}]

# base case 3: max depth reached
elif depth &gt;= self.max_depth:
return []
# base case 3: max depth reached
elif depth == self.max_depth:
return [{&#39;val&#39;: np.mean(y), &#39;num_pts&#39;: y.size}]

# recursively generate rule list
else:
Expand Down Expand Up @@ -475,7 +477,7 @@ <h2 id="params">Params</h2>
&#39;col&#39;: self.feature_names_[col],
&#39;index_col&#39;: col,
&#39;cutoff&#39;: cutoff,
&#39;val&#39;: np.mean(y), # values before splitting
&#39;val&#39;: np.mean(y_left), # will be the values before splitting in the next lower level
&#39;flip&#39;: flip,
&#39;val_right&#39;: np.mean(y_right),
&#39;num_pts&#39;: y.size,
Expand All @@ -502,7 +504,11 @@ <h2 id="params">Params</h2>
for j, rule in enumerate(self.rules_):
if j == len(self.rules_) - 1:
probs[i] = rule[&#39;val&#39;]
elif x[rule[&#39;index_col&#39;]] &gt;= rule[&#39;cutoff&#39;]:
continue
regular_condition = x[rule[&#34;index_col&#34;]] &gt;= rule[&#34;cutoff&#34;]
flipped_condition = x[rule[&#34;index_col&#34;]] &lt; rule[&#34;cutoff&#34;]
condition = flipped_condition if rule[&#34;flip&#34;] else regular_condition
if condition:
probs[i] = rule[&#39;val_right&#39;]
break
return np.vstack((1 - probs, probs)).transpose() # probs (n, 2)
Expand Down Expand Up @@ -771,9 +777,9 @@ <h3>Methods</h3>
elif np.all(y == y[0]):
return [{&#39;val&#39;: y[0], &#39;num_pts&#39;: y.size}]

# base case 3: max depth reached
elif depth &gt;= self.max_depth:
return []
# base case 3: max depth reached
elif depth == self.max_depth:
return [{&#39;val&#39;: np.mean(y), &#39;num_pts&#39;: y.size}]

# recursively generate rule list
else:
Expand Down Expand Up @@ -812,7 +818,7 @@ <h3>Methods</h3>
&#39;col&#39;: self.feature_names_[col],
&#39;index_col&#39;: col,
&#39;cutoff&#39;: cutoff,
&#39;val&#39;: np.mean(y), # values before splitting
&#39;val&#39;: np.mean(y_left), # will be the values before splitting in the next lower level
&#39;flip&#39;: flip,
&#39;val_right&#39;: np.mean(y_right),
&#39;num_pts&#39;: y.size,
Expand Down Expand Up @@ -864,7 +870,11 @@ <h3>Methods</h3>
for j, rule in enumerate(self.rules_):
if j == len(self.rules_) - 1:
probs[i] = rule[&#39;val&#39;]
elif x[rule[&#39;index_col&#39;]] &gt;= rule[&#39;cutoff&#39;]:
continue
regular_condition = x[rule[&#34;index_col&#34;]] &gt;= rule[&#34;cutoff&#34;]
flipped_condition = x[rule[&#34;index_col&#34;]] &lt; rule[&#34;cutoff&#34;]
condition = flipped_condition if rule[&#34;flip&#34;] else regular_condition
if condition:
probs[i] = rule[&#39;val_right&#39;]
break
return np.vstack((1 - probs, probs)).transpose() # probs (n, 2)</code></pre>
Expand Down
34 changes: 22 additions & 12 deletions docs/tree/c45_tree/c45_tree.html
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.model_selection import cross_val_score
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
from imodels.util.arguments import check_fit_arguments

from ..c45_tree.c45_utils import decision, is_numeric_feature, gain, gain_ratio, get_best_split, \
set_as_leaf_node
Expand Down Expand Up @@ -166,17 +167,20 @@

def fit(self, X, y, feature_names: str = None):
self.complexity_ = 0
X, y = check_X_y(X, y)
# X, y = check_X_y(X, y)
X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
self.resultType = type(y[0])
if feature_names is None:
self.feature_names = [f&#39;X_{x}&#39; for x in range(X.shape[1])]
else:
# only include alphanumeric chars / replace spaces with underscores
self.feature_names = [&#39;&#39;.join([i for i in x if i.isalnum()]).replace(&#39; &#39;, &#39;_&#39;)
for x in feature_names]
self.feature_names = [&#39;X_&#39; + x if x[0].isdigit()
else x
for x in self.feature_names]
self.feature_names = [
&#39;X_&#39; + x if x[0].isdigit()
else x
for x in self.feature_names
]

assert len(self.feature_names) == X.shape[1]

Expand Down Expand Up @@ -535,17 +539,20 @@ <h2 id="parameters">Parameters</h2>

def fit(self, X, y, feature_names: str = None):
self.complexity_ = 0
X, y = check_X_y(X, y)
# X, y = check_X_y(X, y)
X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
self.resultType = type(y[0])
if feature_names is None:
self.feature_names = [f&#39;X_{x}&#39; for x in range(X.shape[1])]
else:
# only include alphanumeric chars / replace spaces with underscores
self.feature_names = [&#39;&#39;.join([i for i in x if i.isalnum()]).replace(&#39; &#39;, &#39;_&#39;)
for x in feature_names]
self.feature_names = [&#39;X_&#39; + x if x[0].isdigit()
else x
for x in self.feature_names]
self.feature_names = [
&#39;X_&#39; + x if x[0].isdigit()
else x
for x in self.feature_names
]

assert len(self.feature_names) == X.shape[1]

Expand Down Expand Up @@ -766,17 +773,20 @@ <h3>Methods</h3>
</summary>
<pre><code class="python">def fit(self, X, y, feature_names: str = None):
self.complexity_ = 0
X, y = check_X_y(X, y)
# X, y = check_X_y(X, y)
X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
self.resultType = type(y[0])
if feature_names is None:
self.feature_names = [f&#39;X_{x}&#39; for x in range(X.shape[1])]
else:
# only include alphanumeric chars / replace spaces with underscores
self.feature_names = [&#39;&#39;.join([i for i in x if i.isalnum()]).replace(&#39; &#39;, &#39;_&#39;)
for x in feature_names]
self.feature_names = [&#39;X_&#39; + x if x[0].isdigit()
else x
for x in self.feature_names]
self.feature_names = [
&#39;X_&#39; + x if x[0].isdigit()
else x
for x in self.feature_names
]

assert len(self.feature_names) == X.shape[1]

Expand Down
Loading

0 comments on commit cee96ef

Please sign in to comment.