In [None]:
from typing import List, Dict, Set, Any, Optional, Tuple, Literal, Callable

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int, PRNGKeyArray
import aeon

# from features.sig_trp import SigVanillaTensorizedRandProj, SigRBFTensorizedRandProj
# from features.sig import SigTransform, LogSigTransform
# from features.base import TimeseriesFeatureTransformer, TabularTimeseriesFeatures, RandomNoInformation
# from features.sig_neural import RandomizedSignature
from utils import print_name, print_shape

jax.config.update('jax_platform_name', 'cpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

In [None]:
from aeon.transformations.collection.convolution_based import Rocket, MultiRocket


class RocketTransform:
    def __init__(self, prng_key = jax.random.PRNGKey(999)):
        self.prng_key = prng_key

    def rocket_features(
            self, 
            X: Float[Array, "batch  time  dim"],
        ) -> Float[Array, "batch  n_features"]:
        pass

    def init_biases_and_ridge_weights(
            self, 
            X_train: Float[Array, "batch  time  dim"]
        ):
        """Biases are initialized to be the quantiles of the random convolutions of the training data,
        and the ridge weights are initialized via efficient ridge leave-one-out CV.
        

        Args:
            X_train (Float[Array, "batch  time  dim"]): Training data.
        """
        pass


# use equinox to write my modules
# training loop with optax

In [4]:
import unittest

import numpy as np
import jax
import jax.numpy as jnp
import iisignature

from features.sig_trp import SigVanillaTensorizedRandProj

N=2
T=100
D=3
prng_key, trp_key = jax.random.split(jax.random.PRNGKey(0))

def test_approx_against_iisignature_trunc_level(trunc_level):
    # test specfic arguments
    n_features = 100000
    max_batch = N
    concat_levels=False

    #input brownian motions
    X, Y = jnp.cumsum(jax.random.normal(prng_key, (2, N, T, D)), axis=2) / np.sqrt(T* D)

    #compare
    trp = SigVanillaTensorizedRandProj(trp_key, n_features, trunc_level, max_batch, concat_levels).fit(X)
    X_out_trp = trp.transform(X) # / np.sqrt(n_features)
    Y_out_trp = trp.transform(Y) # / np.sqrt(n_features)
    dot_trp = jnp.dot(X_out_trp, Y_out_trp.T)

    X_out_sig = iisignature.sig(np.array(X), trunc_level)
    Y_out_sig = iisignature.sig(np.array(Y), trunc_level)
    dot_sig = np.dot(X_out_sig, Y_out_sig.T)

    print(dot_trp)
    print(dot_sig)
    print(np.allclose(np.array(dot_trp), dot_sig, atol=1e-3))
    print("\n")


def test_approx_against_iisignature():
    for i in range(1, 7):
        test_approx_against_iisignature_trunc_level(i)

test_approx_against_iisignature()



[[ 0.7317003  -0.2416127 ]
 [ 0.00615408  0.60030797]]
[[ 0.72179029 -0.2486082 ]
 [ 0.0119348   0.60579233]]
False


[[-0.11314843 -0.12268077]
 [-0.05167956  0.00792035]]
[[ 0.74107915 -0.24782069]
 [ 0.09816106  0.70934334]]
False


[[-0.03255501  0.01003426]
 [ 0.01521384 -0.00229101]]
[[ 0.71679987 -0.24961516]
 [ 0.11030181  0.71682938]]
False


[[ 9.44456995e-05 -1.05956018e-03]
 [-1.41740687e-03 -2.61345864e-04]]
[[ 0.71332502 -0.24974785]
 [ 0.11000816  0.71740739]]
False


[[ 4.43172394e-05  1.57129760e-04]
 [ 2.61033928e-04 -9.76931409e-05]]
[[ 0.71297997 -0.24974343]
 [ 0.11012502  0.71744964]]
False


[[ 1.32186612e-05 -3.50973769e-06]
 [ 7.39902977e-06  4.59126003e-06]]
[[ 0.71296687 -0.24974504]
 [ 0.11011793  0.71744927]]
False


