In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int

from utils import print_shape, print_name
from features.sig_trp import SigVanillaTensorizedRandProj, SigRBFTensorizedRandProj
from features.random_fourier_features import RandomFourierFeatures

jax.config.update('jax_platform_name', 'cpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

def gen_BM(N:int, T:int, D:int, seed=0):
    """
    Generate N Brownian Motions of length T with D dimensions
    """
    key = jax.random.PRNGKey(seed)
    normal = jax.random.normal(key, (N, T, D))
    BM = jnp.cumsum(normal, axis=1)
    return BM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
N=10
D=100
rff_dim = 1020
seed = jax.random.PRNGKey(9999)
X = jax.random.normal(seed, (N, D))

rff = RandomFourierFeatures(seed, n_features=rff_dim, sigma=1.0, max_batch=2000)
rff.fit(X)
feat = rff.transform(X)
print_name(feat) #1.7

(10, 1020) feat
[[-0.014  0.022  0.002 ... -0.019 -0.003 -0.031]
 [-0.016  0.031 -0.01  ... -0.018  0.012  0.014]
 [-0.03   0.024  0.    ... -0.03   0.03   0.029]
 ...
 [ 0.026 -0.018 -0.025 ... -0.031 -0.031 -0.028]
 [-0.018  0.021 -0.005 ...  0.03  -0.018 -0.014]
 [ 0.014  0.006  0.03  ...  0.004  0.031  0.028]] 



In [13]:
N=200
T=1000
D=10
X = gen_BM(N, T, D)
seed = jax.random.PRNGKey(0)

max_batch=10
trunc_level=5
n_features=500

linear_trp = SigVanillaTensorizedRandProj(seed, n_features, trunc_level, max_batch)
linear_trp.fit(X)
feat = linear_trp.transform(X)
print_name(feat) #0.3

(200, 5, 500) feat
[[[ 1.104e+00 -4.391e-01  1.431e+00 ...  1.606e+00  1.707e+00  1.224e-01]
  [-5.010e-01 -1.444e-02 -1.178e+00 ... -1.780e+00  3.747e-01 -3.762e-01]
  [-3.445e-01  1.319e-01 -8.764e-01 ...  3.357e-01  4.559e-01  5.116e-02]
  [-3.039e-01  2.454e-04  4.710e-01 ... -1.177e-01 -2.809e-02 -3.352e-01]
  [ 6.265e-02 -8.943e-02  3.847e-01 ...  4.774e-03  7.272e-02 -1.130e-01]]

 [[ 2.321e+00 -1.310e+00  1.892e+00 ...  1.300e-01  1.077e+00  1.461e+00]
  [ 1.359e+00 -1.087e+00 -8.916e-01 ... -1.263e+00 -1.807e+00 -1.477e+00]
  [-3.302e-01 -1.747e-01 -5.481e-01 ...  1.872e+00  2.270e-01  1.574e+00]
  [-1.272e-01 -1.195e-01  1.128e-01 ...  3.480e-01 -1.876e-02 -4.041e-01]
  [-8.540e-03 -1.342e-01  3.079e-02 ... -4.087e-01 -1.072e-01 -2.339e-02]]

 [[-2.288e+00 -1.357e+00  1.209e+00 ...  1.500e+00 -2.696e+00 -2.246e+00]
  [ 9.931e-01 -1.233e+00  1.635e+00 ...  9.532e-01  7.869e-02 -6.842e-01]
  [-7.038e-01  1.606e+00 -9.057e-01 ...  1.201e-01 -2.180e+00 -1.526e-02]
  [ 3.296e-01 -

In [4]:
N = 100
max_batch = 70
divmod(N, max_batch)

(1, 30)

In [5]:
N=100
max_batch = 22


T=100
D=10
X = gen_BM(N, T, D)
seed = jax.random.PRNGKey(0)
trp_seed, rff_seed = jax.random.split(seed, 2)
n_features = 2000
trunc_level = 3
rbf_dimension = 512
sigma = 1.0
rff_max_batch = 2000

rbf_trp = SigRBFTensorizedRandProj(trp_seed, rff_seed, 
                n_features, trunc_level, rbf_dimension, 
                sigma, max_batch, rff_max_batch)
rbf_trp.fit(X)

print("now\n\n\n")
feat = rbf_trp.transform(X)
print_name(feat) #0.9

now



(100, 3, 2000) feat
[[[-7.013e-04 -1.980e-04 -4.662e-04 ...  4.850e-04  1.275e-04  1.775e-04]
  [-6.172e-06 -4.966e-06  1.112e-06 ... -1.581e-06  3.294e-07 -7.030e-06]
  [-3.680e-09 -1.131e-08 -1.908e-09 ...  2.765e-08  4.092e-09  1.072e-08]]

 [[ 1.500e-03  5.899e-04  4.793e-04 ...  2.591e-04 -4.839e-04 -6.523e-04]
  [-1.861e-05  7.011e-06 -5.312e-06 ...  1.443e-05  2.312e-07  2.029e-06]
  [-3.934e-09  5.460e-09 -4.960e-09 ... -2.679e-08 -1.100e-09 -9.370e-09]]

 [[-3.583e-06  6.109e-04  8.921e-04 ...  1.699e-03 -6.375e-04  1.091e-03]
  [-1.058e-05  9.343e-06  9.211e-06 ...  4.048e-07  4.882e-07  2.742e-06]
  [ 1.221e-08  1.183e-08 -9.901e-09 ... -2.004e-08 -3.932e-09  6.120e-09]]

 ...

 [[-8.697e-04 -7.082e-04 -7.229e-04 ...  4.425e-04  1.014e-03  7.373e-04]
  [-1.037e-05  3.799e-06  3.638e-07 ... -2.578e-07  3.433e-07 -3.549e-06]
  [-4.492e-09 -4.415e-09  1.196e-09 ...  7.695e-09 -8.808e-10  3.526e-09]]

 [[-1.250e-04 -3.817e-04 -1.490e-03 ... -7.405e-04 -1.020e-03  7.201e-0