In [None]:
import inox
import inox.nn as nn
import jax
import jax.numpy as jnp
import optax

from inox.random import PRNG
from tqdm import trange
from typing import *

from utils import *

In [None]:
rng = PRNG(0)

## Data

In [None]:
modes = rng.uniform((8, 5), minval=-2.0, maxval=2.0)

i = rng.randint((65536,), minval=0, maxval=len(modes))
x = rng.normal((65536, 5)) / 8
x = modes[i] + x

show(x)

In [None]:
def measure(A, x):
    return jnp.einsum('...ij,...j->...i', A, x)

A = rng.normal((65536, 2, 5))
A = A / jnp.linalg.norm(A, axis=-1, keepdims=True)

y = measure(A, x) + 1e-3 * rng.normal((65536, 2))

In [None]:
def sample(model, A, y):
    sampler = DDPM(
        PosteriorDenoiser(
            model=model,
            A=inox.Partial(measure, A),
            y=y,
            sigma_y=1e-3 ** 2,
        ),
    )

    z = rng.normal((len(y), 5))
    x = sampler(z, steps=64, key=rng.split())

    return x

x_bis = sample(GaussianDenoiser(), A, y)
show(x_bis)

## Training

In [None]:
def train(x, model=None, steps=65536):
    if model is None:
        model = make_model(key=rng.split())

    static, params, others = model.partition(nn.Parameter)

    scheduler = optax.linear_schedule(init_value=1e-3, end_value=1e-6, transition_steps=steps)
    optimizer = optax.adam(learning_rate=scheduler)
    opt_state = optimizer.init(params)

    objective = DenoiserLoss()

    def ell(params, others, x, A, y, key):
        keys = jax.random.split(key, 2)

        z = jax.random.normal(keys[0], shape=x.shape)
        t = jax.random.beta(keys[1], a=3, b=3, shape=x.shape[:1])

        return objective(static(params, others), x, z, t, A=inox.Partial(measure, A), y=y)

    @jax.jit
    def sgd_step(params, others, opt_state, x, A, y, key):
        loss, grads = jax.value_and_grad(ell)(params, others, x, A, y, key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        return loss, params, opt_state

    for step in (bar := trange(steps)):
        i = rng.randint(shape=(1024,), minval=0, maxval=len(y))

        loss, params, opt_state = sgd_step(params, others, opt_state, x[i], A[i], y[i], rng.split())

        bar.set_postfix(loss=float(loss))

    return static(params, others)

In [None]:
model = train(x_bis)
x_bis = sample(model, A, y)
show(x_bis)

In [None]:
model = train(x_bis)
x_bis = sample(model, A, y)
show(x_bis)

In [None]:
model = train(x_bis)
x_bis = sample(model, A, y)
show(x_bis)

In [None]:
model = train(x_bis)
x_bis = sample(model, A, y)
show(x_bis)

In [None]:
sampler = DDPM(model)
x_ter = sampler(rng.normal(x.shape), steps=64, key=rng.split())
show(x_ter)