In [8]:
import re
import numpy as np 
from tree_solvers import RandomForest
from rule_filter import Condition, Path, Rule, filter_linearly_dependent_rules

class SIRUS_Solver:
    def __init__(self, X_train, X_test, y_train, y_test, num_subsampled_points, frequency_threshold, max_tree_depth, q, number_trees):
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.q = q
        self.p0 = frequency_threshold 
        self.a_n = num_subsampled_points #number of samples bootraped for each tree in RF
        self.max_features = int(np.floor(X_train.shape[0] / X_train.shape[1]))  # number of features considered in each node of RF
        self.num_trees = number_trees
        self.max_tree_depth = max_tree_depth 
        self.quantiles = self.empirical_quantiles() #valids splits for each node 
        self.rf_model = None

        self.raw_all_rules = {}
        self.all_rules = []
        self.rules_high_freq = {}
        self.independent_rules = []
        
    def empirical_quantiles(self):
        return np.percentile(self.X_train, np.arange(0, 100, self.q), axis = 0) 

    def _store_raw_rules(self):
        for key, tree in self.rf_model.trees.items():
            self.raw_all_rules[key] = tree.rules

    """
    Simplify the rules that contain a single split by only retaining rules that point left and
    removing duplicates.
    """
    def _store_formated_rules(self):
        rules_flatten  = {}
        for _, tree_rules in self.raw_all_rules.items():
            for path, rule in tree_rules.items():
                rules_flatten[path] = rule

        if len(rules_flatten) % 2 != 0:
            raise AssertionError("The length of rules_flatten is not even.")

        iter_rules_items = iter(rules_flatten.items())
        all_rules = []

        for (path_right, node_pred_right), (path_left, node_pred_left) in zip(iter_rules_items, iter_rules_items):

            #double check that the same node is being compared
            numeric_values_right = [int(float(value)) for value in re.findall(r'\d+\.\d+', path_right)]
            numeric_sum_right = sum(numeric_values_right)

            numeric_values_left = [int(float(value)) for value in re.findall(r'\d+\.\d+', path_left)]
            numeric_sum_left = sum(numeric_values_left)

            if numeric_sum_right != numeric_sum_left:
                raise ValueError("Error: Feature mismatch (i.e the same node is not being processed)")

            conditions = path_right.split("&")
            path_conditions = []
            for condition in conditions:
                feature, operation, split = (condition.strip()).split(" ")[1:]
                path_conditions.append(Condition(int(feature), operation, float(split)))
            path_to_node = Path(path_conditions)
            node_rule = Rule(path_to_node, [float(node_pred_right), float(node_pred_left)])
            all_rules.append(node_rule)
        
        self.all_rules = all_rules

    def fit_trees(self):
        self.rf_model = RandomForest(self.num_trees, self.a_n, self.max_tree_depth, self.max_features, self.quantiles)
        self.rf_model.fit(self.X_train, self.y_train)
        self._store_raw_rules()
        self._store_formated_rules()

    def get_independent_rules(self):
        self.independent_rules = filter_linearly_dependent_rules(self.all_rules)

    def print_rules(self, rules):
        for rule in rules:
            rule.print_rule()

    def get_accuracy(self):
        y_pred = self.rf_model.predict(self.X_test)
        test_acc = np.mean(y_pred == self.y_test)
        print(f"test accuracy: {test_acc}")


if __name__ == "__main__":

    from preprocess.get_data import get_BW_data
    from sklearn.model_selection import train_test_split

    X, y = get_BW_data("/Users/norahallqvist/Code/SIRUS/data/BreastWisconsin.csv")
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

    # Hyperparameters
    num_subsampled_points = int((X_train.shape[0]) * 0.95)
    frequency_threshold = 0.5
    tree_depth = 1
    q = 10
    num_trees = 1000

    solver = SIRUS_Solver(X_train, X_test, y_train, y_test, num_subsampled_points, frequency_threshold, tree_depth, q, num_trees)
    solver.fit_trees()
    solver.get_accuracy()


    solver.get_independent_rules()
    solver.print_rules(solver.all_rules)
    
    # import json 
    # with open('/Users/norahallqvist/Code/SIRUS/results/example_rules_dict.json', 'w') as json_file:
    #     json.dump(solver.raw_all_rules, json_file)



test accuracy: 0.9035087719298246
Feature 22 < 106.4 then Y = 0.0 else Y = 1.0
Feature 22 < 106.4 & Feature 7 < 0.049199999999999994 then Y = 0.0 else Y = 0.0
Feature 22 >= 106.4 & Feature 6 < 0.08802999999999998 then Y = 1.0 else Y = 1.0
Feature 22 < 117.7 then Y = 0.0 else Y = 1.0
Feature 22 < 117.7 & Feature 21 < 26.567999999999998 then Y = 0.0 else Y = 0.0
Feature 22 >= 117.7 & Feature 15 < 0.009169 then Y = 0.0 else Y = 1.0
Feature 23 < 963.1799999999992 then Y = 0.0 else Y = 1.0
Feature 23 < 963.1799999999992 & Feature 7 < 0.049199999999999994 then Y = 0.0 else Y = 1.0
Feature 23 >= 963.1799999999992 & Feature 8 < 0.1506 then Y = 0.0 else Y = 1.0
Feature 7 < 0.049199999999999994 then Y = 0.0 else Y = 1.0
Feature 7 < 0.049199999999999994 & Feature 22 < 97.66 then Y = 0.0 else Y = 0.0
Feature 7 >= 0.049199999999999994 & Feature 0 < 15.213999999999999 then Y = 1.0 else Y = 1.0
Feature 23 < 963.1799999999992 & Feature 27 < 0.12544 then Y = 0.0 else Y = 1.0
Feature 23 >= 963.179999999

In [44]:
def rule_count(rules):
    node_counts = {}

    for i, rule_1 in enumerate(rules):
        count = 0
        for j, rule_2 in enumerate(rules): 
            if i != j: 
                # if rule_2.node_path.node_conditions[0].feature_idx == 22 and rule_2.node_path.node_conditions[0].split_value  == 106.4:
                #     print("rule 1 and 2")
                #     rule_1.print_rule()
                #     rule_2.print_rule()
                if rule_1.node_path == rule_2.node_path:
                    count +=1
        node_counts[rule_1] = count
    
    return node_counts


In [47]:
counts = rule_count(solver.all_rules)
sum(counts.values())

0

In [46]:
counts

{<rule_filter.Rule at 0x16d360450>: 0,
 <rule_filter.Rule at 0x16d362650>: 0,
 <rule_filter.Rule at 0x16d3604d0>: 0,
 <rule_filter.Rule at 0x16d360a90>: 0,
 <rule_filter.Rule at 0x16d360d10>: 0,
 <rule_filter.Rule at 0x16d361110>: 0,
 <rule_filter.Rule at 0x16d361310>: 0,
 <rule_filter.Rule at 0x16d361650>: 0,
 <rule_filter.Rule at 0x16d361690>: 0,
 <rule_filter.Rule at 0x16d3612d0>: 0,
 <rule_filter.Rule at 0x16d361d10>: 0,
 <rule_filter.Rule at 0x16d361b90>: 0,
 <rule_filter.Rule at 0x16d362110>: 0,
 <rule_filter.Rule at 0x16d361f90>: 0,
 <rule_filter.Rule at 0x16d361e50>: 0,
 <rule_filter.Rule at 0x16d3618d0>: 0,
 <rule_filter.Rule at 0x16d362150>: 0,
 <rule_filter.Rule at 0x16d362450>: 0,
 <rule_filter.Rule at 0x16d3624d0>: 0,
 <rule_filter.Rule at 0x16d362790>: 0,
 <rule_filter.Rule at 0x16d362210>: 0,
 <rule_filter.Rule at 0x16d361ed0>: 0,
 <rule_filter.Rule at 0x16d362910>: 0,
 <rule_filter.Rule at 0x16d360410>: 0,
 <rule_filter.Rule at 0x16d361490>: 0,
 <rule_filter.Rule at 0x1

In [17]:
(solver.all_rules[0]).node_path == (solver.all_rules[0]).node_path

True

In [21]:
solver.all_rules[0].node_path.print_path()

Feature 22 < 106.4


In [4]:
["hello"] == ["hello"]

True