In [20]:
import numpy as np
from captum.attr import ShapleyValueSampling
from tqdm import trange

from load_data import load_data
from train_models import *
from segmentation import *
from utils import *
import os

In [1]:
import numpy as np

In [10]:
# to utils.py

def change_points_to_lengths(change_points, array_length):
    # change points is 1D iterable of idxs
    # assumes that each change point is the start of a new segment, aka change_points = start points
    start_points = np.array(change_points)
    end_points = np.append(change_points[1:], [array_length])
    print(start_points, end_points)
    lengths = end_points - start_points
    return lengths

def lengths_to_weights(lengths):
    # lengths is 1D iterable of positive ints
    start_idx = 0
    end_idx = 0
    segment_weights = 1 / lengths
    weights = np.ones(lengths.sum())
    for segment_weight, length in zip(segment_weights, lengths):
        end_idx += length
        weights[start_idx: end_idx] = segment_weight
        start_idx = end_idx
    return weights


In [22]:
# device for torch
from torch.cuda import is_available as is_GPU_available
device = "cuda" if is_GPU_available() else "cpu"

# dictionary mapping predictors to torch vs other, necessary for Captum 
predictors = {
	'torch' : ['resNet'],
	'scikit' : ['miniRocket','randomForest']
}

In [23]:
# load data
dataset_name = 'gunpoint'
predictor_name = 'resNet'

# I've returned also a Label encoder from load_data to have a mapping between dataset label
# which can be string while captum requires idx (integers)
X_train, X_test, y_train, y_test, enc = load_data(subset='all', dataset_name=dataset_name)

# train model
if predictor_name=='resNet':
	clf,preds = train_ResNet(X_train, y_train, X_test, y_test, dataset_name,device=device)
elif predictor_name=='miniRocket':
	clf,preds = train_miniRocket(X_train, y_train, X_test, y_test, dataset_name)

# create a dictionary to be dumped containing attribution and metadata
# initialize data structure meant to contain the segments
segments =  np.empty( (X_test.shape[0] , X_test.shape[1]), dtype=object) if X_test.shape[1] > 1  else (
	np.empty( X_test.shape[0] , dtype=object))

results = {
	'attributions' : {},
	'segments' : segments,
	'y_test_true' : y_test,
	'y_test_pred' : preds,
	'label_mapping' : enc,
}

training ResNet
Epoch 1: train loss:  0.675, 	 train accuracy  0.660 
          test loss:  0.597,  	 test accuracy  0.787
Epoch 11: train loss:  0.267, 	 train accuracy  0.940 
          test loss:  0.326,  	 test accuracy  0.927
Epoch 21: train loss:  0.119, 	 train accuracy  1.000 
          test loss:  0.164,  	 test accuracy  1.000
Epoch 31: train loss:  0.067, 	 train accuracy  1.000 
          test loss:  0.109,  	 test accuracy  1.000
Epoch 41: train loss:  0.049, 	 train accuracy  1.000 
          test loss:  0.089,  	 test accuracy  1.000
training early stopped! Final stats are:
Epoch 45: train loss:  0.066, 	 train accuracy  1.000 
          test loss:  0.079,  	 test accuracy  1.000
accuracy for resNet is  1.0


In [10]:
n_background = 50
background_types = ["average", "zero","sampling"] # zero, constant, average, multisample
for type in background_types:
	results['attributions'][type] = np.zeros( X_test.shape ,dtype=np.float32 )

In [11]:
with torch.no_grad():
	SHAP = ShapleyValueSampling(clf) if predictor_name in predictors['torch'] else ShapleyValueSampling(forward_classification)

	for i in trange ( X_test.shape[0] ) : #

		# get current sample and label
		ts, y = X_test[i] , torch.tensor( y_test[i:i+1] )

		# get segment and its tensor representation
		current_segments = get_claSP_segmentation(ts)[:X_test.shape[1]]
		results['segments'][i] = current_segments
		mask = get_feature_mask(current_segments,ts.shape[-1])

		ts = torch.tensor(ts).repeat(1,1,1)

		for background_type in background_types:


			# background data
			if background_type=="zero":
				background_dataset = torch.zeros((1,) + X_train.shape[1:])
			elif background_type=="sampling":
				background_dataset = sample_background(X_train, n_background)
			elif background_type=="average":
				background_dataset = sample_background(X_train, n_background).mean(axis=0, keepdim=True)

			ts = ts.repeat(background_dataset.shape[0],1,1) if background_type=="sampling" else ts

			# data structure with room for each sample in the background dataset
			if predictor_name in predictors['scikit']:
				tmp = SHAP.attribute( ts, target=y , feature_mask=mask, baselines=background_dataset, additional_forward_args=clf)
			elif predictor_name in predictors['torch']:
				ts = ts.to(device); y = y.to(device)
				mask = mask.to(device) ; background_dataset =  background_dataset.to(device)
				tmp = SHAP.attribute( ts, target=y , feature_mask=mask, baselines=background_dataset)
			########  only for random forest as every instance should be a 1D tensor    ########
			#current_attr[j:j+actual_size] = tmp.reshape(actual_size,X_test.shape[1],X_test.shape[2])
			###############################################################################

			# compute as final explanation mean of each explanation using a different baseline
			results['attributions'][background_type][i] = torch.mean(tmp, dim=0).cpu().numpy() if \
				background_type=="sampling" else tmp[0].cpu().numpy()

  1%|          | 1/150 [00:01<03:52,  1.56s/it]


KeyboardInterrupt: 

In [12]:
 # normalized weights
weights = np.array(list(map(lambda x: list(map(lambda y: lengths_to_weights(change_points_to_lengths(y, X_train.shape[-1])), x)), results["segments"])))
results["attributions"][background_type] *= weights

In [25]:
# dump result to disk
file_name = "_".join ( (dataset_name, predictor_name) )+".npy"
file_path = os.path.join("attributions", file_name)
np.save( file_path, results )

In [10]:
results['attributions'][background_type].sum(axis=(1,2))

array([ 1.70120507e-01,  1.52841201e-02,  9.28617418e-02,  1.13407120e-01,
        1.83220327e-01,  2.10560542e-02,  1.27464145e-01,  1.06733963e-01,
        1.75181329e-02,  1.57249138e-01,  1.40664607e-01,  1.49215162e-01,
        1.59901679e-02,  1.27517417e-01,  1.47278994e-01,  1.60658777e-01,
        2.17270672e-01,  5.09824380e-02,  3.32470946e-02,  3.66214737e-02,
        1.30641073e-01,  1.23184487e-01,  1.96215227e-01,  1.38463348e-01,
        8.75410587e-02,  9.70023274e-02,  1.71967596e-01,  1.50519937e-01,
        5.04426882e-02,  1.61408171e-01,  2.16436416e-01,  1.41229391e-01,
        1.41947210e-01,  1.62317842e-01,  1.64650887e-01,  1.64535791e-01,
        1.46388069e-01,  1.16750002e-01,  1.22873716e-01,  4.31026220e-02,
        1.57310590e-01,  1.00012690e-01,  1.55799270e-01,  1.55326679e-01,
        8.71996731e-02,  1.57125652e-01,  1.36757612e-01,  3.19495164e-02,
        1.55807734e-01,  3.47837806e-02,  9.90368426e-05,  1.06972188e-01,
        1.50174171e-01,  