# Metrics for counterfactuals

In [1]:
import numpy as np
import pandas as pd
import re
from statistics import mean
import json
import pickle

In [2]:
from PythonScripts.metrics_perturbed import ClassificationMetrics, RegressionMetrics, Comparison

## Load data

### Dataset 

In [3]:
# datasets: iris, adult, auto

dataset_name = 'auto'

with open('../Configs/'+dataset_name+'.json') as config_file:
        config = json.load(config_file)

dataset = pd.read_csv('../Data/'+config['filtered_data_with_headers'], header = 0)
perturbed_dataset = pd.read_csv('../Data/'+config['perturbed_data'], header = 0)

dataset = dataset.values[0:config['sample']]
X = dataset[:,0:config['num_features']]
labels = dataset[:,config['target_col']-1]

perturbed_dataset = perturbed_dataset.values[0:config['sample']]
perturbed_X = perturbed_dataset[:,0:config['num_features']]
perturbed_labels = perturbed_dataset[:,config['target_col']]

features = np.arange(1,config['num_features']+1)

### Paths

In [4]:
paths = pd.read_csv('../Outputs/'+config['perturbed_paths'], header = 0)
bins = pd.read_csv('../Outputs/'+config['perturbed_local_bins'], header = 0)

paths = paths.values
bin_vals = bins.values

if 'factors' in config:
    bin_dict = dict((x[0], x[1]) for x in bin_vals) 
else:
    bin_dict = dict((x[0], float(x[1])) for x in bin_vals)

regex = re.compile(config['path_regex'], re.I)

### Depths

In [5]:
depths = pd.read_csv('../Outputs/'+config['tree_depths'], header = 0)

depths = depths.values
depths = depths.flatten()

### Binning

In [6]:
if config['type'] == 'regression':
    with open ('../Outputs/'+config['label_bins'], 'rb') as fp:
        label_bins = pickle.load(fp)

## Compute metrics

In [7]:
path_list = []

for i in range(2):
    temp = []
    for path in paths[:,i]:
        nodes = path.split(",")
        newpath = []
        for node in nodes:
            matchobj =  re.match(regex, node)
            newpath.append((int(matchobj.group(1)), bin_dict[matchobj.group(2)], matchobj.group(3)))
        temp.append(newpath)
    path_list.append(temp)

if 'factors' in config:
    factors = pd.read_csv('../Outputs/'+config['factors'], header = 0)
    factors = factors.values    
else:
    factors = None

if config['type'] == 'classification':
    metrics = ClassificationMetrics(path_list[0], labels, features, factors)
    perturbed_metrics = ClassificationMetrics(path_list[1], perturbed_labels, features, factors)
elif config['type'] == 'regression':
    metrics = RegressionMetrics(path_list[0], labels, features, factors, label_bins)
    perturbed_metrics = RegressionMetrics(path_list[1], perturbed_labels, features, factors, label_bins)
else:
    print(("Type {} not supported").format(config['type']))

## Results

### Decision set size

In [8]:
print(metrics.decision_paths_size(), perturbed_metrics.decision_paths_size())

392 392


### Decision set length

In [9]:
print(metrics.decision_paths_length(), perturbed_metrics.decision_paths_length())

2784 2672


### Average rule length

In [10]:
print(metrics.average_rule_length(), perturbed_metrics.average_rule_length())

7.1020408163265305 6.816326530612245


### Average distinct features

In [None]:
print(metrics.average_distinct_features(), perturbed_metrics.average_distinct_features())

4.3061224489795915 4.198979591836735


### Inter-class overlap

In [None]:
print(metrics.interclass_overlap(X), perturbed_metrics.interclass_overlap(perturbed_X))

### Intra-class overlap

In [None]:
print(metrics.intraclass_overlap(X), perturbed_metrics.intraclass_overlap(perturbed_X))

### Total number of classes covered

In [None]:
print(metrics.num_classes_covered(), perturbed_metrics.num_classes_covered())

### Correct cover

In [None]:
print(metrics.total_correct_cover(X), perturbed_metrics.total_correct_cover(perturbed_X))

### Incorrect cover

In [None]:
print(metrics.total_incorrect_cover(X), perturbed_metrics.total_incorrect_cover(perturbed_X))

### Mean rank

In [None]:
print(metrics.mean_rank(), perturbed_metrics.mean_rank())

### Feature frequencies at all depths

In [None]:
print(metrics.frequency_at_all_depths())
print(perturbed_metrics.frequency_at_all_depths())

## Comparison

In [None]:
changes = None
if 'factors' not in config:
    comp = Comparison(metrics, perturbed_metrics)
    changes = comp.change_of_class()
changes   