## Model

The input data is a coordinate vector, $\mathbf{x}_\phi$, of the image coordinates.

$$
\mathbf{x}_\phi \in \mathbb{R}^{D_\phi}
$$

where $D_\phi = [\text{x}, \text{y}]$. So we are interested in learning a function, $\boldsymbol{f}$, such that we can input a coordinate vector and output a scaler/vector value of the pixel value.

$$
\mathbf{u} = \boldsymbol{f}(\mathbf{x}_\phi; \boldsymbol{\theta})
$$

**Identity**

$$
\gamma_T(t) = t
$$

In [17]:
from jejeqx._src.nets import time_net

latent_dim = 128

time_encoder = time_net.TimeIdentity(out_features=latent_dim, key=jrandom.PRNGKey(123))

out = time_encoder(t_init[0])

assert out.shape == (latent_dim,)

out_batch = jax.vmap(time_encoder)(t_init)

assert out_batch.shape == (t_init.shape[0],latent_dim)

out_batch.min(), out_batch.max()

(Array(-1., dtype=float64), Array(1., dtype=float64))

**Linear**

$$
\gamma_T(t) = \mathbf{w}t + \mathbf{b}
$$

In [18]:
latent_dim = 128

time_encoder = eqx.nn.Linear(
    in_features=1, out_features=latent_dim, key=jrandom.PRNGKey(123)
)


out = time_encoder(t_init[0])

assert out.shape == (latent_dim,)

out_batch = jax.vmap(time_encoder)(t_init)

assert out_batch.shape == (t_init.shape[0],latent_dim)

out_batch.min(), out_batch.max()

(Array(-1.70982319, dtype=float64), Array(1.7617026, dtype=float64))

**Tanh**

$$
\gamma_T(t) = \tanh\left(\mathbf{w}t+\mathbf{b}\right)
$$

In [19]:
latent_dim = 128

time_encoder = time_net.TimeTanh(
    in_features=1, out_features=latent_dim, 
    use_bias=True,
    key=jrandom.PRNGKey(123)
)

out = time_encoder(t_init[0])

assert out.shape == (latent_dim,)

out_batch = jax.vmap(time_encoder)(t_init)

assert out_batch.shape == (t_init.shape[0],latent_dim)

out_batch.min(), out_batch.max()

(Array(-0.93662585, dtype=float64), Array(0.94269286, dtype=float64))

**Log**

$$
\gamma_T(t) = \log\left(\exp\left(\mathbf{w}t+\mathbf{b}\right) +1\right)
$$

In [20]:
latent_dim = 128

time_encoder = time_net.TimeLog(
    in_features=1, out_features=latent_dim, 
    use_bias=True,
    key=jrandom.PRNGKey(123)
)

out = time_encoder(t_init[0])

assert out.shape == (latent_dim,)

out_batch = jax.vmap(time_encoder)(t_init)

assert out_batch.shape == (t_init.shape[0],latent_dim)

out_batch.min(), out_batch.max()

(Array(0.16627497, dtype=float64), Array(1.92020283, dtype=float64))

**Fourier**

$$
\gamma_T(t) = \log\left(\exp\left(\mathbf{w}t+\mathbf{b}\right) +1\right)
$$

In [21]:
latent_dim = 128
bounded = True
time_encoder = time_net.TimeFourier(
    in_features=1, out_features=latent_dim, 
    bounded=bounded,
    key=jrandom.PRNGKey(123)
)

out = time_encoder(t_init[0])

assert out.shape == (latent_dim,)

out_batch = jax.vmap(time_encoder)(t_init)

assert out_batch.shape == (t_init.shape[0],latent_dim)

out_batch.min(), out_batch.max()

(Array(-0.49999261, dtype=float64), Array(0.49999261, dtype=float64))

#### Positional Encoding

In [18]:
from jejeqx._src.nets.nerfs import encoders

#### Gaussian Random Features


Source:

* [Blog](https://gregorygundersen.com/blog/2019/12/23/random-fourier-features/)
* [Tutorial](https://random-walks.org/content/misc/rff/rff.html)

In [19]:
latent_dim = 128
in_dim = 1
sigma = 1.0
key = key = jrandom.PRNGKey(42)

projection = jnp.eye(in_dim)

time_encoder = encoders.GaussianFourierFeatureEncoding(in_dim=in_dim, num_features=latent_dim, sigma=sigma, key=key)

out = time_encoder(t_init[0])

assert out.shape[-1] == time_encoder.out_dim

out_batch = jax.vmap(time_encoder)(t_init)

out.shape, out_batch.shape

assert out_batch.shape[-1] == time_encoder.out_dim

#### Identity

In [20]:
latent_dim = 128
in_dim = 1

projection = jnp.eye(in_dim)

time_encoder = encoders.IdentityEncoding(in_dim=in_dim)

out = time_encoder(t_init[0])

assert out.shape[-1] == time_encoder.out_dim

out_batch = jax.vmap(time_encoder)(t_init)

assert out_batch.shape[-1] == time_encoder.out_dim

#### NeRF-Like Positional Encoding

*NeRF - Neural Radiance Fields*

In [21]:
latent_dim = 128
in_dim = 1

projection = jnp.eye(in_dim)

time_encoder = encoders.SinusoidalEncoding(in_dim=in_dim, num_features=latent_dim)

out = time_encoder(t_init[0])

assert out.shape[-1] == time_encoder.out_dim

out_batch = jax.vmap(time_encoder)(t_init)

out.shape, out_batch.shape

assert out_batch.shape[-1] == time_encoder.out_dim