In [1]:
import math
from typing import NamedTuple, Tuple

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
import numpy as np

from dataclasses import dataclass

In [2]:
@dataclass
class ModelParams:
    """
    model's parameters
    """
    count: ArrayLike
    projection_matrix: ArrayLike
    xtx: ArrayLike
    xty: ArrayLike
    w: ArrayLike

In [18]:
@dataclass
class HyperParams:
    """
    Hyperparameters for the model.
    """
    input_dim: int = 1
    output_dim: int = 1
    num_features: int = 2
    num_bins: int = 5
    feature_dim: int = 2
    edges: int = num_bins + 1
    num_grid_per_feature = edges ** feature_dim
    embedding_dim: int = num_grid_per_feature * num_features
    eps: float = 1e-5
hyper_params = HyperParams()

In [19]:
@dataclass
class Data:
    """
    Dataset initialization
    """
    num_tests: int = 200
    num_trains: int = num_tests // 20
    x_train: ArrayLike = jnp.linspace(-jnp.pi, jnp.pi, num_trains).reshape(num_trains, HyperParams.input_dim)
    y_train: ArrayLike = jnp.sin(x_train)
    x_test: ArrayLike = jnp.linspace(-jnp.pi, jnp.pi, num_tests).reshape(num_tests, HyperParams.input_dim)

In [20]:
# PRNGKey init
seed = 0
rng = jax.random.PRNGKey(seed)

In [21]:
# Initialize the model's parameters
def init(
    rng: ArrayLike,
    hyper_parasm: HyperParams,
) -> ModelParams:
    _rng, rng = jax.random.split(rng)
    std = 1 / jnp.sqrt(hyper_params.input_dim)
    # clip the projection matrix to avoid numerical instability
    proj_matrix_clip = 2
    projection_matrix = std * jax.random.truncated_normal(
        _rng,
        lower=-proj_matrix_clip,
        upper=proj_matrix_clip,
        shape=(hyper_params.num_features, hyper_params.input_dim, hyper_params.feature_dim),
    )
    return rng, ModelParams(
        count=jnp.zeros(0, dtype=jnp.int32),
        projection_matrix=projection_matrix,
        xtx=jnp.zeros(
            HyperParams.embedding_dim * HyperParams.embedding_dim,
            dtype=projection_matrix.dtype,
        ),
        xty=jnp.zeros(
            (HyperParams.embedding_dim, HyperParams.output_dim),
            dtype=projection_matrix.dtype,
        ),
        w=jnp.zeros(
            (HyperParams.embedding_dim, HyperParams.output_dim),
            dtype=projection_matrix.dtype,
        ),
    )

In [22]:
rng, model_params = init(rng, hyper_params)

In [23]:
def sparse_indices_and_values(projection_matrix: ArrayLike, x: ArrayLike, hyper_params: HyperParams) -> ArrayLike :
    """
    Calculate the indices and values of the sparse embedding vector.
    """
    # latent = jnp.matmul(x, projection_matrix)
    latent = jax.vmap(jnp.matmul, in_axes=(None, 0), out_axes=-1)(x, projection_matrix)
    # assert latent.shape == (batch=1, 2, hyper_params.num_features)
    latent = jax.nn.sigmoid(latent)
    latent = latent * hyper_params.num_bins
    indices = jnp.floor(latent).astype(jnp.int32)
    # offsets = latent - indices
    
    # indices = jnp.reshape(indices, (-1, bin_dim, n_feat))
    pass 

def update_memory():
    pass

def update_w():
    pass

In [89]:
def update(model_params: ModelParams, x: ArrayLike, y: ArrayLike, hyper_params: HyperParams) -> ModelParams:
    """
    Update the model's parameters.
    """
    assert x.shape == (1, hyper_params.input_dim)  # non batched
    assert y.shape == (1, hyper_params.output_dim)  # non batched
    sparse_indices_and_values(model_params.projection_matrix, x, hyper_params)
    update_memory()
    update_w()
    pass

In [90]:
# Online update
for i in range(Data.num_trains):
    x = Data.x_train[i : i + 1]
    y = Data.y_train[i : i + 1]
    model_state = update(model_params, x, y, hyper_params)

In [8]:
x = Data.x_train[1 : 1 + 1]
y = Data.y_train[1 : 1 + 1]

In [None]:
latent = jax.vmap(jnp.matmul, in_axes=(None, 0), out_axes=-1)(x, model_params.projection_matrix)
assert latent.shape == (1, hyper_params.feature_dim, hyper_params.num_features)
latent = jax.nn.sigmoid(latent)
latent = latent * hyper_params.num_bins
indices = jnp.floor(latent).astype(jnp.int32)
offsets = latent - indices
indices = jnp.stack([indices, indices + 1], axis=-1) # (1, hyper_params.feature_dim, hyper_params.num_features， 2)
values = jnp.stack([offsets, 1.0 - offsets], axis=-1)
multiplier = jnp.power(hyper_params.edges, jnp.arange(hyper_params.feature_dim - 1, -1, -1))
# indices *= multiplier[None, :, None, None]  # 对每一个feature的维度乘以不同的倍数，以便后续相加

In [33]:
indices, indices.shape, multiplier

(Array([[[[2, 3],
          [3, 4]],
 
         [[0, 1],
          [2, 3]]]], dtype=int32),
 (1, 2, 2, 2),
 Array([6, 1], dtype=int32))

In [35]:
indices *= multiplier[None, :, None, None]
indices

Array([[[[12, 18],
         [18, 24]],

        [[ 0,  1],
         [ 2,  3]]]], dtype=int32)

In [26]:
shape_suffix = [tuple(*p) for p in np.split(np.eye(hyper_params.feature_dim, dtype=np.int32) + 1, hyper_params.feature_dim)]
shape_suffix

[(np.int32(2), np.int32(1)), (np.int32(1), np.int32(2))]

In [36]:
sum(jnp.reshape(indices[:, i], (-1, hyper_params.num_features, *suffix)) for i, suffix in enumerate(shape_suffix))

Array([[[[12, 13],
         [18, 19]],

        [[20, 21],
         [26, 27]]]], dtype=int32)

In [15]:
[jnp.reshape(indices[:, i], (-1, hyper_params.num_features, *suffix)) for i, suffix in enumerate(shape_suffix)]

[Array([[[[ 0],
          [ 6]],
 
         [[ 0],
          [ 6]],
 
         [[12],
          [18]],
 
         [[24],
          [30]],
 
         [[24],
          [30]],
 
         [[24],
          [30]],
 
         [[ 0],
          [ 6]],
 
         [[18],
          [24]],
 
         [[12],
          [18]],
 
         [[12],
          [18]],
 
         [[12],
          [18]],
 
         [[ 0],
          [ 6]],
 
         [[18],
          [24]],
 
         [[12],
          [18]],
 
         [[24],
          [30]],
 
         [[ 6],
          [12]],
 
         [[ 0],
          [ 6]],
 
         [[ 6],
          [12]],
 
         [[ 0],
          [ 6]],
 
         [[ 0],
          [ 6]],
 
         [[ 0],
          [ 6]],
 
         [[ 6],
          [12]],
 
         [[ 0],
          [ 6]],
 
         [[18],
          [24]],
 
         [[ 0],
          [ 6]],
 
         [[12],
          [18]],
 
         [[ 0],
          [ 6]],
 
         [[24],
          [30]],
 
         [[24],
    

In [11]:
indices.shape

(1, 2, 50, 2)

In [216]:
indices[:, :, 0, :][:, :, :, None] + indices[:, :, 1, :][:, :, None, :]

Array([[[[ 4,  5],
         [ 8,  9]],

        [[ 9, 10],
         [13, 14]]]], dtype=int32)

In [99]:
shape_suffix = [tuple(*p) for p in np.split(np.eye(hyper_params.feature_dim, dtype=np.int32) + 1, hyper_params.feature_dim)]
shape_suffix

[(np.int32(2), np.int32(1)), (np.int32(1), np.int32(2))]

In [101]:
indices = sum(jnp.reshape(indices[:, i], (-1, hyper_params.num_features, *suffix)) for i, suffix in enumerate(shape_suffix))

In [104]:
indices += jnp.expand_dims(
    hyper_params.num_grid_per_feature * jnp.arange(hyper_params.num_features), axis=tuple(range(-hyper_params.feature_dim, 1, 1))
)
indices = jnp.reshape(indices, (-1, hyper_params.num_features * 2**hyper_params.feature_dim))
indices

Array([[ 5,  6, 10, 11]], dtype=int32)

In [106]:
math.prod(jnp.reshape(values[:, i], (-1, hyper_params.num_features, *suffix)) for i, suffix in enumerate(shape_suffix))

Array([[[[0.08443601, 0.063721  ],
         [0.48547298, 0.36637005]]]], dtype=float32)

In [108]:
math.prod(jnp.reshape(values[:, i], (-1, hyper_params.num_features, *suffix)) for i, suffix in enumerate(shape_suffix))

Array([[[[0.36637005, 0.48547298],
         [0.063721  , 0.08443601]]]], dtype=float32)