In [1]:
from typing import List, Dict, Set, Any, Optional, Tuple, Literal, Callable

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int
import aeon

from features.sig_trp import SigVanillaTensorizedRandProj, SigRBFTensorizedRandProj
from features.sig import SigTransform, LogSigTransform
from features.base import TimeseriesFeatureTransformer, TabularTimeseriesFeatures
from features.sig_neural import RandomizedSignature
from utils import print_name, print_shape

jax.config.update('jax_platform_name', 'cpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

  from .autonotebook import tqdm as notebook_tqdm


# aoen toolkit

In [2]:
# Print the different datasets
from aeon.datasets.tsc_datasets import multivariate, univariate, univariate_equal_length
from aeon.datasets import load_classification

def get_aeon_dataset(
        dataset_name:str, 
        ):
    """Loads a dataset from the UCR/UEA archive using 
    the aeon library.

    Args:
        dataset_name (str): Name of the dataset

    Returns:
        Tuple: 4-tuple of the form (X_train, y_train, X_test, y_test)
    """
    X_train, y_train = load_classification(dataset_name, split="train")
    X_test, y_test = load_classification(dataset_name, split="test")

    return X_train.transpose(0,2,1), y_train, X_test.transpose(0,2,1), y_test

univariate

{'ACSF1',
 'Adiac',
 'AllGestureWiimoteX',
 'AllGestureWiimoteY',
 'AllGestureWiimoteZ',
 'ArrowHead',
 'BME',
 'Beef',
 'BeetleFly',
 'BirdChicken',
 'CBF',
 'Car',
 'Chinatown',
 'ChlorineConcentration',
 'CinCECGTorso',
 'Coffee',
 'Computers',
 'CricketX',
 'CricketY',
 'CricketZ',
 'Crop',
 'DiatomSizeReduction',
 'DistalPhalanxOutlineAgeGroup',
 'DistalPhalanxOutlineCorrect',
 'DistalPhalanxTW',
 'DodgerLoopDay',
 'DodgerLoopGame',
 'DodgerLoopWeekend',
 'ECG200',
 'ECG5000',
 'ECGFiveDays',
 'EOGHorizontalSignal',
 'EOGVerticalSignal',
 'Earthquakes',
 'ElectricDevices',
 'EthanolLevel',
 'FaceAll',
 'FaceFour',
 'FacesUCR',
 'FiftyWords',
 'Fish',
 'FordA',
 'FordB',
 'FreezerRegularTrain',
 'FreezerSmallTrain',
 'Fungi',
 'GestureMidAirD1',
 'GestureMidAirD2',
 'GestureMidAirD3',
 'GesturePebbleZ1',
 'GesturePebbleZ2',
 'GunPoint',
 'GunPointAgeSpan',
 'GunPointMaleVersusFemale',
 'GunPointOldVersusYoung',
 'Ham',
 'HandOutlines',
 'Haptics',
 'Herring',
 'HouseTwenty',
 'Inli

In [None]:
X_train, y_train, X_test, y_test = get_aeon_dataset("ArrowHead")
print_name(X_train)

#obtain minirocket results for univariate time series

# Start here

In [3]:
#from aeon.transformations.collection.convolution_based import MiniRocketMultivariate
from preprocessing.timeseries_augmentation import normalize_mean_std_traindata, normalize_streams, augment_time, add_basepoint_zero
from aeon.classification.sklearn import RotationForestClassifier
#from sklearn.linear_model import RidgeCV
from sklearn.metrics import accuracy_score
import time

def train_and_test(
        dataset:str,
        transformer:TimeseriesFeatureTransformer,
        apply_augmentation:bool=True,
    ):
    train_X, train_y, test_X, test_y = get_aeon_dataset(dataset)
    train_X, test_X = normalize_streams(train_X, test_X, max_T=1000)
    print_name(train_X)
    print_name(test_X)
    # augment data
    train_X = lax.stop_gradient(jnp.array(train_X))
    test_X  = lax.stop_gradient(jnp.array(test_X))
    if apply_augmentation:
        train_X = add_basepoint_zero(train_X)
        train_X = augment_time(train_X)
        test_X  = add_basepoint_zero(test_X)
        test_X  = augment_time(test_X)

    # fit transformer
    t0 = time.time()
    transformer.fit(train_X)
    train_X = np.array(transformer.transform(train_X))
    test_X = np.array(transformer.transform(test_X))
    train_X, test_X = normalize_mean_std_traindata(train_X, test_X)
    t1 = time.time()
    print_name(train_X)
    print_name(test_X)
    print(f"Time to transform: {t1-t0} seconds")

    # train classifier      
    clf = RotationForestClassifier()
    clf.fit(train_X, train_y)
    t2 = time.time()
    print(f"Time to fit classifier on train: {t2-t1} seconds")

    # predict
    pred = clf.predict(test_X)
    acc = accuracy_score(test_y, pred)
    t3 = time.time()
    print(f"Time to predict: {t3-t2} seconds")
    print(f"{acc} accuracy for {transformer}")

In [None]:
#TODO: CVRidge in JAX

In [None]:
train_and_test(
    dataset="Adiac",
    transformer=SigVanillaTensorizedRandProj(
        jax.random.PRNGKey(999),
        n_features= 256,
        trunc_level=5,
        max_batch=10000,
        )
    )
# 0.6086956521739131 accuracy for SigVanillaTensorizedRandProj(max_batch=10000, n_features=256,
#                              seed=Array([  0, 999], dtype=uint32),
#                              trunc_level=5)

In [None]:
trp_key, rbf_key = jax.random.split(jax.random.PRNGKey(999))

train_and_test(
    dataset="Adiac",
    transformer=SigRBFTensorizedRandProj(
        trp_key,
        rbf_key,
        n_features= 256,
        trunc_level=3,
        rbf_dimension=1000,
        max_batch=10000,
        rff_max_batch=10000,
        )
    )
# 0.6445012787723785 accuracy for SigRBFTensorizedRandProj(max_batch=10000, n_features=256, rbf_dimension=1000,
#                          rff_max_batch=10000,
#                          rff_seed=Array([4116651765, 1982142802], dtype=uint32),
#                          trp_seed=Array([3655788082, 2541180754], dtype=uint32))


In [None]:
train_and_test(
    dataset="Adiac",
    transformer=TabularTimeseriesFeatures(),
    apply_augmentation=False
    )# 0.7902813299232737 accuracy for TabularTimeseriesFeatures()

In [None]:
train_and_test(
    dataset="Adiac",
    transformer=SigTransform(trunc_level=5),
    ) # 0.5498721227621484 accuracy for SigTransform(trunc_level=5)

In [None]:
train_and_test(
    dataset="Adiac",
    transformer=LogSigTransform(trunc_level=5),
    ) # 0.4884910485933504 accuracy for LogSigTransform(trunc_level=5)


In [4]:
train_and_test(
    dataset="Adiac",
    transformer=RandomizedSignature(
        jax.random.PRNGKey(999),
        n_features= 256,
        max_batch=10000,
        )
    )

(390, 176, 1) train_X
[[[-1.842e-01]
  [-5.851e-02]
  [-1.182e-03]
  ...
  [-4.549e-02]
  [-2.463e-01]
  [-2.213e-01]]

 [[ 3.760e-01]
  [ 3.148e-01]
  [ 2.498e-01]
  ...
  [-3.418e-01]
  [-4.698e-02]
  [ 2.311e-01]]

 [[ 4.911e-01]
  [ 4.442e-01]
  [ 4.474e-01]
  ...
  [ 4.725e-01]
  [ 5.350e-01]
  [ 4.552e-01]]

 ...

 [[ 5.000e+00]
  [ 5.000e+00]
  [ 4.429e+00]
  ...
  [ 3.822e+00]
  [ 3.266e+00]
  [ 2.850e+00]]

 [[ 1.520e-01]
  [ 1.759e-01]
  [ 1.195e-01]
  ...
  [ 5.933e-02]
  [ 1.694e-01]
  [ 1.726e-01]]

 [[ 5.144e-01]
  [ 6.201e-01]
  [ 7.316e-01]
  ...
  [ 1.694e-01]
  [ 1.910e-01]
  [ 3.986e-01]]] 

(391, 176, 1) test_X
[[[-1.396e+00]
  [-1.684e+00]
  [-1.902e+00]
  ...
  [ 2.630e-01]
  [-4.453e-01]
  [-1.046e+00]]

 [[-1.585e+00]
  [-1.967e+00]
  [-2.277e+00]
  ...
  [ 5.753e-01]
  [-3.895e-01]
  [-1.108e+00]]

 [[ 2.296e-01]
  [ 1.743e-01]
  [-1.807e-04]
  ...
  [-7.008e-02]
  [ 1.542e-01]
  [ 2.517e-01]]

 ...

 [[ 1.090e-01]
  [ 4.519e-01]
  [ 6.738e-01]
  ...
  [-4.784e