In [1]:
import torch
import numpy as np
from tqdm import tqdm
import sys

sys.path.append("../../src")
from explainer import Archipelago
from synthetic_utils import *

sys.path.append("../../baselines/shapley_interaction_index")
from si_explainer import SiExplainer

sys.path.append("../../baselines/shapley_taylor_interaction_index")
from sti_explainer import StiExplainer, subset_before

sys.path.append("../../baselines/mahe_madex/madex")
from utils import general_utils as nid_utils
import neural_interaction_detection as nid

import statsmodels.api as sm
from statsmodels.formula.api import ols

%load_ext autoreload
%autoreload 2

## Parameters

In [2]:
method = "archdetect"
function_id = 4

p = 40 # num features
input_value, base_value = 1, -1

## Get Data and Synthetic Function

In [3]:
input = np.array([input_value]*p)
baseline = np.array([base_value]*p)

print("function id:", function_id)
model = synth_model(function_id, input_value, base_value)
gts = model.get_gts(p)

function id: 4


## Run Baseline Method

In [4]:
if method == "archdetect":
    apgo = Archipelago(model, input=input, baseline=baseline, output_indices=0, batch_size=20) 
    inter_scores = apgo.archdetect()["interactions"]
    
elif method == "si":
    si_method = SiExplainer(model, input=input, baseline=baseline, output_indices=0, batch_size=20, seed=42)

    num_T = 20
    inter_scores = []
    for i in range(p):
        for j in range(i+1, p):
            S = (i,j)
            att = si_method.attribution(S, num_T)
            inter_scores.append( (S, att**2))

elif method == "sti":
    sti_method = StiExplainer(model, input=input, baseline=baseline, output_indices=0, batch_size=20)
    
    inter_atts = sti_method.batch_attribution(num_orderings=20, pairwise=True, seed=42)
    inter_scores = []
    for i in range(p):
        for j in range(i+1, p):
            inter_scores.append( ( (i,j), inter_atts[i,j]**2) )
            
elif method == "nid":
    X, Y = gen_data_samples(model, input_value, base_value, p, n=30000, seed=42)
    Xs, Ys = nid_utils.proprocess_data(X, Y, valid_size = 10000, test_size=10000, std_scale_X=True, std_scale=True)
    inter_scores, mlp_loss = nid.detect_interactions(Xs, Ys, pairwise=True, seed=42)

elif method == "anova":
    X, Y = gen_data_samples(model, input_value, base_value, p, n=30000, seed=42)
    Xs, Ys = nid_utils.proprocess_data(X, Y, valid_size = 10000, test_size=10000, std_scale_X=True, std_scale=True)
    X_train = Xs["train"]
    Y_train = Ys["train"]
    
    data =  {}
    data['y'] = Y_train.squeeze()
    st =''
    for i in range(0,X_train.shape[1]):    
        data['X'+str(i)] = X_train[:,i]
        st+='+X'+str(i)
    st = "("+st[1:]+")"
    formula = 'y ~ '+st+":"+st

    lm = ols(formula,data=data).fit()

    table = sm.stats.anova_lm(lm, typ=2)
    inter_scores = []
    for i, name in enumerate(table.index):
        if name == "Residual": continue
        inter = tuple(int(x) for x in name.replace("X", "").split(":"))
        if len(inter)==1: continue

        inter_scores.append((inter, table.values[i,0]))

In [5]:
print("auc", get_auc(inter_scores, gts))

auc 1.0
