In [2]:
import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int

from utils import print_shape, print_name
from preprocessing.timeseries_augmentation import normalize_mean_std, normalize_mean_std_traindata
from preprocessing.timeseries_augmentation import avg_pool_time, augment_time, add_basepoint_zero, I_visibility, T_visibility

jax.config.update('jax_platform_name', 'cpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

In [6]:
seed = 1234
key = jax.random.key(seed)
X_key, Y_key = jax.random.split(key)

N1 = 20
N2 = 10
T = 105
max_T = 30
D = 4
X = jax.random.normal(X_key, (N1, T, D))
Y = jax.random.normal(Y_key, (N2, T, D))

X_normalized = normalize_mean_std(X)
X, Y = normalize_mean_std_traindata(X, Y)
pooled_X = avg_pool_time(X, max_T)
augment_time_X = augment_time(X)
add_basepoint_X = add_basepoint_zero(X)
I_visibility_X = I_visibility(X)
T_visibility_X = T_visibility(X)

print_name(X)
print_name(pooled_X)
print_name(augment_time_X)
print_name(add_basepoint_X)
print_name(I_visibility_X)
print_name(T_visibility_X)

(20, 105, 4) X
[[[ 2.028  2.748  1.458 -0.99 ]
  [-1.207  1.335 -1.078  1.591]
  [ 1.876  2.188 -0.285 -0.63 ]
  ...
  [ 0.303  2.199 -0.181  0.55 ]
  [-0.429 -1.465 -0.101 -0.196]
  [ 0.11  -1.096  0.195  1.48 ]]

 [[-0.681  0.026  0.812  0.608]
  [-0.24   0.613  0.72  -0.414]
  [ 0.482  0.101 -1.315  0.501]
  ...
  [-0.463  0.181 -0.63   1.534]
  [ 0.677  0.084  0.509 -0.239]
  [-1.093  0.067  0.622 -0.851]]

 [[-1.017  0.651  1.07   0.114]
  [ 0.135  0.18   0.594 -0.573]
  [ 0.349 -0.831  0.064  0.03 ]
  ...
  [-0.551  0.369  1.851  0.749]
  [ 0.335 -2.068 -1.281  0.255]
  [ 0.021 -0.718 -1.189 -0.511]]

 ...

 [[-0.643  0.429 -0.19   0.26 ]
  [ 1.228  0.461 -1.326 -1.254]
  [ 0.51  -1.135  0.306  0.841]
  ...
  [-1.607 -1.874  1.447 -0.117]
  [-1.713  0.605 -0.186  1.908]
  [ 1.891  0.489  0.614 -0.588]]

 [[ 0.493 -0.589 -0.612 -0.569]
  [-0.605 -0.37  -0.424  0.625]
  [ 0.523 -0.035 -0.034 -0.11 ]
  ...
  [-0.645 -0.342 -1.601  0.44 ]
  [-0.292  0.556 -0.965  0.242]
  [ 0.034 -1.