In [4]:
from typing import Tuple, List, Union, Any, Optional, Dict, Set, Literal, Callable
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int, PRNGKeyArray
from jax.random import PRNGKey
import aeon
import pandas as pd

from features.sig_trp import SigVanillaTensorizedRandProj, SigRBFTensorizedRandProj
from features.sig import SigTransform, LogSigTransform
from features.base import TimeseriesFeatureTransformer, TabularTimeseriesFeatures, RandomGuesser
from features.sig_neural import RandomizedSignature
from utils.utils import print_name, print_shape
from preprocessing.timeseries_augmentation import normalize_mean_std_traindata, normalize_streams, augment_time, add_basepoint_zero

jax.config.update('jax_platform_name', 'cpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

In [48]:
def init_single_SWIM_layer(
        X: Float[Array, "N  d"],
        y: Float[Array, "N  D"],
        n_features: int,
        seed: PRNGKeyArray,
    ) -> Tuple[Float[Array, "d  n_features"], Float[Array, "n_features"]]:
    """
    Fits the weights for the SWIM model, iteratively layer by layer

    Args:
        X (Float[Array, "N  d"]): Previous layer's output.
        n_features (int): Next hidden layer size.
        seed (PRNGKeyArray): Random seed for the weights and biases.
    Returns:
        Weights (d, n_features) and biases (1, n_features) for the next layer.
    """
    seed_idxs, seed_sample = jax.random.split(seed, 2)
    N, d = X.shape
    EPS = 1e-06

    #obtain pair indices
    n_pairs_pre = jnp.minimum(N * (N - 1), 3*n_features) #maybe this should be a parameter TODO
    idx1 = jnp.arange(0, n_pairs_pre)
    delta = jax.random.randint(seed_idxs, shape=(n_pairs_pre,), minval=1, maxval=N)
    idx2 = (idx1 + delta) % N

    #calculate 'gradients'
    dx = X[idx2] - X[idx1]
    dy = y[idx2] - y[idx1]
    dists = jnp.maximum(EPS, jnp.linalg.norm(dx, axis=1, keepdims=True) )
    gradients = (jnp.linalg.norm(dy, axis=1, keepdims=True) / dists ).reshape(-1)
    #gradients = (np.max(np.abs(dy), axis=1, keepdims=True) / dists ).reshape(-1) #NOTE paper uses this instead

    #sample pairs, weighted by gradients     NOTE make replace a parameter
    selected_idx = jax.random.choice(
        seed_sample, 
        n_pairs_pre,
        shape=(n_features,), 
        replace=True, 
        p=gradients/gradients.sum()
        )
    idx1 = idx1[selected_idx]
    dx = dx[selected_idx]
    dists = dists[selected_idx]
    
    #define weights and biases
    weights = (dx / dists**2).T
    biases = -jnp.sum(weights * X[idx1].T, axis=0, keepdims=True) - 0.5  # NOTE experiment with this. also +-0.5 ?
    return weights, biases


seed1, seed2, seed3 = jax.random.split(PRNGKey(0), 3)
N=10
d=2
dim_y = 3
n_features=6

X = jax.random.normal(seed1, (N, d))
y = jax.random.normal(seed2, (N, dim_y))

weights, biases = init_single_SWIM_layer(X, y, n_features, seed3)


###### TODO NEXT TODO -----  use lax.scan to iterate over layers, and to implement the forward pass
