# Sampling

In [1]:
from functools import partial
import os

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

from optax import EmaState, EmptyState, ScaleByAdamState, ScaleByScheduleState

from diffuse.sde import SDE, SDEState, LinearSchedule
from diffuse.unet import UNet
from vraie_vie.create_dataset import WMH

In [2]:
config = {
        "modality": "FLAIR",
        "slice_size_template": 91,
        "flair_template_path": "/lustre/fswork/projects/rech/hlp/uha64uw/aistat24/WMH/MNI-FLAIR-2.0mm.nii.gz",
        "path_dataset": "/lustre/fswork/projects/rech/hlp/uha64uw/aistat24/WMH",
        "save_path": "/lustre/fswork/projects/rech/hlp/uha64uw/aistat24/WMH/models/",
        "n_epochs": 4000,
        "batch_size": 32,
        "num_workers": 0,
        "n_t": 32,
        "tf": 2.0,
        "lr": 2e-4,
    }

In [3]:
# Retrieve trained Parameters

checkpoint = jnp.load(os.path.join(config["save_path"], "ann_215.npz"), allow_pickle=True)

params = checkpoint["params"].item()

ema_state = EmaState(count=checkpoint["ema_state"][0], ema=checkpoint["ema_state"][1])

opt_state = (
    EmptyState(),
    (
        ScaleByAdamState(
            count=checkpoint["opt_state_2"][0], mu=checkpoint["opt_state_2"][1], nu=checkpoint["opt_state_2"][2]
        ),
        ScaleByScheduleState(checkpoint["opt_state_3"][0]),
    ),
)

# Get the Datasets
wmh = WMH(config)
wmh.setup()
train_loader = wmh.get_train_dataloader().dataset

# Get the ScoreNet
nn_unet = UNet(config["tf"] / config["n_t"], 64, upsampling="pixel_shuffle")

def nn_score_(x, t, scoreNet, params):
    return scoreNet.apply(params, x, t)

nn_score = partial(nn_score_, scoreNet=nn_unet, params=params)

The history saving thread hit an unexpected error (OperationalError('disk I/O error')).History will not be written to the database.


In [None]:
n_steps = 1000
key = jax.random.PRNGKey(0)

ts = jnp.array([config["tf"]])
dts = jnp.array([2.0 / n_steps] * (n_steps))

beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)
sde = SDE(beta=beta)

for _ in range(10):
    key, subkey = jax.random.split(key)
    init_samples = jax.random.normal(subkey, train_loader[0].shape)
    state_f = SDEState(position=init_samples, t=ts)
    
    revert_sde = partial(sde.reverso, score=nn_score, dts=dts)
    
    key, subkey = jax.random.split(key)
    state_0, state_Ts = revert_sde(subkey, state_f)
    
    plt.imshow(state_Ts.position[-1][..., 0], cmap="gray")
    plt.show()

In [None]:
# On bruite
key = jax.random.PRNGKey(0)

x0_samples = jnp.array([train_loader[k] for k in range(config['batch_size'])])
n_x0 = x0_samples.shape[0]

key, subkey = jax.random.split(key)
ts = jax.random.uniform(key, (config['n_t'] - 1, 1), minval=1e-5, maxval=config['tf'])
ts = jnp.concatenate([ts, jnp.array([[config['tf']]])], axis=0)

state_0 = SDEState(x0_samples, jnp.zeros((n_x0, 1)))
keys_x = jax.random.split(key, n_x0)
state = jax.vmap(sde.path, in_axes=(0, 0, 0))(keys_x, state_0, ts)

# On sample
n_steps = 1000

ts = jnp.array([config["tf"]])
dts = jnp.array([2.0 / n_steps] * (n_steps))

beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)
sde = SDE(beta=beta)

key, subkey = jax.random.split(key)
init_samples = state.position[30]
state_f = SDEState(position=init_samples, t=ts)

revert_sde = partial(sde.reverso, score=nn_score, dts=dts)

key, subkey = jax.random.split(key)
state_0, state_Ts = revert_sde(subkey, state_f)

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(x0_samples[30][..., 0], cmap="gray")

plt.subplot(1, 2, 2)
plt.imshow(state_Ts.position[-1][..., 0], cmap="gray")
plt.show()