In [1]:
from typing import List, Dict, Set, Any, Optional, Tuple, Literal, Callable

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int, PRNGKeyArray
import aeon

from features.sig_trp import SigVanillaTensorizedRandProj, SigRBFTensorizedRandProj
from features.sig import SigTransform, LogSigTransform
from features.base import TimeseriesFeatureTransformer, TabularTimeseriesFeatures, RandomNoInformation
from features.sig_neural import RandomizedSignature
from utils import print_name, print_shape

jax.config.update('jax_platform_name', 'gpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

  from .autonotebook import tqdm as notebook_tqdm


# aoen toolkit

In [2]:
# Print the different datasets
from aeon.datasets.tsc_datasets import multivariate, univariate, univariate_equal_length
from aeon.datasets import load_classification

def get_aeon_dataset(
        dataset_name:str, 
        ):
    """Loads a dataset from the UCR/UEA archive using 
    the aeon library.

    Args:
        dataset_name (str): Name of the dataset

    Returns:
        Tuple: 4-tuple of the form (X_train, y_train, X_test, y_test)
    """
    X_train, y_train = load_classification(dataset_name, split="train")
    X_test, y_test = load_classification(dataset_name, split="test")

    return X_train.transpose(0,2,1), y_train, X_test.transpose(0,2,1), y_test

#univariate

In [3]:
#from aeon.transformations.collection.convolution_based import MiniRocketMultivariate
from preprocessing.timeseries_augmentation import normalize_mean_std_traindata, normalize_streams, augment_time, add_basepoint_zero
from aeon.classification.sklearn import RotationForestClassifier
#from sklearn.linear_model import RidgeCV
from sklearn.metrics import accuracy_score
import time

def train_and_test(
        dataset:str,
        transformer:TimeseriesFeatureTransformer,
        apply_augmentation:bool=True,
    ):
    train_X, train_y, test_X, test_y = get_aeon_dataset(dataset)
    train_X, test_X = normalize_streams(train_X, test_X, max_T=1000)
    print_name(train_X)
    print_name(test_X)
    # augment data
    train_X = lax.stop_gradient(jnp.array(train_X))
    test_X  = lax.stop_gradient(jnp.array(test_X))
    if apply_augmentation:
        train_X = add_basepoint_zero(train_X)
        train_X = augment_time(train_X)
        test_X  = add_basepoint_zero(test_X)
        test_X  = augment_time(test_X)

    # fit transformer
    t0 = time.time()
    transformer.fit(train_X)
    train_X = np.array(transformer.transform(train_X))
    test_X = np.array(transformer.transform(test_X))
    train_X, test_X = normalize_mean_std_traindata(train_X, test_X)
    t1 = time.time()
    print_name(train_X)
    print_name(test_X)
    print(f"Time to transform: {t1-t0} seconds")

    # train classifier      
    clf = RotationForestClassifier()
    clf.fit(train_X, train_y)
    t2 = time.time()
    print(f"Time to fit classifier on train: {t2-t1} seconds")

    # predict
    pred = clf.predict(test_X)
    acc = accuracy_score(test_y, pred)
    t3 = time.time()
    print(f"Time to predict: {t3-t2} seconds")
    print(f"{acc} accuracy for {transformer}")

In [None]:
train_and_test(
    dataset="Adiac",
    transformer=SigVanillaTensorizedRandProj(
        jax.random.PRNGKey(999),
        n_features= 128,
        trunc_level=5,
        max_batch=2000,
        )
    )
# 0.6086956521739131 accuracy for SigVanillaTensorizedRandProj(max_batch=10000, n_features=256,
#                              seed=Array([  0, 999], dtype=uint32),
#                              trunc_level=5)

In [None]:
trp_key, rbf_key = jax.random.split(jax.random.PRNGKey(999))

train_and_test(
    dataset="Adiac",
    transformer=SigRBFTensorizedRandProj(
        trp_key,
        rbf_key,
        n_features= 256,
        trunc_level=3,
        rbf_dimension=1000,
        max_batch=10000,
        rff_max_batch=10000,
        )
    )
# 0.6445012787723785 accuracy for SigRBFTensorizedRandProj(max_batch=10000, n_features=256, rbf_dimension=1000,
#                          rff_max_batch=10000,
#                          rff_seed=Array([4116651765, 1982142802], dtype=uint32),
#                          trp_seed=Array([3655788082, 2541180754], dtype=uint32))

In [None]:
train_and_test(
    dataset="Adiac",
    transformer=TabularTimeseriesFeatures(),
    apply_augmentation=False
    )
# 0.7902813299232737 accuracy for TabularTimeseriesFeatures()

In [None]:
train_and_test(
    dataset="Adiac",
    transformer=SigTransform(trunc_level=5),
    ) 
# 0.5498721227621484 accuracy for SigTransform(trunc_level=5)

In [None]:
train_and_test(
    dataset="Adiac",
    transformer=LogSigTransform(trunc_level=5),
    ) 
# 0.4884910485933504 accuracy for LogSigTransform(trunc_level=5)


In [None]:
train_and_test(
    dataset="Adiac",
    transformer=RandomizedSignature(
        jax.random.PRNGKey(999),
        n_features= 128,
        max_batch=10000,
        )
    )
# 0.27365728900255754 accuracy for RandomizedSignature(max_batch=10000, n_features=128,
#                     seed=Array([  0, 999], dtype=uint32))

In [None]:
train_and_test(
    dataset="Adiac",
    transformer=RandomNoInformation(
        jax.random.PRNGKey(999),
        n_features= 64,
        )
    )
# 0.03324808184143223 accuracy for RandomNoInformation(n_features=64,
#                     seed=Array([1508125853,  174035561], dtype=uint32))

# Ridge

In [6]:
from sklearn.linear_model import RidgeClassifierCV

def train_and_test_ridge(
        dataset:str,
        transformer:TimeseriesFeatureTransformer,
        apply_augmentation:bool=True,
    ):
    train_X, train_y, test_X, test_y = get_aeon_dataset(dataset)
    train_X, test_X = normalize_streams(train_X, test_X, max_T=1000)
    print_shape(train_X)
    print_shape(test_X)
    # augment data
    train_X = lax.stop_gradient(jnp.array(train_X))
    test_X  = lax.stop_gradient(jnp.array(test_X))
    if apply_augmentation:
        train_X = add_basepoint_zero(train_X)
        train_X = augment_time(train_X)
        test_X  = add_basepoint_zero(test_X)
        test_X  = augment_time(test_X)

    # fit transformer
    t0 = time.time()
    transformer.fit(train_X)
    feat_train_X = np.array(transformer.transform(train_X))
    feat_test_X = np.array(transformer.transform(test_X))
    print("Before normalization:")
    print_name(feat_test_X)
    feat_train_X, feat_test_X = normalize_mean_std_traindata(feat_train_X, feat_test_X)
    t1 = time.time()
    print("After normalization:")
    print_name(feat_test_X)
    print(f"Time to transform: {t1-t0} seconds")

    # train classifier      
    clf = RidgeClassifierCV(alphas=np.logspace(-3, 3, 100))
    clf.fit(feat_train_X, train_y)
    t2 = time.time()
    print(f"Chosen alpha: {clf.alpha_}")
    print(f"Time to fit classifier on train: {t2-t1} seconds")

    # predict
    pred = clf.predict(feat_test_X)
    print(pred)
    acc = accuracy_score(test_y, pred)
    t3 = time.time()
    print(f"Time to predict: {t3-t2} seconds")
    print(f"{acc} accuracy for {transformer}")

In [7]:
train_and_test_ridge(
    dataset="GunPoint",
    transformer=SigVanillaTensorizedRandProj(
        jax.random.PRNGKey(999),
        n_features= 10000,
        trunc_level=3,
        max_batch=20,
        )
    )

(50, 150, 1) train_X 

(150, 150, 1) test_X 

Before normalization:
(150, 10000) feat_test_X
[[-5.305e-04 -1.540e-02 -3.204e-02 ... -4.086e-01 -1.227e-01 -4.070e-01]
 [-4.436e-02  2.216e-01 -9.703e-02 ...  5.936e-01 -1.397e-01 -1.073e+00]
 [-1.087e-02  6.769e-01 -9.241e-01 ... -3.232e-01 -1.407e+00 -6.737e+00]
 ...
 [-2.878e-03  1.318e-01  3.847e-02 ...  2.486e-01 -7.167e-02 -5.431e-01]
 [-5.253e-03  7.485e-02 -1.062e-01 ... -9.039e-01 -3.531e-01 -1.306e+00]
 [-2.154e-03 -1.980e-03 -1.250e-01 ... -7.798e-01 -4.689e-01 -8.564e-01]] 

After normalization:
(150, 10000) feat_test_X
[[-5.305e-04 -1.540e-02 -3.204e-02 ... -4.086e-01 -1.227e-01 -4.070e-01]
 [-4.436e-02  2.216e-01 -9.703e-02 ...  5.936e-01 -1.397e-01 -1.073e+00]
 [-1.087e-02  6.769e-01 -9.241e-01 ... -3.232e-01 -1.407e+00 -6.737e+00]
 ...
 [-2.878e-03  1.318e-01  3.847e-02 ...  2.486e-01 -7.167e-02 -5.431e-01]
 [-5.253e-03  7.485e-02 -1.062e-01 ... -9.039e-01 -3.531e-01 -1.306e+00]
 [-2.154e-03 -1.980e-03 -1.250e-01 ... -7.798

In [None]:
trp_key, rbf_key = jax.random.split(jax.random.PRNGKey(999))

train_and_test_ridge(
    dataset="GunPoint",
    transformer=SigRBFTensorizedRandProj(
        trp_key,
        rbf_key,
        n_features= 10000,
        trunc_level=3,
        rbf_dimension=1000,
        max_batch=10,
        rff_max_batch=10000,
        )
    )

In [None]:
train_and_test_ridge(
    dataset="GunPoint",
    transformer=TabularTimeseriesFeatures(),
    apply_augmentation=False
    )

In [None]:
train_and_test_ridge(
    dataset="GunPoint",
    transformer=SigTransform(trunc_level=5),
    ) 

In [None]:
train_and_test_ridge(
    dataset="GunPoint",
    transformer=LogSigTransform(trunc_level=5),
    ) 

In [None]:
train_and_test_ridge(
    dataset="GunPoint",
    transformer=RandomizedSignature(
        jax.random.PRNGKey(999),
        n_features= 1000,
        max_batch=1,
        )
    )

In [None]:
train_and_test_ridge(
    dataset="GunPoint",
    transformer=RandomNoInformation()
    )