In [14]:
import numpy as np
from sklearn.linear_model import LinearRegression
from aeon.datasets import  load_from_tsfile
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from aeon.transformations.collection.convolution_based import MiniRocketMultivariate


In [7]:
from TSCaptum.ts_captum_method import Feature_Ablation, Feature_Permutation


In [15]:
# load data
X_train, y_train = load_from_tsfile("./data/AppliancesEnergy_TRAIN.ts")
X_test, y_test = load_from_tsfile("./data/AppliancesEnergy_TEST.ts")
print(X_test.shape, y_train, y_train.dtype, y_test, y_test.dtype)

(42, 24, 144) [19.38 12.68  5.34 12.72 13.25 26.28 13.1  14.06 10.92 10.46 20.74 21.31
 21.49 10.25 11.4  10.8  11.64 23.42 11.23 13.56 14.82 16.53 19.94 12.78
 11.49  9.63 11.53 12.97 23.01 11.83 13.37 12.24 14.8  19.01 12.98 12.07
 10.61 17.3   8.62  9.6  10.26  9.82 14.61  5.38 14.62 19.62 19.22 16.25
 16.22  9.17 15.89 10.82 18.18 12.03 11.54 13.21 10.51  7.03 11.63 16.41
 21.69 21.91 21.74 17.   22.1  12.68 10.51 14.99 10.31 12.54 16.05 18.56
 16.77 11.32 15.68 16.02 11.93 20.44 10.89 22.74 10.62 11.92 17.53 10.17
 11.06 10.63  9.99 10.11 13.29 14.28 14.71 13.69 13.87 17.66  8.75] float64 [17.37 20.65 11.42 10.68 12.44 11.17 24.12 10.99 13.76 14.56 14.97 13.43
 11.57  9.33 15.58 20.93 11.99 15.37 10.47 14.41 10.16 15.12 12.32 10.46
 15.06  9.68 10.69 17.8  10.69 17.06 20.88 10.89 13.47 13.62 13.48 14.89
 10.6  15.59 17.89 12.95 10.12 12.14] float64


# train the model


In [18]:
clf =  make_pipeline( MiniRocketMultivariate(), StandardScaler(),
                      LinearRegression( n_jobs = -1))
clf.fit(X_train,y_train)
clf.score(X_test, y_test)

0.5806377276253256

# feature ablation

In [19]:
myfa = Feature_Ablation(clf)
exp = myfa.explain(  X_test[:2] )
print( type(exp), exp.shape, exp[:,:5] )


<class 'numpy.ndarray'> (2, 24, 144) [[[-0.00438499  0.00987053  0.03636551 ... -0.00761414  0.05986595
    0.04519653]
  [-0.05688477 -0.09816551 -0.0851326  ... -0.24951744 -0.20546913
   -0.18114471]
  [-0.00237083  0.01057053  0.01975441 ... -0.01951027  0.01322174
    0.02563095]
  [ 0.04958725  0.06273174  0.0617733  ...  0.10015774  0.13770103
    0.1546278 ]
  [-0.13033485 -0.14841652 -0.19039917 ... -0.08766365 -0.00926018
    0.01327133]]

 [[ 0.01442146  0.03054428  0.03034782 ... -0.08519936  0.0651207
    0.0591526 ]
  [-0.05425072 -0.04529572 -0.03030396 ... -0.23029518 -0.19984627
   -0.1379509 ]
  [-0.01863098  0.00298119  0.0087204  ... -0.05579185  0.00400734
    0.02424622]
  [ 0.10404015  0.13443375  0.13075829 ...  0.02822304  0.08831978
    0.12795448]
  [-0.09727669 -0.10766411 -0.15155602 ... -0.07908058 -0.01750374
    0.02531052]]]


# feature permutation

In [20]:
myfp = Feature_Permutation(clf)
exp = myfp.explain( X_test[:2])
print( type(exp), exp.shape, exp[:,:5] )


<class 'numpy.ndarray'> (2, 24, 144) [[[-0.00222969 -0.01920319 -0.01606941 ... -0.01566887  0.0023365
    0.00990105]
  [-0.00646019 -0.04250145 -0.02719498 ... -0.06228447 -0.03265572
   -0.00814819]
  [-0.00016975  0.00322723  0.00405884 ...  0.00607681 -0.00415802
    0.00585747]
  [ 0.01461983  0.02619743  0.0500927  ...  0.02120781  0.0402813
    0.03228188]
  [-0.08144569 -0.12944794 -0.1618557  ... -0.07964897 -0.06194878
   -0.0467186 ]]

 [[-0.03256607 -0.03700447 -0.06622887 ... -0.02106667 -0.01533699
   -0.006073  ]
  [-0.01356506  0.00361824  0.03347778 ... -0.02192116 -0.01525688
   -0.00905609]
  [-0.00129318 -0.00095558 -0.00157738 ... -0.00587082  0.00432777
   -0.00326157]
  [ 0.00329781  0.00043869  0.02051926 ... -0.0525856  -0.04391479
   -0.01286888]
  [-0.05869865 -0.05541039 -0.10430717 ... -0.08131981 -0.08039284
   -0.06080818]]]
