In [4]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int

from utils.utils import print_shape, print_name
from utils.gen_data import gen_BM
from features.sig_trp import SigVanillaTensorizedRandProj, SigRBFTensorizedRandProj
from features.random_fourier_features import RandomFourierFeatures

jax.config.update('jax_platform_name', 'cpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options



In [3]:
N=10
D=100
rff_dim = 1020
seed = jax.random.PRNGKey(9999)
X = jax.random.normal(seed, (N, D))

rff = RandomFourierFeatures(seed, n_features=rff_dim, sigma=1.0, max_batch=2000)
rff.fit(X)
feat = rff.transform(X)
print_name(feat)

(10, 1020) feat
[[-0.014  0.022  0.002 ... -0.019 -0.003 -0.031]
 [-0.016  0.031 -0.01  ... -0.018  0.012  0.014]
 [-0.03   0.024  0.    ... -0.03   0.03   0.029]
 ...
 [ 0.026 -0.018 -0.025 ... -0.031 -0.031 -0.028]
 [-0.018  0.021 -0.005 ...  0.03  -0.018 -0.014]
 [ 0.014  0.006  0.03  ...  0.004  0.031  0.028]] 



In [5]:
N=200
T=1000
D=10
X = gen_BM(N, T, D)
seed = jax.random.PRNGKey(0)

max_batch=10
trunc_level=5
n_features=500

linear_trp = SigVanillaTensorizedRandProj(seed, n_features, trunc_level, max_batch)
linear_trp.fit(X)
feat = linear_trp.transform(X)
print_name(feat)

(200, 2500) feat
[[ 3.491e+00 -1.389e+00  4.524e+00 ...  3.774e+05  5.749e+06 -8.935e+06]
 [ 7.341e+00 -4.142e+00  5.985e+00 ... -3.231e+07 -8.474e+06 -1.849e+06]
 [-7.234e+00 -4.291e+00  3.824e+00 ... -2.768e+07 -4.017e+07 -1.715e+07]
 ...
 [ 8.514e-01  2.538e+00 -1.787e+00 ... -3.527e+06 -1.322e+06 -1.371e+06]
 [ 1.838e+00 -4.224e+00  4.689e+00 ... -7.838e+05  8.610e+06 -4.785e+06]
 [ 4.043e+00  5.993e+00 -4.015e+00 ...  2.302e+06  2.067e+05 -5.566e+07]] 



In [10]:
N=100
max_batch = 22


T=100
D=10
X = gen_BM(N, T, D)
seed = jax.random.PRNGKey(0)
n_features = 2000
trunc_level = 3
rbf_dimension = 512
sigma = 1.0
rff_max_batch = 2000

rbf_trp = SigRBFTensorizedRandProj(seed, 
                n_features, trunc_level, rbf_dimension, 
                sigma, max_batch, rff_max_batch)
rbf_trp.fit(X)

print("now\n\n\n")
feat = rbf_trp.transform(X)
print_name(feat)

now



(100, 6000) feat
[[-1.587e-02 -4.481e-03 -1.055e-02 ...  6.407e-01  9.482e-02  2.484e-01]
 [ 3.395e-02  1.335e-02  1.085e-02 ... -6.206e-01 -2.549e-02 -2.171e-01]
 [-8.108e-05  1.382e-02  2.019e-02 ... -4.643e-01 -9.111e-02  1.418e-01]
 ...
 [-1.968e-02 -1.603e-02 -1.636e-02 ...  1.783e-01 -2.041e-02  8.169e-02]
 [-2.829e-03 -8.637e-03 -3.370e-02 ... -3.632e-02 -8.979e-02 -6.773e-02]
 [-2.575e-02 -2.767e-02 -3.243e-02 ... -1.920e-01 -5.353e-03 -8.936e-02]] 

