In [5]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int

from utils import print_shape, print_name
from features.sig_trp import SigVanillaTensorizedRandProj

jax.config.update('jax_platform_name', 'cpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

def gen_BM(N:int, T:int, D:int, seed=0):
    """
    Generate N Brownian Motions of length T with D dimensions
    """
    key = jax.random.PRNGKey(seed)
    normal = jax.random.normal(key, (N, T, D))
    BM = jnp.cumsum(normal, axis=1)
    return BM

In [16]:
N=12
T=1000
D=100
X = gen_BM(N, T, D)

trp = SigVanillaTensorizedRandProj(n_features=2000, trunc_level=5,
                                   max_batch=8)
trp.fit(X)
feat = trp.transform(X)
print_name(feat)

(12, 5, 2000) feat
[[[ 2.982e-01  3.793e-01 -8.266e-01 ... -6.623e-01 -3.995e-01  4.455e-01]
  [ 9.572e-02  1.584e-01  2.819e-01 ... -1.369e-02  2.446e-01 -1.110e-01]
  [-1.024e-01 -3.014e-02  1.563e-01 ... -1.497e-02 -4.037e-02 -1.880e-01]
  [-3.626e-02  6.335e-03 -4.583e-02 ...  6.939e-03 -1.852e-02 -1.751e-02]
  [ 2.251e-02 -4.565e-03 -9.680e-03 ... -2.429e-03 -3.692e-03 -6.093e-03]]

 [[-6.347e-01  1.681e+00 -2.687e-01 ... -7.269e-01  3.379e-02  2.455e-01]
  [-1.822e-01 -1.003e+00 -5.391e-01 ... -4.616e-01 -2.255e-01  4.308e-02]
  [-4.746e-02  1.852e-01 -6.477e-04 ... -3.067e-01  3.664e-02  1.611e-03]
  [ 1.848e-03 -3.166e-02  3.662e-03 ... -7.386e-02  1.676e-02  5.587e-03]
  [ 3.281e-03 -9.992e-03 -1.658e-03 ...  6.911e-03 -2.730e-03  2.302e-03]]

 [[-4.780e-01  3.697e-01  2.827e-02 ... -1.631e-01  8.570e-01  4.148e-01]
  [-2.142e-01 -5.473e-02  2.220e-01 ...  3.771e-02  8.302e-02 -3.224e-01]
  [-2.613e-02  6.401e-02 -4.391e-02 ... -1.500e-02 -1.946e-02  4.867e-02]
  [ 4.711e-03 -