In [1]:
import numpy as np
from captum.attr import ShapleyValueSampling
from tqdm import trange

from load_data import load_data
from train_models import *
from segmentation import *
from utils import *

In [2]:
# to utils.py

def change_points_to_lengths(change_points, max_length):
    # change points is 1D iterable of idxs
    change_points = list(change_points)
    start_points = [0] + change_points
    end_points = change_points + [max_length]
    lengths = np.array(end_points) - np.array(start_points)
    return lengths

def lengths_to_weights(lengths):
    # lengths is 1D iterable of positive ints
    start_idx = 0
    end_idx = 0
    segment_weights = 1 / lengths
    weights = np.ones(lengths.sum())
    for segment_weight, length in zip(segment_weights, lengths):
        end_idx += length
        weights[start_idx: end_idx] = segment_weight
        start_idx = end_idx
    return weights


In [3]:
# device for torch
from torch.cuda import is_available as is_GPU_available
device = "cuda" if is_GPU_available() else "cpu"

# dictionary of predictor types; useful for Captum 
predictors = {
	'torch' : ['resNet'],
	'scikit' : ['randomForest']
}

In [7]:
# load data
dataset_name = 'gunpoint'
predictor_name = 'miniRocket'

# I've returned also a Label encoder from load_data to have a mapping between dataset label
# which can be string while captum requires idx (integers)
X_train, X_test, y_train, y_test, enc = load_data(subset='all', dataset_name=dataset_name)

# X_test = X_test[:2
# y_test = y_test[:2]

# train model
#clf, preds = train_randomForest(X_train,y_train,X_test,y_test, dataset_name)
if predictor_name=='resNet':
	clf,preds = train_ResNet(X_train, y_train, X_test, y_test, dataset_name,device=device)
elif predictor_name=='miniRocket':
	clf,preds = train_miniRocket(X_train, y_train, X_test, y_test, dataset_name)

# create a dictionary to be dumped containing attribution and metadata
# initialize data structure meant to contain the segments
segments =  np.empty( (X_test.shape[0] , X_test.shape[1]), dtype=object) if X_test.shape[1] > 1  else (
	np.empty( X_test.shape[0] , dtype=object))

all_attributions = {
	'attributions' : np.empty( X_test.shape ,dtype=np.float32 ),
	'segments' : segments,
	'y_test_true' : y_test,
	'y_test_pred' : preds,
	'label_mapping' : enc,
}

training miniRocket
accuracy for miniRocket is  0.9733333333333334


In [8]:
clf

In [9]:
# explain
n_background = 50
background_type = "average" # zero, constant, average, multisample
batch_size = 32

# TODO for each baseline so that I don't retrain a model each time 

with torch.no_grad():
    SHAP = ShapleyValueSampling(forward_classification) if predictor_name in predictors['scikit'] else ShapleyValueSampling(clf)
    
    for i in range ( X_test.shape[0] ) : # 
        # get current sample and label
        ts, y = X_test[i] , torch.tensor( y_test[i:i+1] )

        # get segment and its tensor representation
        current_segments = get_claSP_segmentation(ts)[:X_test.shape[1]]
        all_attributions['segments'][i] = current_segments
        mask = get_feature_mask(current_segments,ts.shape[-1])

        # background data
        
        if background_type=="zero":
            background_dataset = torch.zeros((1,) + X_train.shape[1:])
        elif background_type=="sampling":
            background_dataset = sample_background(X_train, n_background)
        elif background_type=="average":
            background_dataset = sample_background(X_train, n_background).mean(axis=0, keepdim=True)

        print("\n explaining sample n.",i,"\n")
        # data structure with room for each sample in the background dataset
        current_attr = torch.zeros(background_dataset.shape[0], ts.shape[0], ts.shape[1])
        for j in trange(0,background_dataset.shape[0] ,batch_size):

            sample = background_dataset[j:j+batch_size]
            actual_size = sample.shape[0]
            batched_ts = torch.tensor( np.array([ts]*actual_size) )

            ##### only for random forest as every instance should be a 1D tensor #######
            #batched_ts , sample = batched_ts.reshape(actual_size,-1), sample.reshape(actual_size,-1)
            #mask = mask.reshape(1,-1)
            ###############################################################################

            if predictor_name in predictors['scikit']:
                tmp = SHAP.attribute( batched_ts, target=y , feature_mask=mask, baselines=sample, additional_forward_args=clf)
            elif predictor_name in predictors['torch']:
                batched_ts = batched_ts.to(device); y = y.to(device)
                mask = mask.to(device) ; sample =  sample.to(device)
                
                tmp = SHAP.attribute( batched_ts, target=y , feature_mask=mask, baselines=sample)

            ########  only for random forest as every instance should be a 1D tensor    ########
            #current_attr[j:j+actual_size] = tmp.reshape(actual_size,X_test.shape[1],X_test.shape[2])
            ###############################################################################

        # compute as final explanation mean of each explanation using a different baseline
        all_attributions['attributions'][i] =torch.mean(current_attr,dim=0)

weights = np.array(list(map(lambda x: list(map(lambda y: lengths_to_weights(change_points_to_lengths(y, X_train.shape[-1])), x)), all_attributions["segments"])))
all_attributions["attributions"] *= weights


 explaining sample n. 0 



100%|██████████| 1/1 [00:00<00:00, 3454.95it/s]



 explaining sample n. 1 



100%|██████████| 1/1 [00:00<00:00, 3155.98it/s]



 explaining sample n. 2 



100%|██████████| 1/1 [00:00<00:00, 3379.78it/s]



 explaining sample n. 3 



100%|██████████| 1/1 [00:00<00:00, 3410.00it/s]



 explaining sample n. 4 



100%|██████████| 1/1 [00:00<00:00, 3360.82it/s]



 explaining sample n. 5 



100%|██████████| 1/1 [00:00<00:00, 1183.83it/s]



 explaining sample n. 6 



100%|██████████| 1/1 [00:00<00:00, 3669.56it/s]



 explaining sample n. 7 



100%|██████████| 1/1 [00:00<00:00, 1860.00it/s]



 explaining sample n. 8 



100%|██████████| 1/1 [00:00<00:00, 6000.43it/s]



 explaining sample n. 9 



100%|██████████| 1/1 [00:00<00:00, 3024.01it/s]



 explaining sample n. 10 



100%|██████████| 1/1 [00:00<00:00, 1174.88it/s]



 explaining sample n. 11 



100%|██████████| 1/1 [00:00<00:00, 1018.28it/s]



 explaining sample n. 12 



100%|██████████| 1/1 [00:00<00:00, 3666.35it/s]



 explaining sample n. 13 



100%|██████████| 1/1 [00:00<00:00, 1818.87it/s]



 explaining sample n. 14 



100%|██████████| 1/1 [00:00<00:00, 2146.52it/s]



 explaining sample n. 15 



100%|██████████| 1/1 [00:00<00:00, 2805.55it/s]



 explaining sample n. 16 



100%|██████████| 1/1 [00:00<00:00, 5412.01it/s]



 explaining sample n. 17 



100%|██████████| 1/1 [00:00<00:00, 4723.32it/s]



 explaining sample n. 18 



100%|██████████| 1/1 [00:00<00:00, 5309.25it/s]



 explaining sample n. 19 



100%|██████████| 1/1 [00:00<00:00, 4905.62it/s]



 explaining sample n. 20 



100%|██████████| 1/1 [00:00<00:00, 2978.91it/s]



 explaining sample n. 21 



100%|██████████| 1/1 [00:00<00:00, 1253.53it/s]



 explaining sample n. 22 



100%|██████████| 1/1 [00:00<00:00, 4324.02it/s]



 explaining sample n. 23 



100%|██████████| 1/1 [00:00<00:00, 5349.88it/s]



 explaining sample n. 24 



100%|██████████| 1/1 [00:00<00:00, 3184.74it/s]



 explaining sample n. 25 



100%|██████████| 1/1 [00:00<00:00, 4424.37it/s]



 explaining sample n. 26 



100%|██████████| 1/1 [00:00<00:00, 1098.56it/s]



 explaining sample n. 27 



100%|██████████| 1/1 [00:00<00:00, 3563.55it/s]



 explaining sample n. 28 



100%|██████████| 1/1 [00:00<00:00, 2972.58it/s]


 explaining sample n. 29 




100%|██████████| 1/1 [00:00<00:00, 3919.91it/s]



 explaining sample n. 30 



100%|██████████| 1/1 [00:00<00:00, 1269.85it/s]



 explaining sample n. 31 



100%|██████████| 1/1 [00:00<00:00, 4975.45it/s]



 explaining sample n. 32 



100%|██████████| 1/1 [00:00<00:00, 4438.42it/s]



 explaining sample n. 33 



100%|██████████| 1/1 [00:00<00:00, 1737.49it/s]



 explaining sample n. 34 



100%|██████████| 1/1 [00:00<00:00, 5047.30it/s]



 explaining sample n. 35 



100%|██████████| 1/1 [00:00<00:00, 2822.55it/s]



 explaining sample n. 36 



100%|██████████| 1/1 [00:00<00:00, 1216.80it/s]



 explaining sample n. 37 



100%|██████████| 1/1 [00:00<00:00, 3751.61it/s]



 explaining sample n. 38 



100%|██████████| 1/1 [00:00<00:00, 790.33it/s]



 explaining sample n. 39 



100%|██████████| 1/1 [00:00<00:00, 2801.81it/s]



 explaining sample n. 40 



100%|██████████| 1/1 [00:00<00:00, 1797.05it/s]



 explaining sample n. 41 



100%|██████████| 1/1 [00:00<00:00, 1716.16it/s]



 explaining sample n. 42 



100%|██████████| 1/1 [00:00<00:00, 5562.74it/s]



 explaining sample n. 43 



100%|██████████| 1/1 [00:00<00:00, 1016.55it/s]



 explaining sample n. 44 



100%|██████████| 1/1 [00:00<00:00, 6423.13it/s]



 explaining sample n. 45 



100%|██████████| 1/1 [00:00<00:00, 3132.42it/s]



 explaining sample n. 46 



100%|██████████| 1/1 [00:00<00:00, 3457.79it/s]



 explaining sample n. 47 



100%|██████████| 1/1 [00:00<00:00, 3097.71it/s]



 explaining sample n. 48 



100%|██████████| 1/1 [00:00<00:00, 2908.67it/s]



 explaining sample n. 49 



100%|██████████| 1/1 [00:00<00:00, 4288.65it/s]



 explaining sample n. 50 



100%|██████████| 1/1 [00:00<00:00, 3979.42it/s]



 explaining sample n. 51 



100%|██████████| 1/1 [00:00<00:00, 2418.86it/s]



 explaining sample n. 52 



100%|██████████| 1/1 [00:00<00:00, 3360.82it/s]



 explaining sample n. 53 



100%|██████████| 1/1 [00:00<00:00, 6543.38it/s]



 explaining sample n. 54 



100%|██████████| 1/1 [00:00<00:00, 2991.66it/s]



 explaining sample n. 55 



100%|██████████| 1/1 [00:00<00:00, 3603.35it/s]



 explaining sample n. 56 



100%|██████████| 1/1 [00:00<00:00, 3454.95it/s]



 explaining sample n. 57 



100%|██████████| 1/1 [00:00<00:00, 4987.28it/s]



 explaining sample n. 58 



100%|██████████| 1/1 [00:00<00:00, 4202.71it/s]



 explaining sample n. 59 



100%|██████████| 1/1 [00:00<00:00, 1729.61it/s]



 explaining sample n. 60 



100%|██████████| 1/1 [00:00<00:00, 4258.18it/s]



 explaining sample n. 61 



100%|██████████| 1/1 [00:00<00:00, 2353.71it/s]



 explaining sample n. 62 



100%|██████████| 1/1 [00:00<00:00, 2939.25it/s]



 explaining sample n. 63 



100%|██████████| 1/1 [00:00<00:00, 10538.45it/s]



 explaining sample n. 64 



100%|██████████| 1/1 [00:00<00:00, 908.25it/s]



 explaining sample n. 65 



100%|██████████| 1/1 [00:00<00:00, 2416.07it/s]



 explaining sample n. 66 



100%|██████████| 1/1 [00:00<00:00, 2933.08it/s]



 explaining sample n. 67 



100%|██████████| 1/1 [00:00<00:00, 3998.38it/s]



 explaining sample n. 68 



100%|██████████| 1/1 [00:00<00:00, 3560.53it/s]



 explaining sample n. 69 



100%|██████████| 1/1 [00:00<00:00, 1273.32it/s]



 explaining sample n. 70 



100%|██████████| 1/1 [00:00<00:00, 1973.79it/s]



 explaining sample n. 71 



100%|██████████| 1/1 [00:00<00:00, 5398.07it/s]



 explaining sample n. 72 



100%|██████████| 1/1 [00:00<00:00, 3382.50it/s]



 explaining sample n. 73 



100%|██████████| 1/1 [00:00<00:00, 2730.67it/s]



 explaining sample n. 74 



100%|██████████| 1/1 [00:00<00:00, 3701.95it/s]



 explaining sample n. 75 



100%|██████████| 1/1 [00:00<00:00, 2695.57it/s]



 explaining sample n. 76 



100%|██████████| 1/1 [00:00<00:00, 1937.32it/s]



 explaining sample n. 77 



100%|██████████| 1/1 [00:00<00:00, 3666.35it/s]



 explaining sample n. 78 



100%|██████████| 1/1 [00:00<00:00, 3512.82it/s]



 explaining sample n. 79 



100%|██████████| 1/1 [00:00<00:00, 4848.91it/s]



 explaining sample n. 80 



100%|██████████| 1/1 [00:00<00:00, 3347.41it/s]



 explaining sample n. 81 



100%|██████████| 1/1 [00:00<00:00, 5262.61it/s]



 explaining sample n. 82 



100%|██████████| 1/1 [00:00<00:00, 4495.50it/s]



 explaining sample n. 83 



100%|██████████| 1/1 [00:00<00:00, 2711.25it/s]



 explaining sample n. 84 



100%|██████████| 1/1 [00:00<00:00, 1931.08it/s]



 explaining sample n. 85 



100%|██████████| 1/1 [00:00<00:00, 3701.95it/s]



 explaining sample n. 86 



100%|██████████| 1/1 [00:00<00:00, 2700.78it/s]



 explaining sample n. 87 



100%|██████████| 1/1 [00:00<00:00, 2291.97it/s]



 explaining sample n. 88 



100%|██████████| 1/1 [00:00<00:00, 3675.99it/s]



 explaining sample n. 89 



100%|██████████| 1/1 [00:00<00:00, 2434.30it/s]



 explaining sample n. 90 



100%|██████████| 1/1 [00:00<00:00, 4165.15it/s]



 explaining sample n. 91 



100%|██████████| 1/1 [00:00<00:00, 3437.95it/s]



 explaining sample n. 92 



100%|██████████| 1/1 [00:00<00:00, 2981.03it/s]



 explaining sample n. 93 



100%|██████████| 1/1 [00:00<00:00, 2898.62it/s]



 explaining sample n. 94 



100%|██████████| 1/1 [00:00<00:00, 4415.06it/s]



 explaining sample n. 95 



100%|██████████| 1/1 [00:00<00:00, 3666.35it/s]



 explaining sample n. 96 



100%|██████████| 1/1 [00:00<00:00, 4529.49it/s]



 explaining sample n. 97 



100%|██████████| 1/1 [00:00<00:00, 2035.08it/s]



 explaining sample n. 98 



100%|██████████| 1/1 [00:00<00:00, 2081.54it/s]



 explaining sample n. 99 



100%|██████████| 1/1 [00:00<00:00, 3938.31it/s]



 explaining sample n. 100 



100%|██████████| 1/1 [00:00<00:00, 4723.32it/s]



 explaining sample n. 101 



100%|██████████| 1/1 [00:00<00:00, 3734.91it/s]



 explaining sample n. 102 



100%|██████████| 1/1 [00:00<00:00, 4793.49it/s]



 explaining sample n. 103 



100%|██████████| 1/1 [00:00<00:00, 2663.05it/s]



 explaining sample n. 104 



100%|██████████| 1/1 [00:00<00:00, 3106.89it/s]



 explaining sample n. 105 



100%|██████████| 1/1 [00:00<00:00, 2475.98it/s]



 explaining sample n. 106 



100%|██████████| 1/1 [00:00<00:00, 3472.11it/s]



 explaining sample n. 107 



100%|██████████| 1/1 [00:00<00:00, 4096.00it/s]



 explaining sample n. 108 



100%|██████████| 1/1 [00:00<00:00, 2851.33it/s]



 explaining sample n. 109 



100%|██████████| 1/1 [00:00<00:00, 1102.02it/s]



 explaining sample n. 110 



100%|██████████| 1/1 [00:00<00:00, 3880.02it/s]



 explaining sample n. 111 



100%|██████████| 1/1 [00:00<00:00, 3566.59it/s]



 explaining sample n. 112 



100%|██████████| 1/1 [00:00<00:00, 4686.37it/s]



 explaining sample n. 113 



100%|██████████| 1/1 [00:00<00:00, 4443.12it/s]



 explaining sample n. 114 



100%|██████████| 1/1 [00:00<00:00, 3326.17it/s]



 explaining sample n. 115 



100%|██████████| 1/1 [00:00<00:00, 5256.02it/s]



 explaining sample n. 116 



100%|██████████| 1/1 [00:00<00:00, 4405.78it/s]



 explaining sample n. 117 



100%|██████████| 1/1 [00:00<00:00, 2920.82it/s]



 explaining sample n. 118 



100%|██████████| 1/1 [00:00<00:00, 2805.55it/s]



 explaining sample n. 119 



100%|██████████| 1/1 [00:00<00:00, 1181.16it/s]



 explaining sample n. 120 



100%|██████████| 1/1 [00:00<00:00, 7133.17it/s]



 explaining sample n. 121 



100%|██████████| 1/1 [00:00<00:00, 4202.71it/s]



 explaining sample n. 122 



100%|██████████| 1/1 [00:00<00:00, 404.78it/s]



 explaining sample n. 123 



100%|██████████| 1/1 [00:00<00:00, 3492.34it/s]



 explaining sample n. 124 



100%|██████████| 1/1 [00:00<00:00, 1065.63it/s]



 explaining sample n. 125 



100%|██████████| 1/1 [00:00<00:00, 2989.53it/s]



 explaining sample n. 126 



100%|██████████| 1/1 [00:00<00:00, 2064.13it/s]



 explaining sample n. 127 



100%|██████████| 1/1 [00:00<00:00, 3788.89it/s]



 explaining sample n. 128 



100%|██████████| 1/1 [00:00<00:00, 4629.47it/s]



 explaining sample n. 129 



100%|██████████| 1/1 [00:00<00:00, 1194.96it/s]



 explaining sample n. 130 



100%|██████████| 1/1 [00:00<00:00, 4905.62it/s]



 explaining sample n. 131 



100%|██████████| 1/1 [00:00<00:00, 887.31it/s]



 explaining sample n. 132 



100%|██████████| 1/1 [00:00<00:00, 3292.23it/s]



 explaining sample n. 133 



100%|██████████| 1/1 [00:00<00:00, 3310.42it/s]



 explaining sample n. 134 



100%|██████████| 1/1 [00:00<00:00, 3432.33it/s]



 explaining sample n. 135 



100%|██████████| 1/1 [00:00<00:00, 1081.84it/s]



 explaining sample n. 136 



100%|██████████| 1/1 [00:00<00:00, 3754.97it/s]



 explaining sample n. 137 



100%|██████████| 1/1 [00:00<00:00, 6150.01it/s]



 explaining sample n. 138 



100%|██████████| 1/1 [00:00<00:00, 4346.43it/s]



 explaining sample n. 139 



100%|██████████| 1/1 [00:00<00:00, 6689.48it/s]



 explaining sample n. 140 



100%|██████████| 1/1 [00:00<00:00, 1043.88it/s]



 explaining sample n. 141 



100%|██████████| 1/1 [00:00<00:00, 2748.56it/s]



 explaining sample n. 142 



100%|██████████| 1/1 [00:00<00:00, 5511.57it/s]



 explaining sample n. 143 



100%|██████████| 1/1 [00:00<00:00, 7397.36it/s]



 explaining sample n. 144 



100%|██████████| 1/1 [00:00<00:00, 1122.67it/s]



 explaining sample n. 145 



100%|██████████| 1/1 [00:00<00:00, 3116.12it/s]



 explaining sample n. 146 



100%|██████████| 1/1 [00:00<00:00, 4144.57it/s]



 explaining sample n. 147 



100%|██████████| 1/1 [00:00<00:00, 2968.37it/s]



 explaining sample n. 148 



100%|██████████| 1/1 [00:00<00:00, 2432.89it/s]



 explaining sample n. 149 



100%|██████████| 1/1 [00:00<00:00, 2129.09it/s]
  segment_weights = 1 / lengths


In [10]:
all_attributions['attributions'].sum(axis=(1,2))

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)