In [1]:
from sklearn.linear_model import LinearRegression, LogisticRegressionCV
from aeon.datasets import  load_from_tsfile
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from aeon.transformations.collection.convolution_based import MiniRocketMultivariate
from aeon.datasets import load_basic_motions


In [2]:
from tsCaptum.ts_captum_method import Feature_Ablation, Feature_Permutation


# load regression data


In [3]:
X_train, y_train = load_from_tsfile("./data/AppliancesEnergy_TRAIN.ts")
X_test, y_test = load_from_tsfile("./data/AppliancesEnergy_TEST.ts")

# train regression model

In [4]:
clf =  make_pipeline( MiniRocketMultivariate(), StandardScaler(),
                      LinearRegression( n_jobs = -1))
clf.fit(X_train,y_train)
clf.score(X_test, y_test)

0.6480603985628843

# test regression methods .. TBC

In [5]:
myfa = Feature_Ablation(clf)
exp = myfa.explain(  X_test[:2] )
print( type(exp), exp.shape, exp[:,:5,:5]  , "\n\n")

myfp = Feature_Permutation(clf)
exp = myfp.explain( X_test[:2])
print( type(exp), exp.shape, exp[:,:5,:5]  , "\n\n")


tmp = Feature_Permutation(clf=clf)
tmp3 = Feature_Permutation(clf=clf, clf_type="regressor")

<class 'numpy.ndarray'> (2, 24, 144) [[[-0.02985191 -0.05060482 -0.03701019 -0.02889252 -0.04920673]
  [-0.1837759  -0.2010374  -0.24809074 -0.14790344 -0.18824673]
  [ 0.00544357  0.0478611   0.00166702  0.08009243  0.08660412]
  [ 0.03286648  0.04365444  0.09275818  0.07201195  0.08699036]
  [-0.15802383 -0.16177368 -0.21934032 -0.24625969 -0.23500061]]

 [[-0.01483345 -0.00751305  0.02280998  0.02466583  0.00994301]
  [-0.16393661 -0.10100174 -0.11806297 -0.06227112 -0.13948631]
  [-0.1089077   0.04946327 -0.04631233  0.00365639  0.03973007]
  [ 0.03846359  0.06023979  0.10733986  0.09062958  0.12396812]
  [-0.21626472 -0.2823372  -0.3541088  -0.41034698 -0.43452644]]] 


<class 'numpy.ndarray'> (2, 24, 144) [[[-0.01372719 -0.0133791  -0.03288364 -0.04853439 -0.06876087]
  [-0.01695347 -0.05210972 -0.03399658 -0.03756428 -0.04555035]
  [ 0.00052738 -0.00313473  0.00271416  0.00058556  0.0007515 ]
  [-0.01420116  0.00169659  0.01017666 -0.02364159 -0.04451275]
  [-0.06788254 -0.07894

# load classification data



In [6]:
X_train, y_train = load_basic_motions(split="train")
X_test, y_test = load_basic_motions(split="test")


# train regression model

In [7]:
clf =  make_pipeline( MiniRocketMultivariate(), StandardScaler(),
                      LogisticRegressionCV(cv = 5, random_state=0, n_jobs = -1,max_iter=1000))
clf.fit( X_train, y_train)
clf.score ( X_test, y_test )

1.0

# test XAI classification methods TBC

In [8]:
myfp4clf = Feature_Permutation(clf=clf, clf_type="classifier")
exp = myfp4clf.explain( X_test[:2], labels=y_test[:2] )
print( type(exp), exp.shape,exp[:,:5,:5]  , "\n\n")

myfp4clf2 = Feature_Permutation(clf=clf)
exp = myfp4clf2.explain( X_test[:2], labels=y_test[:2] )
print( type(exp), exp.shape, exp[:,:5,:5]  , "\n\n")

myfa4clf = Feature_Ablation(clf=clf, clf_type="classifier")
exp = myfa4clf.explain( X_test[:2], labels=y_test[:2] )
print( type(exp), exp.shape, exp[:,:5,:5]   , "\n\n")

<class 'numpy.ndarray'> (2, 6, 100) [[[ 3.50786370e-05  2.96636450e-05 -7.31043140e-05 -7.38552912e-08
   -3.83516598e-05]
  [ 2.90222083e-07  2.38001400e-05  4.15881745e-04  2.55555096e-04
    4.37418041e-04]
  [-5.96477451e-06  1.02141177e-05  7.53462997e-04  3.57395141e-04
    8.28554525e-05]
  [ 3.18053138e-05  4.26747908e-05  1.09903476e-03  1.58509910e-04
    3.15730063e-04]
  [ 5.72559282e-07  2.86787025e-07  1.37396042e-05  1.20624138e-04
    2.43623382e-04]]

 [[-4.15011192e-05 -7.95864548e-05 -1.14619998e-03 -5.57865792e-04
    1.63815660e-05]
  [ 2.62339451e-05  1.82291608e-05 -1.07716269e-03 -9.67546780e-04
   -7.20140760e-04]
  [ 9.39743630e-05  6.38675746e-05 -1.52333410e-03 -1.03562440e-03
   -1.00507412e-04]
  [-9.36591442e-05 -9.42794679e-05 -1.94986244e-03 -1.12417073e-03
   -5.67322612e-04]
  [-7.00433497e-06  1.50377404e-06 -8.75788445e-04 -4.80310024e-04
   -4.94830494e-04]]] 


<class 'numpy.ndarray'> (2, 6, 100) [[[ 3.50786370e-05  2.96636450e-05 -7.31043140e-05 