Skip to content

Commit

Permalink
Merge pull request #87 from csinva/issue82
Browse files Browse the repository at this point in the history
fix dedup of extracted rule candidates
  • Loading branch information
keyan3 committed Feb 5, 2022
2 parents 128e1a1 + 978c517 commit 2963cad
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions imodels/util/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.utils.validation import check_array

from imodels.discretization import BRLDiscretizer, SimpleDiscretizer
from imodels.util.convert import tree_to_rules
from imodels.util import rule, convert


def extract_fpgrowth(X, y,
Expand Down Expand Up @@ -94,13 +94,16 @@ def extract_rulefit(X, y, feature_names,
else:
estimators_ = tree_generator.estimators_

seen_antecedents = set()
seen_rules = set()
extracted_rules = []
for estimator in estimators_:
for rule_value_pair in tree_to_rules(estimator[0], np.array(feature_names), prediction_values=True):
if rule_value_pair[0] not in seen_antecedents:
for estimator in estimators_:
for rule_value_pair in convert.tree_to_rules(estimator[0], np.array(feature_names), prediction_values=True):

rule_obj = rule.Rule(rule_value_pair[0])

if rule_obj not in seen_rules:
extracted_rules.append(rule_value_pair)
seen_antecedents.add(rule_value_pair[0])
seen_rules.add(rule_obj)

extracted_rules = sorted(extracted_rules, key=lambda x: x[1])
extracted_rules = list(map(lambda x: x[0], extracted_rules))
Expand Down Expand Up @@ -170,6 +173,6 @@ def extract_skope(X, y, feature_names,

extracted_rules = []
for estimator, features in zip(estimators_, estimators_features_):
extracted_rules.append(tree_to_rules(estimator, np.array(feature_names)[features]))
extracted_rules.append(convert.tree_to_rules(estimator, np.array(feature_names)[features]))

return extracted_rules, estimators_samples_, estimators_features_

0 comments on commit 2963cad

Please sign in to comment.