In [None]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sktime.datasets import load_UCR_UEA_dataset
from sktime.classification.shapelet_based import ShapeletTransformClassifier

from sklearn.metrics import classification_report, precision_recall_fscore_support 

import warnings
warnings.filterwarnings('ignore')

In [None]:
X_train, y_train = load_UCR_UEA_dataset(name="Strawberry",
                                        split='train', 
                                        return_type='numpy3D',
                                        extract_path="./data")
X_test, y_test = load_UCR_UEA_dataset(name="Strawberry",
                                      split='test', 
                                      return_type='numpy3D',
                                      extract_path='./data')

In [None]:
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

In [None]:
def find_instance_with_label(X, y, label):
    indices = np.where(y == label)[0]
    if len(indices) == 0:
        raise ValueError(f"No instance found with label {label}")
    instance = X[indices[0]]
    
    return instance

instance_label_1 = find_instance_with_label(X_train, y_train, '1')
instance_label_2 = find_instance_with_label(X_train, y_train, '2')

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

lines1 = ax1.plot(instance_label_1.T)
ax1.set_ylabel("Value")
ax1.set_title("Instance with Label 1")

lines2 = ax2.plot(instance_label_2.T)
ax2.set_xlabel("Timestep")
ax2.set_ylabel("Value")
ax2.set_title("Instance with Label 2")

plt.tight_layout()
plt.show()

## Visualizing the shapelet transform

In [None]:
from pyts.transformation import ShapeletTransform

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
fig.suptitle('The four most discriminative shapelets', fontsize=16)

for i, index in enumerate(st.indices_[:4]):
    idx, start, end = index
    row = i // 2
    col = i % 2
    
    # Plot the full time series
    axs[row, col].plot(X_train_squeezed[idx], color='C{}'.format(i),
                       label='Full series')
    
    # Plot the shapelet
    axs[row, col].plot(np.arange(start, end), X_train_squeezed[idx, start:end],
                       lw=5, color='C{}'.format(i), label='Shapelet')
    
    axs[row, col].set_title(f'Sample {idx}')
    axs[row, col].set_xlabel('Time')
    axs[row, col].set_ylabel('Value')
    axs[row, col].legend(loc='best')

plt.tight_layout()
plt.show()

## Shapelet classifier

In [None]:
%%time



clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

clf_report = classification_report(y_test, y_pred)

print(clf_report)