In [2]:
import json
import numpy as np
import pandas as pd
import os

import importlib
from surrogate_rule import forest_info
from surrogate_rule import tree_node_info
importlib.reload(tree_node_info)
importlib.reload(forest_info)

import altair as alt


In [204]:
data_list = [
    "penn_wind", 
    "penn_cpu", 
    "penn_satellite", 
    "crime", 
    "diabetes",
    "loan",
    "income",
    "heart",
    "dry_bean",
    "obesity_level",
    "music_origin"
]



In [205]:
# define pruning constraints

filter_threshold = {
    "support": 5,
    "fidelity": .85,
    "num_feat": 4,
    "num_bin": 3,
}
num_bin = filter_threshold['num_bin']


## HSR

In [195]:
import importlib
from surrogate_rule import forest_info
from surrogate_rule import tree_node_info
importlib.reload(tree_node_info)
importlib.reload(forest_info)

<module 'surrogate_rule.forest_info' from '/Users/junyuan/Documents/_/python/sure_eval/surrogate_rule/forest_info.py'>

In [154]:
def extract_rules_from_RF():
    # train surrogate
    surrogate_obj = tree_node_info.tree_node_info()

    surrogate_obj.initialize(X=np.array(data['data']), y=np.array(data['y_gt']), 
                             y_pred=np.array(data['y_pred']), debug_class=-1,
                             attrs=data['columns'], filter_threshold=filter_threshold,
                             n_cls=len(data['target_names']),
                             num_bin=num_bin, verbose=False
    ).train_surrogate_random_forest().tree_pruning()

    forest_obj = tree_node_info.forest()
    forest_obj.initialize(
        trees=surrogate_obj.tree_list, cate_X=surrogate_obj.cate_X, 
        y=surrogate_obj.y, y_pred=surrogate_obj.y_pred, attrs=attrs, num_bin=num_bin,
        real_percentiles=surrogate_obj.real_percentiles,
        real_min=surrogate_obj.real_min, real_max=surrogate_obj.real_max,
    ).construct_tree().extract_rules()
    
    forest = forest_info.Forest()

    forest.initialize(forest_obj.tree_node_dict, data['real_min'], data['real_max'], surrogate_obj.percentile_info,
        df, data['y_pred'], data['y_gt'],
        forest_obj.rule_lists,
        data['target_names'], 2)
    forest.initialize_rule_match_table()
    forest.initilized_rule_overlapping()
    res = forest.find_the_min_set()
    lattice = forest.get_lattice_structure(res['rules'])
    
    max_feat = 0
    min_feat = 111
    avg_feat = 0.0
    for rule in res['rules']:
        if (len(rule['rules']) > max_feat):
            max_feat = len(rule['rules'])
        if (len(rule['rules']) < min_feat):
            min_feat = len(rule['rules'])
        avg_feat += len(rule['rules'])
                           
    return len(res['rules']), res['coverage'], max_feat, min_feat, avg_feat/len(res['rules']), len(lattice)

## Decision Tree

In [136]:
# train surrogate DT
def extract_rules_from_DT():
    surrogate_obj2 = tree_node_info.tree_node_info()

    surrogate_obj2.initialize(X=np.array(data['data']), y=np.array(data['y_gt']), 
                             y_pred=np.array(data['y_pred']), debug_class=-1,
                             attrs=data['columns'], filter_threshold=filter_threshold,
                             n_cls=len(data['target_names']),
                             num_bin=num_bin, verbose=False
    ).train_surrogate_decision_tree().tree_pruning(False)
    
    tree_obj = tree_node_info.forest()
    tree_obj.initialize(
        trees=surrogate_obj2.tree_list, cate_X=surrogate_obj2.cate_X, 
        y=surrogate_obj2.y, y_pred=surrogate_obj2.y_pred, attrs=attrs, num_bin=num_bin,
        real_percentiles=surrogate_obj2.real_percentiles,
        real_min=surrogate_obj2.real_min, real_max=surrogate_obj2.real_max,
    ).construct_tree().extract_rules()

    dt = forest_info.Forest()

    dt.initialize(tree_obj.tree_node_dict, data['real_min'], data['real_max'], surrogate_obj2.percentile_info,
        df, data['y_pred'], data['y_gt'],
        tree_obj.rule_lists,
        data['target_names'], 2)
    dt.initialize_rule_match_table()
    dt.initilized_rule_overlapping()
    res = dt.find_the_min_set()
    max_feat = 0
    min_feat = 200
    avg_feat = 0.0
    for rule in res['rules']:
        if (len(rule['rules']) > max_feat):
            max_feat = len(rule['rules'])
        if (len(rule['rules']) < min_feat):
            min_feat = len(rule['rules'])
        avg_feat += len(rule['rules'])
                           
    return len(res['rules']), res['coverage'], max_feat, min_feat, avg_feat/len(res['rules']), len(tree_obj.tree_node_dict)

In [196]:
num_feat_list = [2,3, 4, 5, 6,7]
num_bin = filter_threshold['num_bin']


for data_name in data_list:
    file = "output/" + data_name + "/test.json"
    with open(file, 'r') as json_input:
        data = json.load(json_input)
    attrs = data['columns']
    

In [197]:
len(data['data'])

12096

In [198]:
data.keys()

dict_keys(['columns', 'data', 'target_names', 'real_min', 'real_max', 'y_pred', 'y_gt', 'n_cls'])

In [206]:
# read json
to_plot_list = []
filter_threshold = {
    "support": 5,
    "fidelity": .85,
    "num_feat": 4,
    "num_bin": 3,
}

num_feat_list = [2,3, 4, 5, 6,7]
num_bin = filter_threshold['num_bin']


for data_name in data_list:
    file = "output/" + data_name + "/test.json"
    with open(file, 'r') as json_input:
        data = json.load(json_input)
    attrs = data['columns']
    df = pd.DataFrame(data=np.array(data['data']), columns = attrs)
    y_pred = data['y_pred']
    res = []
    for num_feat in num_feat_list:
        print("========", num_feat, data_name, "========")
        filter_threshold['num_feat'] = num_feat

        sure_len, sure_cover, sure_max_feat, sure_min_feat, sure_avg_feat, lattice_size = extract_rules_from_RF()    
        dt_len, dt_cover, dt_max_feat, dt_min_feat, dt_avg_feat, tree_size = extract_rules_from_DT()

#         to_plot_list.append([data_name, 'SuRE', sure_len, sure_cover, sure_max_feat, sure_min_feat, sure_avg_feat, num_feat, lattice_size])
#         to_plot_list.append([data_name, 'DT', dt_len, dt_cover, dt_max_feat, dt_min_feat, dt_avg_feat, num_feat, tree_size])
        res.append({
            'SuRE': {
                'len': sure_len,
                'cover': sure_cover,
                'max_feat': sure_max_feat,
                'min_feat': sure_min_feat,
                'avg_feat': sure_avg_feat,
                'num_feat': num_feat,
                'lattice/tree_size': lattice_size,
            },'DT': {
                'len': dt_len,
                'cover': dt_cover,
                'max_feat': dt_max_feat,
                'min_feat': dt_min_feat,
                'avg_feat': dt_avg_feat,
                'num_feat': num_feat,
                'lattice/tree_size': tree_size,
            }
        })
    
    with open("exp_output/compare_"+data_name+".json", "w") as output:
        output.write(json.dumps({"res": res}))



# Plotting

In [3]:
to_plot_list = []

data_name_list = [
 'crime',
 'diabetes',
 'dry_bean',
 'income',
 'loan',
 'obesity_level',
 'penn_cpu',
 'penn_satellite',
 'penn_wind',
 'music_origin'
]


for data_name in data_name_list:
    with open("exp_output/compare_" + data_name +'.json', 'r') as json_input:
        data = json.load(json_input)
        result = data['res']
    for res in result:
        r = res['SuRE']
        to_plot_list.append([data_name, 'SuRE', r['len'], r['cover'], r['max_feat'], r['min_feat'], r['avg_feat'], r['num_feat'], r['lattice/tree_size']])
        r = res['DT']
        to_plot_list.append([data_name, 'DT', r['len'], r['cover'], r['max_feat'], r['min_feat'], r['avg_feat'], r['num_feat'], r['lattice/tree_size']])
        


In [4]:
data_name_list.sort()
data_name_list

['crime',
 'diabetes',
 'dry_bean',
 'income',
 'loan',
 'music_origin',
 'obesity_level',
 'penn_cpu',
 'penn_satellite',
 'penn_wind']

In [15]:
to_plot = pd.DataFrame(data=to_plot_list, columns=['data_name', 'method', 'set_size', 'coverage', 'max_feat', 'min_feat', 'avg_feat', 'n_feat', 'node_num'])

In [16]:
to_plot

Unnamed: 0,data_name,method,set_size,coverage,max_feat,min_feat,avg_feat,n_feat,node_num
0,crime,SuRE,11,1.000000,2,1,1.454545,2,16
1,crime,DT,2,0.668966,2,1,1.500000,2,5
2,crime,SuRE,9,0.997492,3,1,2.000000,3,18
3,crime,DT,2,0.668966,2,1,1.500000,3,5
4,crime,SuRE,9,0.997492,3,1,2.000000,4,18
...,...,...,...,...,...,...,...,...,...
115,music_origin,DT,6,0.793388,5,1,3.666667,5,16
116,music_origin,SuRE,15,0.985832,6,1,3.200000,6,47
117,music_origin,DT,9,0.817001,6,1,4.444444,6,24
118,music_origin,SuRE,17,0.990555,7,1,3.823529,7,61


In [17]:
to_plot['coverage'] = to_plot['coverage'].round(2)
to_plot['avg_feat'] = to_plot['avg_feat'].round(2)

In [18]:
to_plot = to_plot.rename(columns={"set_size": "number of rules",
                       "avg_feat": "average conditions"}).replace(["SuRE", "DT"], ["HSR", "SDT"])

In [19]:
to_plot

Unnamed: 0,data_name,method,number of rules,coverage,max_feat,min_feat,average conditions,n_feat,node_num
0,crime,HSR,11,1.00,2,1,1.45,2,16
1,crime,SDT,2,0.67,2,1,1.50,2,5
2,crime,HSR,9,1.00,3,1,2.00,3,18
3,crime,SDT,2,0.67,2,1,1.50,3,5
4,crime,HSR,9,1.00,3,1,2.00,4,18
...,...,...,...,...,...,...,...,...,...
115,music_origin,SDT,6,0.79,5,1,3.67,5,16
116,music_origin,HSR,15,0.99,6,1,3.20,6,47
117,music_origin,SDT,9,0.82,6,1,4.44,6,24
118,music_origin,HSR,17,0.99,7,1,3.82,7,61


In [21]:
alt.Chart(to_plot).mark_line().encode(
    x=alt.X('n_feat:Q', axis=alt.Axis(tickCount=6,)),
    y=alt.Y(alt.repeat('row'), type='quantitative', axis=alt.Axis(tickCount=5,)),
    color="method:N",
).properties(
    width=80,
    height = 80,
).facet(
    column=alt.Column('data_name:N', title=None),
).repeat(
    row=['coverage', 'number of rules', 'average conditions'],
)