In [1]:
import numpy as np
import sys

sys.path.append("../../src")
from sti_explainer import StiExplainer
sys.path.append("../../experiments/1. archdetect")
from synthetic_utils import *

%load_ext autoreload
%autoreload 2

## Parameters

In [2]:
function_id = 4

p = 40 # num features
input_value, base_value = 1, -1

## Get Data and Synthetic Function

In [3]:
input = np.array([input_value]*p)
baseline = np.array([base_value]*p)

print("function id:", function_id)
model = synth_model(function_id, input_value, base_value)
gts = model.get_gts(p)

function id: 4


In [4]:
f_diff = (model(input)-model(baseline)).item()

## Get Explanation

In [5]:
sti_method = StiExplainer(model, input=input, baseline=baseline, output_indices=0, batch_size=20)

### Individual Tests

In [6]:
np.random.seed(42)

def subset_before(i, j, ordering, ordering_dict):
    end_idx = min(ordering_dict[i], ordering_dict[j])
    return ordering[:end_idx]

ordering = np.random.permutation(list(range(p)))
ordering_dict = {ordering[i]: i for i in range(len(ordering))}

att_sum = 0
inters = {}
for i in range(p):
    for j in range(0, p):
        if i >= j: continue
        T = subset_before(i, j, ordering, ordering_dict)
        S = (i,j)
        
        att = sti_method.attribution(S, T)
        att_sum+=att
        inters[S] = att
        
for i in range(p):
    att = sti_method.attribution([i], [])
    att_sum += att

### Check Completeness

In [7]:
assert(att_sum == f_diff)
print(att_sum, f_diff)

82.0 82.0


### Batch Version

In [8]:
%%time
num_orderings = 50
mat = sti_method.batch_attribution(num_orderings, pairwise=True, seed=4)
arr = sti_method.batch_attribution(num_orderings, main_effects=True, pairwise=False, seed=4)

att_sum = mat.sum() + arr.sum()

CPU times: user 13 s, sys: 93.9 ms, total: 13.1 s
Wall time: 22.8 s


### Check Completeness

In [9]:
assert(round(att_sum) == f_diff)
print(att_sum, f_diff)

81.97242471249145 82.0
