In [None]:
import inox
import inox.nn as nn
import jax
import jax.numpy as jnp
import optax

from inox.random import PRNG
from typing import *

from utils import *

In [None]:
rng = PRNG(1)

## Data

In [None]:
x, A, y = make_data(65536, key=rng.split())
show(x)

## Fit

In [None]:
mu_x, sigma_x = fit_moments(
    features=5,
    rank=3,
    A=inox.Partial(measure, A),
    y=y,
    sigma_y=1e-3 ** 2,
    key=rng.split(),
)

In [None]:
x_bis = sample(GaussianDenoiser(mu_x, sigma_x), A, y, rng.split())
show(x_bis)

## Training

In [None]:
def train(x, model=None, steps=65536):
    if model is None:
        model = make_model(key=rng.split())

    static, params, others = model.partition(nn.Parameter)

    scheduler = optax.linear_schedule(init_value=1e-3, end_value=1e-6, transition_steps=steps)
    optimizer = optax.adam(learning_rate=scheduler)
    opt_state = optimizer.init(params)

    objective = DenoiserLoss()

    def ell(params, others, x, key):
        keys = jax.random.split(key, 2)

        z = jax.random.normal(keys[0], shape=x.shape)
        t = jax.random.beta(keys[1], a=3, b=5, shape=x.shape[:1])

        return objective(static(params, others), x, z, t)

    @jax.jit
    def sgd_step(params, others, opt_state, x, key):
        loss, grads = jax.value_and_grad(ell)(params, others, x, key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        return loss, params, opt_state

    for step in range(steps):
        i = rng.randint(shape=(1024,), minval=0, maxval=len(y))
        loss, params, opt_state = sgd_step(params, others, opt_state, x[i], rng.split())

    return static(params, others)

In [None]:
model = train(x_bis)
x_bis = sample(model, A, y, rng.split())
show(x_bis)

In [None]:
model = train(x_bis)
x_bis = sample(model, A, y, rng.split())
show(x_bis)

In [None]:
model = train(x_bis)
x_bis = sample(model, A, y, rng.split())
show(x_bis)

In [None]:
model = train(x_bis)
x_bis = sample(model, A, y, rng.split())
show(x_bis)

In [None]:
x_ter = sample_any(model, shape=(65535, 5), key=rng.split())
show(x_ter)