In [22]:
import math
from typing import NamedTuple, Tuple

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
import numpy as np

from dataclasses import dataclass

In [23]:
@dataclass
class ModelParams:
    """
    model's parameters
    """
    count: ArrayLike
    projection_matrix: ArrayLike
    xtx: ArrayLike
    xty: ArrayLike
    w: ArrayLike

In [26]:
@dataclass
class HyperParams:
    """
    Hyperparameters for the model.
    """
    input_dim: int = 1
    output_dim: int = 1
    num_features: int = 50
    num_bins: int = 5
    feature_dim: int = 2
    edges: int = num_bins + 1
    num_grid_per_feature = edges ** feature_dim
    embedding_dim: int = num_grid_per_feature * num_features
    eps: float = 1e-5

In [25]:
@dataclass
class Data:
    """
    Dataset initialization
    """
    num_tests: int = 200
    num_trains: int = num_tests // 20
    x_train: ArrayLike = jnp.linspace(-jnp.pi, jnp.pi, num_trains).reshape(num_trains, HyperParams.input_dim)
    y_train: ArrayLike = jnp.sin(x_train)
    x_test: ArrayLike = jnp.linspace(-jnp.pi, jnp.pi, num_tests).reshape(num_tests, HyperParams.input_dim)

In [None]:
# PRNGKey init
seed = 0
rng = jax.random.PRNGKey(seed)

In [27]:
# Initialize the model's parameters
def init(
    rng: ArrayLike,
    input_dim: int,
    output_dim: int,
    num_features: int,
    num_bins: int,
    feature_dim: int,
    eps: float,
    Mo,
) -> ModelParams:
    _rng, rng = jax.random.split(rng)
    std = 1 / jnp.sqrt(input_dim)
    # clip the projection matrix to avoid numerical instability
    proj_matrix_clip = 2
    projection_matrix = std * jax.random.truncated_normal(
        _rng,
        lower=-proj_matrix_clip,
        upper=proj_matrix_clip,
        shape=(num_features, input_dim, feature_dim),
    )
    return rng, ModelParams(
        count=jnp.zeros(0, dtype=jnp.int32),
        projection_matrix=projection_matrix,
        xtx=jnp.zeros(
            HyperParams.embedding_dim * HyperParams.embedding_dim,
            dtype=projection_matrix.dtype,
        ),
        xty=jnp.zeros(
            (HyperParams.embedding_dim, HyperParams.output_dim),
            dtype=projection_matrix.dtype,
        ),
        w=jnp.zeros(
            (HyperParams.embedding_dim, HyperParams.output_dim),
            dtype=projection_matrix.dtype,
        ),
    )

In [None]:
rng, ModelParams

In [None]:
ModelParams


AttributeError: type object 'ModelParams' has no attribute 'w'

In [21]:
m1 = ModelParams()

TypeError: ModelParams.__init__() missing 5 required positional arguments: 'count', 'projection_matrix', 'xtx', 'xty', and 'w'