In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int

from utils import print_shape, print_name
from features.sig_trp import SigVanillaTensorizedRandProj, SigRBFTensorizedRandProj
from features.random_fourier_features import RandomFourierFeatures

jax.config.update('jax_platform_name', 'gpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

def gen_BM(N:int, T:int, D:int, seed=0):
    """
    Generate N Brownian Motions of length T with D dimensions
    """
    key = jax.random.PRNGKey(seed)
    normal = jax.random.normal(key, (N, T, D))
    BM = jnp.cumsum(normal, axis=1)
    return BM

In [None]:
N=10
D=100
rff_dim = 1000
seed = jax.random.PRNGKey(9999)
X = jax.random.normal(seed, (N, D))

rff = RandomFourierFeatures(seed, n_features=rff_dim, sigma=1.0, max_batch=2000)
rff.fit(X)
feat = rff.transform(X)
print_name(feat)

In [2]:
N=20
T=100
D=10
X = gen_BM(N, T, D)
seed = jax.random.PRNGKey(0)

linear_trp = SigVanillaTensorizedRandProj(seed, n_features=2000, trunc_level=5,
                                   max_batch=10)
linear_trp.fit(X)
feat = linear_trp.transform(X)
print_name(feat)

(20, 5, 2000) feat
[[[-5.194e-02 -2.418e-01  1.791e-01 ... -2.623e-01 -9.626e-02 -2.038e-01]
  [ 5.131e-03 -6.003e-02 -2.472e-02 ... -3.565e-02  1.808e-02 -8.411e-03]
  [ 2.400e-04 -1.489e-03  8.551e-04 ...  1.097e-03  3.647e-04 -1.172e-03]
  [-1.290e-05 -1.326e-04 -2.311e-04 ...  1.547e-04 -1.081e-04  4.824e-05]
  [ 6.825e-06 -1.684e-05 -1.880e-06 ...  5.406e-06 -2.249e-05  3.847e-06]]

 [[ 2.753e-01  5.542e-02 -2.051e-01 ... -6.964e-02  3.774e-02 -3.913e-01]
  [ 5.316e-02 -7.530e-03  9.990e-03 ...  2.330e-02 -2.055e-02  1.016e-02]
  [-5.522e-03 -1.577e-03  1.632e-04 ... -9.823e-04  4.257e-03  9.048e-03]
  [-2.558e-05 -1.106e-04 -1.003e-04 ...  1.471e-04  3.109e-05  8.953e-04]
  [-1.180e-05  1.627e-05 -8.178e-07 ... -2.812e-05  4.647e-05 -5.969e-05]]

 [[-6.139e-02 -3.341e-02 -1.820e-01 ...  7.483e-02  1.969e-01 -2.451e-01]
  [-6.863e-03  2.181e-03  4.595e-03 ...  1.232e-03  3.001e-02  1.862e-02]
  [ 7.598e-05 -1.926e-04  3.049e-04 ... -1.042e-03 -1.336e-03  9.122e-05]
  [-1.104e-05 -

In [7]:
N=20
T=100
D=10
X = gen_BM(N, T, D)
seed = jax.random.PRNGKey(0)
trp_seed, rff_seed = jax.random.split(seed, 2)
n_features = 200
trunc_level = 3
rbf_dimension = 512
sigma = 1.0
max_batch = 128
rff_max_batch = 2000

rbf_trp = SigRBFTensorizedRandProj(trp_seed, rff_seed, 
                n_features, trunc_level, rbf_dimension, 
                sigma, max_batch, rff_max_batch)
rbf_trp.fit(X)
feat = rbf_trp.transform(X)
print_name(feat)

(20, 3, 512) feat
[[[-2.449e-02 -3.868e-02 -1.076e-02 ... -5.431e-03  1.162e-02 -3.877e-03]
  [ 2.404e-04  4.441e-04  3.636e-04 ... -3.433e-04 -3.455e-04  7.852e-04]
  [ 7.113e-07 -2.158e-05 -3.006e-05 ... -1.651e-06 -8.906e-06  5.273e-05]]

 [[-1.153e-02  1.172e-02  4.146e-02 ... -6.604e-03 -1.676e-02  1.323e-02]
  [-7.761e-04 -6.656e-05  4.034e-05 ... -4.193e-04 -7.240e-05 -8.330e-04]
  [ 1.489e-05  7.772e-06 -8.273e-06 ...  2.521e-06 -1.529e-05  3.347e-05]]

 [[-7.163e-03  2.229e-02 -2.570e-02 ... -6.341e-03  2.897e-02  2.676e-02]
  [ 2.963e-04 -4.584e-04 -7.687e-04 ...  9.799e-04  8.406e-04  6.755e-04]
  [-8.818e-06 -5.723e-06 -5.069e-06 ...  5.809e-06  1.341e-05 -6.862e-07]]

 ...

 [[ 2.772e-02 -1.039e-02 -2.869e-02 ... -7.390e-02 -7.315e-03 -2.296e-02]
  [ 3.872e-04 -1.823e-05  5.722e-04 ... -2.004e-03 -1.039e-03 -7.782e-04]
  [ 2.375e-05 -2.024e-06 -4.394e-05 ...  5.154e-05 -1.027e-05  9.133e-06]]

 [[ 3.097e-02  4.805e-03 -6.746e-03 ...  2.390e-02 -3.228e-03  1.563e-03]
  [-2.