In [3]:
import warnings

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

from aeon.datasets import load_italy_power_demand
from aeon.registry import all_estimators
from aeon.transformations.collection.shapelet_based import RandomShapeletTransform

warnings.filterwarnings("ignore")
all_estimators("classifier", filter_tags={"algorithm_type": "shapelet"})

[('MrSQMClassifier',
  aeon.classification.shapelet_based._mrsqm.MrSQMClassifier),
 ('RDSTClassifier', aeon.classification.shapelet_based._rdst.RDSTClassifier),
 ('ShapeletTransformClassifier',
  aeon.classification.shapelet_based._stc.ShapeletTransformClassifier)]

In [4]:
X,y = load_italy_power_demand(split='train', return_type='numpy2d')

In [12]:
X

array([[-0.71051757, -1.1833204 , -1.3724416 , ...,  0.58181015,
         0.1720477 , -0.26923494],
       [-0.99300935, -1.4267865 , -1.5798843 , ...,  0.69106647,
        -0.04890624, -0.38061813],
       [ 1.3190669 ,  0.56977448,  0.19512825, ...,  2.3493441 ,
         2.2556825 ,  1.6000516 ],
       ...,
       [-1.159152  , -1.3014    , -1.5249326 , ...,  0.46653962,
        -0.12277359, -0.65112336],
       [-0.6949193 , -1.2358295 , -1.4161329 , ...,  0.65735611,
         0.20659763, -0.19908499],
       [ 0.98403309,  0.16966088, -0.51942331, ...,  2.4248455 ,
         1.9236934 ,  1.1719652 ]])

In [13]:
X.shape

(67, 24)

In [14]:
rst = RandomShapeletTransform(n_shapelet_samples=100, max_shapelets=10, random_state=42)
st = rst.fit_transform(X, y)
print(" Shape of transformed data = ", st.shape)
print(" Distance of second series to third shapelet = ", st[1][2])
testX, testy = load_italy_power_demand(split="test")
tr_test = rst.transform(testX)
rf = RandomForestClassifier(random_state=10)
rf.fit(st, y)
preds = rf.predict(tr_test)
print(" Shapelets + random forest acc = ", accuracy_score(preds, testy))

In [15]:
y.shape

(67,)

In [16]:
running_shapelet = rst.shapelets[0]
print("Quality = ", running_shapelet[0])
print("Length = ", running_shapelet[1])
print("position = ", running_shapelet[2])
print("Channel = ", running_shapelet[3])
print("Origin Instance Index = ", running_shapelet[4])
print("Class label = ", running_shapelet[5])
print("Shapelet = ", running_shapelet[6])

In [17]:
print('Hello')