In [1]:
import numpy as np
import torch
from captum.attr import ShapleyValueSampling
from tqdm import trange

from load_data import load_data
from train_models import *
from segmentation import *
from utils import *
import os

In [2]:
# to utils.py

def change_points_to_lengths(change_points, array_length):
	# change points is 1D iterable of idxs
	# assumes that each change point is the start of a new segment, aka change_points = start points
	start_points = np.array(change_points)
	end_points = np.append(change_points, array_length)[1:]
	lengths = end_points - start_points
	return lengths

def lengths_to_weights(lengths):
	# lengths is 1D iterable of positive ints
	start_idx = 0
	end_idx = 0
	segment_weights = 1 / lengths
	weights = np.ones(lengths.sum())
	for segment_weight, length in zip(segment_weights, lengths):
		end_idx += length
		weights[start_idx: end_idx] = segment_weight
		start_idx = end_idx
	return weights


In [3]:
# device for torch
from torch.cuda import is_available as is_GPU_available
device = "cuda" if is_GPU_available() else "cpu"

# dictionary mapping predictors to torch vs other, necessary for Captum 
predictors = {
	'torch' : ['resNet'],
	'scikit' : ['miniRocket','randomForest']
}

In [4]:
# load data
dataset_name = 'gunpoint'
predictor_name = 'resNet'

# I've returned also a Label encoder from load_data to have a mapping between dataset label
# which can be string while captum requires idx (integers)
X_train, X_test, y_train, y_test, enc = load_data(subset='all', dataset_name=dataset_name)

# train model
if predictor_name=='resNet':
	clf,preds = train_ResNet(X_train, y_train, X_test, y_test, dataset_name,device=device)
elif predictor_name=='miniRocket':
	clf,preds = train_miniRocket(X_train, y_train, X_test, y_test, dataset_name)

# create a dictionary to be dumped containing attribution and metadata
# initialize data structure meant to contain the segments
segments =  np.empty( (X_test.shape[0] , X_test.shape[1]), dtype=object) if X_test.shape[1] > 1  else (
	np.empty( X_test.shape[0] , dtype=object))

results = {
	'attributions' : {},
	'segments' : segments,
	'y_test_true' : y_test,
	'y_test_pred' : preds,
	'label_mapping' : enc,
}

training ResNet
Epoch 1: train loss:  0.673, 	 train accuracy  0.540 
          test loss:  0.583,  	 test accuracy  0.787
Epoch 11: train loss:  0.267, 	 train accuracy  0.900 
          test loss:  0.323,  	 test accuracy  0.913
Epoch 21: train loss:  0.119, 	 train accuracy  1.000 
          test loss:  0.168,  	 test accuracy  1.000
Epoch 31: train loss:  0.098, 	 train accuracy  1.000 
          test loss:  0.116,  	 test accuracy  1.000
Epoch 41: train loss:  0.052, 	 train accuracy  1.000 
          test loss:  0.093,  	 test accuracy  1.000
training early stopped! Final stats are:
Epoch 51: train loss:  0.049, 	 train accuracy  1.000 
          test loss:  0.078,  	 test accuracy  1.000
accuracy for resNet is  1.0


In [5]:
n_background = 50
background_types = ["average", "zero","sampling"] # zero, constant, average, multisample
for type in background_types:
	results['attributions'][type] = np.zeros( X_test.shape ,dtype=np.float32 )

In [6]:
with torch.no_grad():
	SHAP = ShapleyValueSampling(clf) if predictor_name in predictors['torch'] else ShapleyValueSampling(forward_classification)

	for i in trange ( X_test.shape[0] ) : #

		# get current sample and label
		ts, y = X_test[i] , torch.tensor( y_test[i:i+1] )

		# get segment and its tensor representation
		current_segments = get_claSP_segmentation(ts)[:X_test.shape[1]]
		results['segments'][i] = current_segments
		mask = get_feature_mask(current_segments,ts.shape[-1])

		ts = torch.tensor(ts).repeat(1,1,1)

		for background_type in background_types:


			# background data
			if background_type=="zero":
				background_dataset = torch.zeros((1,) + X_train.shape[1:])
			elif background_type=="sampling":
				background_dataset = sample_background(X_train, n_background)
			elif background_type=="average":
				background_dataset = sample_background(X_train, n_background).mean(axis=0, keepdim=True)

			ts = ts.repeat(background_dataset.shape[0],1,1) if background_type=="sampling" else ts

			# data structure with room for each sample in the background dataset
			if predictor_name in predictors['scikit']:
				tmp = SHAP.attribute( ts, target=y , feature_mask=mask, baselines=background_dataset, additional_forward_args=clf)
			elif predictor_name in predictors['torch']:
				ts = ts.to(device); y = y.to(device)
				mask = mask.to(device) ; background_dataset =  background_dataset.to(device)
				tmp = SHAP.attribute( ts, target=y , feature_mask=mask, baselines=background_dataset)
				########  only for random forest as every instance should be a 1D tensor    ########
				#current_attr[j:j+actual_size] = tmp.reshape(actual_size,X_test.shape[1],X_test.shape[2])
				###############################################################################

			# compute as final explanation mean of each explanation using a different baseline
			results['attributions'][background_type][i] = torch.mean(tmp, dim=0).cpu().numpy() if \
				background_type=="sampling" else tmp[0].cpu().numpy()

100%|██████████| 150/150 [02:12<00:00,  1.13it/s]


In [7]:
 # normalized weights
weights = np.array(list(map(lambda x: list(map(lambda y: lengths_to_weights(change_points_to_lengths(y, X_train.shape[-1])), x)), results["segments"])))
results["attributions"][background_type] *= weights

In [8]:
# dump result to disk
file_name = "_".join ( (predictor_name,dataset_name) )+".npy"
file_path = os.path.join("attributions", file_name)
np.save( file_path, results )

In [9]:
results['attributions'][background_type].sum(axis=(1,2))

array([-0.00771312,  0.03489595,  0.08431351, -0.04803198,  0.00195293,
        0.06863715,  0.00032749,  0.06003411,  0.01224429,  0.03213294,
        0.00200376, -0.00859401,  0.08302087, -0.00203428, -0.01158601,
        0.00246663,  0.02346536,  0.04300815,  0.04242647,  0.06092416,
       -0.00733136,  0.08475581,  0.02149864,  0.00419958, -0.0500086 ,
        0.00815944, -0.02794795, -0.03676845, -0.01178761, -0.00268818,
        0.00474527,  0.05431886,  0.03923873,  0.02237679, -0.00610698,
        0.00618449,  0.01734989,  0.00603125, -0.01472872,  0.03765767,
       -0.00510494,  0.0699935 ,  0.02377506,  0.00785602,  0.00218885,
       -0.01827871,  0.05209655,  0.02040819,  0.00573549,  0.03228654,
        0.01540048,  0.04855904,  0.02197419,  0.06047125,  0.03856869,
       -0.01358054,  0.01103381,  0.01240072,  0.00826397,  0.00489563,
       -0.02412014,  0.09228799,  0.00876795, -0.01159254,  0.02882524,
        0.0091797 , -0.04600435,  0.00671154, -0.00639732,  0.05