In [None]:
import json
import numpy as np
import pandas as pd
import os

from surrogate_rule import forest_info
from surrogate_rule import tree_node_info

In [None]:
import time

In [None]:
data_name_list = [
    'crime',
    'diabetes',
    'dry_bean',
    'income',
    'loan',
    'obesity_level',
    'penn_cpu',
    'penn_satellite',
    'penn_wind',
    'music_origin'
]


In [None]:
num_bin_list = [2, 3, 4, 5]
min_support_list = [10, 20, 30, 40, 50]
num_feat_list = [2,3,4,5]
min_fidelity_list = [.7, .75, .8, .85, .9]
res = []

In [None]:
def extract_rules_from_RF():
    start_time = time.time()
    # train surrogate
    surrogate_obj = tree_node_info.tree_node_info()

    surrogate_obj.initialize(X=np.array(data['data']), y=np.array(data['y_gt']), 
                             y_pred=np.array(data['y_pred']), debug_class=-1,
                             attrs=data['columns'], filter_threshold=filter_threshold,
                             n_cls=len(data['target_names']),
                             num_bin=num_bin, verbose=False
    ).train_surrogate_random_forest().tree_pruning()

    forest_obj = tree_node_info.forest()
    forest_obj.initialize(
        trees=surrogate_obj.tree_list, cate_X=surrogate_obj.cate_X, 
        y=surrogate_obj.y, y_pred=surrogate_obj.y_pred, attrs=attrs, num_bin=num_bin,
        real_percentiles=surrogate_obj.real_percentiles,
        real_min=surrogate_obj.real_min, real_max=surrogate_obj.real_max,
    ).construct_tree().extract_rules()
    
    forest = forest_info.Forest()

    forest.initialize(forest_obj.tree_node_dict, data['real_min'], data['real_max'], surrogate_obj.percentile_info,
        df, data['y_pred'], data['y_gt'],
        forest_obj.rule_lists,
        data['target_names'], 2)
    forest.initialize_rule_match_table()
    forest.initilized_rule_overlapping()
    res = forest.find_the_min_set()
    duration = time.time() - start_time
    
    return len(res['rules']), res['coverage'], duration

In [None]:
for data_name in data_name_list:
    file = "output/" + data_name + "/test.json"
    with open(file, 'r') as json_input:
        data = json.load(json_input)
    attrs = data['columns']
    df = pd.DataFrame(data=np.array(data['data']), columns = attrs)
    y_pred = data['y_pred']
    print("======== read", data_name, "========")
    
    res = []
    for num_feat in num_feat_list:
        for num_bin in num_bin_list:
            for min_support in min_support_list:
                for min_fidelity in min_fidelity_list:
                    filter_threshold = {
                        "support": min_support,
                        "fidelity": min_fidelity,
                        "num_feat": num_feat,
                        "num_bin": num_bin,
                    }
                    print("#feat=",num_feat, "#bin=", num_bin, "support=", min_support, "fidelity=",min_fidelity)
                    try:
                        sure_len, sure_cover, duration = extract_rules_from_RF()    
                    except ValueError:
                        continue
                    res.append({
                        'min_support': min_support,
                        'num_feat': num_feat,
                        'num_bin': num_bin,
                        'min_fidelity': min_fidelity,
                        'min_set_size': sure_len,
                        'time': duration,
                        'coverage': sure_cover,
                    })
    with open('./exp_output/'+data_name+'.json', 'w') as output:
        output.write(json.dumps({'res': res}))