In [1]:
import json
import time
import timeit
import pickle
import numpy as np
from pandas import DataFrame, Series
import multiprocessing as mp
from forest_surveyor import p_count, p_count_corrected
import forest_surveyor.datasets as ds
from forest_surveyor.structures import forest_walker, batch_getter, rule_tester, loo_encoder, rule_accumulator
from forest_surveyor.routines import tune_rf, train_rf, evaluate_model, run_batch_explanations, anchors_preproc, anchors_explanation
from scipy.stats import chi2_contingency
from math import sqrt
from sklearn.metrics import confusion_matrix, cohen_kappa_score, precision_recall_fscore_support, accuracy_score

In [10]:
def mine_path_segments(walked, data_container,
                        support_paths=0.1, alpha_paths=0.5,
                        disc_path_bins=4, disc_path_eqcounts=False,
                        which_trees='majority'):

    # discretize any numeric features
    walked.discretize_paths(data_container.var_dict,
                            bins=disc_path_bins,
                            equal_counts=disc_path_eqcounts)

    # the patterns are found but not scored and sorted yet
    walked.mine_patterns(support=support_paths)
    print('done m')
    return(walked)

def score_sort_path_segments(walked, data_container,
                                sample_instances, sample_labels, encoder,
                                alpha_paths=0.5, weighting='chisq'):
    # best at -1 < alpha < 1
    # the patterns will be weighted by chi**2 for independence test, p-values
    if weighting == 'chisq':
        weights = [] * len(walked.patterns)
        for wp in walked.patterns:
            rt = rule_tester(data_container=data_container,
                            rule=wp,
                            sample_instances=sample_instances)
            rt.sample_instances = encoder.transform(rt.sample_instances)
            idx = rt.apply_rule()
            covered = p_count_corrected(sample_labels[idx], [i for i in range(len(data_container.class_names))])['counts']
            not_covered = p_count_corrected(sample_labels[~idx], [i for i in range(len(data_container.class_names))])['counts']
            observed = np.array((covered, not_covered))

            # this is the chisq based weighting. can add other options
            if covered.sum() > 0 and not_covered.sum() > 0: # previous_counts.sum() == 0 is impossible
                weights.append(sqrt(chi2_contingency(observed=observed[:, np.where(observed.sum(axis=0) != 0)], correction=True)[0]))
            else:
                weights.append(max(weights))

        # now the patterns are scored and sorted. alpha > 0 favours longer patterns. 0 neutral. < 0 shorter.
        walked.sort_patterns(alpha=alpha_paths, weights=weights) # with chi2 and support sorting
    else:
        walked.sort_patterns(alpha=alpha_paths) # with only support/alpha sorting
    return(walked)

def get_rule(rule_acc, encoder, sample_instances, sample_labels, pred_model,
                        greedy='precision', precis_threshold=0.95):

        # run the rule accumulator with greedy precis
        rule_acc.build_rule(encoder=encoder,
                    sample_instances=sample_instances,
                    sample_labels=sample_labels,
                    greedy=greedy,
                    prediction_model=pred_model,
                    precis_threshold=precis_threshold)
        rule_acc.prune_rule()
        ra_lite = rule_acc.lite_instance()

        # collect completed rule accumulator
        return(ra_lite)

def as_chirps(walked, data_container,
                        sample_instances, sample_labels, encoder, pred_model,
                        support_paths=0.1, alpha_paths=0.5,
                        disc_path_bins=4, disc_path_eqcounts=False,
                        which_trees='majority', weighting='chisq',
                        greedy='precis', precis_threshold=0.95,
                        batch_idx=None):
    # these steps make up the CHIRPS process:
    # mine paths for freq patts
    walked = mine_path_segments(walked, data_container,
                            support_paths, alpha_paths,
                            disc_path_bins, disc_path_eqcounts,
                            which_trees)
    # score and sort
    walked = score_sort_path_segments(walked, data_container,
                                    sample_instances, sample_labels,
                                    encoder, alpha_paths, weighting)
    # greedily add terms to create rule
    ra = rule_accumulator(data_container=data_container, paths_container=walked)
    ra_lite = get_rule(ra, encoder, sample_instances, sample_labels, pred_model,
    greedy, precis_threshold)

    return(batch_idx, ra_lite)


def run_b(f_walker, getter,
 data_container, encoder, sample_instances, sample_labels,
 batch_size = 1, n_batches = 1,
 support_paths=0.1, alpha_paths=0.5,
 disc_path_bins=4, disc_path_eqcounts=False,
 alpha_scores=0.5, which_trees='majority',
 precis_threshold=0.95, weighting='chisq', greedy='greedy',
 forest_walk_async=False, chirps_explanation_async=False):

    pred_model = f_walker.prediction_model
    # create a list to collect completed rule accumulators
    completed_rule_accs = [[]] * (batch_size * n_batches)

    for b in range(n_batches):
        print('walking forest for batch ' + str(b) + ' of batch size ' + str(batch_size))
        instances, labels = getter.get_next(batch_size)
        instance_ids = labels.index.tolist()
        # get all the tree paths instance
        batch_walked = f_walker.forest_walk(instances = instances
                                , labels = labels
                                , async = forest_walk_async)

        for batch_idx in range(batch_size):
            instance_id = instance_ids[batch_idx]
            # extract the current instance paths for freq patt mining, filter by majority trees only
            walked = batch_walked.get_instance_paths(batch_idx, which_trees=which_trees)
            walked.instance_id = instance_id
            print(batch_idx, instance_id)
            walked = mine_path_segments(walked, data_container, 
                        support_paths, alpha_paths,
                        disc_path_bins, disc_path_eqcounts,
                        which_trees)
            walked = score_sort_path_segments(walked, data_container,
                                sample_instances, sample_labels, encoder,
                                alpha_paths, weighting)
            
            ra = rule_accumulator(data_container=data_container, paths_container=walked)
            ra_lite = get_rule(ra, encoder, sample_instances, sample_labels, pred_model,
                                greedy, precis_threshold)

            # run the chirps process on each instance paths
#             _, completed_rule_acc = \
#                 as_chirps(walked, data_container,
#                 sample_instances, sample_labels,
#                 encoder, pred_model,
#                 support_paths, alpha_paths,
#                 disc_path_bins, disc_path_eqcounts,
#                 which_trees, weighting,
#                 greedy, precis_threshold,
#                 batch_idx)

#             # add the finished rule accumulator to the results
#             completed_rule_accs[b * batch_size + batch_idx] = [completed_rule_acc]
    return(instances, labels, instance_ids, batch_walked, walked)
    #return(batch_walked, walked, ra, ra_lite, batch_idx)


In [11]:
random_state = 125
override_tuning=False
add_trees=0
n_instances=3
n_batches=1

mydata = ds.credit_data(random_state=random_state)
tt = mydata.tt_split(random_state=123)

best_params = tune_rf(tt['X_train_enc'], tt['y_train'],
 save_path = mydata.pickle_path(),
 random_state=mydata.random_state,
 override_tuning=override_tuning)

best_params['n_estimators'] = best_params['n_estimators'] + add_trees

# train a rf model
rf, enc_rf = train_rf(X=tt['X_train_enc'], y=tt['y_train'],
 best_params=best_params,
 encoder=tt['encoder'],
 random_state=mydata.random_state)

# fit the forest_walker
f_walker = forest_walker(forest = rf,
 data_container=mydata,
 encoder=tt['encoder'],
 prediction_model=enc_rf)

# run the batch based forest walker
getter = batch_getter(instances=tt['X_test'], labels=tt['y_test'])

# faster to do one batch, avoids the overhead of setting up many but consumes more mem
# get a round number of instances no more than what's available in the test set
n_instances = min(n_instances, len(tt['y_test']))
batch_size = int(n_instances / n_batches)
n_instances = batch_size * n_batches

# collect completed rule_acc_lite objects for the whole batch
#batch_walked, walked, ra, ra_lite, batch_idx = run_b(f_walker=f_walker,
instances, labels, instance_ids, batch_walked, walked = run_b(f_walker=f_walker,                                                     
 getter=getter,
 data_container=mydata,
 encoder=tt['encoder'],
 sample_instances=tt['X_train'],
 sample_labels=tt['y_train'],
 support_paths=0.05,
 alpha_paths=0.5,
 disc_path_bins=4,
 disc_path_eqcounts=False,
 alpha_scores=0.5,
 which_trees='majority',
 precis_threshold=0.95,
 batch_size=batch_size,
 n_batches=n_batches,
 weighting='chisq',
 greedy='precis',
 forest_walk_async=False,
 chirps_explanation_async=False)

batch_idx = 0

using previous tuning parameters
walking forest for batch 0 of batch size 3
0 399
done m


  np.histogram(lowers, lower_bins)[0]).round(5)


1 250
done m


  np.histogram(uppers, upper_bins)[0]).round(5)


2 396
done m


In [11]:
batch_walked.path_detail[648][batch_idx]

{'instance_id': 399,
 'path': {'feature_idx': [],
  'feature_name': [],
  'feature_value': [],
  'leq_threshold': [],
  'threshold': []},
 'pred_class': 0,
 'pred_class_label': 'minus',
 'pred_proba': [0.5445134575569358, 0.4554865424430642],
 'tree_correct': True,
 'true_class': 0}

In [12]:
batch_walked = f_walker.forest_walk(instances = instances
                        , labels = labels
                        , async = False)

In [13]:
instances = tt['encoder'].transform(instances)
if 'todense' in dir(instances): # it's a sparse matrix
    instances = instances.todense()
n_features = instances.shape[1]

f_walker.forest.estimators_[648].predict(instances)

array([0., 0., 0.])

In [48]:
# f_walker.forest.estimators_[648].tree_.decision_path()
f_walker.forest.estimators_[648].tree_.feature

array([-2], dtype=int64)

In [42]:
instance_id = instance_ids[2]
# extract the current instance paths for freq patt mining, filter by majority trees only
walked = batch_walked.get_instance_paths(batch_idx, which_trees='majority')
walked.instance_id = instance_id

walked = mine_path_segments(walked, mydata, 
                            support_paths=0.05, alpha_paths=0.5,
                            disc_path_bins=4, disc_path_eqcounts=False,
                            which_trees='majority')

walked = score_sort_path_segments(walked, mydata,
                    encoder=tt['encoder'],
 sample_instances=tt['X_train'],
 sample_labels=tt['y_train'],
                    alpha_paths=0.5, weighting='chisq')

  np.histogram(lowers, lower_bins)[0]).round(5)


In [19]:
walked.paths[648]

[('A14', False, 85.14286),
 ('A10_t', True, 0.5),
 ('A9_f', False, 0.5),
 ('A7_v', True, 0.5),
 ('A13_g', False, 0.5),
 ('A2', False, 21.54912)]

In [14]:
walked.discretize_paths(mydata.var_dict,
                        bins=4, equal_counts=False)

[39.66999816894531, 48.125, 36.459999084472656, 55.665000915527344, 58.0, 36.459999084472656, 36.459999084472656, 36.625, 54.375, 52.665000915527344, 29.915000915527344, 35.209999084472656, 36.459999084472656, 35.0, 51.084999084472656, 40.959999084472656, 34.875, 52.459999084472656, 48.125, 36.625, 41.04499816894531, 45.5, 40.959999084472656, 46.70500183105469, 48.290000915527344, 29.915000915527344, 31.78408432006836, 41.04499816894531, 32.375, 35.790000915527344, 36.459999084472656, 43.084999084472656, 34.66999816894531, 30.329999923706055, 31.78408432006836, 40.959999084472656, 36.459999084472656, 33.084999084472656, 34.290000915527344, 32.040000915527344, 45.58000183105469, 51.040000915527344, 52.209999084472656, 41.25, 37.415000915527344, 33.625, 41.25, 36.459999084472656, 41.04499816894531, 53.33000183105469, 34.415000915527344, 47.709999084472656, 63.290000915527344, 50.834999084472656, 37.625, 34.959999084472656, 41.25, 34.875, 34.875, 29.920000076293945, 46.125, 33.375, 42.169

  np.histogram(lowers, lower_bins)[0]).round(5)


In [24]:
dir(f_walker)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'child_features',
 'class_col',
 'class_names',
 'encoder',
 'features',
 'forest',
 'forest_stats',
 'forest_stats_by_label',
 'forest_walk',
 'full_survey',
 'get_label',
 'lower_features',
 'n_features',
 'prediction_model',
 'root_features',
 'structure',
 'tree_structures']

In [27]:
f_walker.forest.estimators_[648]

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=16,
            max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=5, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False,
            random_state=1609987677, splitter='best')