In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp


from diffuse.conditional import CondSDE
from diffuse.inference import generate_cond_sampleV2
from diffuse.sde import SDE, SDEState, LinearSchedule
from diffuse.unet import UNet

from vraie_vie.create_dataset import WMH
from vraie_vie.design_wmh import main
from vraie_vie.utils import maskAno, maskSpiral, slice_inverse_fourier

import matplotlib.pyplot as plt
from functools import partial
import os

jax.config.update("jax_enable_x64", False)

  """
  """


In [None]:
config = {
    "modality": "FLAIR",
    "slice_size_template": 49,
    "begin_slice": 26,
    "flair_template_path": "/lustre/fswork/projects/rech/hlp/uha64uw/projet_p/WMH/MNI-FLAIR-2.0mm.nii.gz",
    # "path_dataset": "/Users/geoffroyoudoumanessah/Documents/these/projects/datasets/WMH",
    "path_dataset": "/lustre/fswork/projects/rech/hlp/uha64uw/projet_p/WMH",
    "save_path": "/lustre/fswork/projects/rech/hlp/uha64uw/projet_p/WMH/models/",
    "n_epochs": 4000,
    "batch_size": 32,
    "num_workers": 0,
    "n_t": 32,
    "tf": 2.0,
    "lr": 2e-4,
}

# ODE Sampling

In [None]:
# Retrieve trained Parameters

checkpoint = jnp.load(
    os.path.join(config["save_path"], "ann_3795.npz"), allow_pickle=True
)

params = checkpoint["params"].item()

# Get the Datasets
wmh = WMH(config)
wmh.setup()
train_loader = wmh.get_train_dataloader().dataset

# Get the ScoreNet
nn_unet = UNet(config["tf"] / config["n_t"], 64, upsampling="pixel_shuffle")


def nn_score_(x, t, scoreNet, params):
    return scoreNet.apply(params, x, t)


nn_score = partial(nn_score_, scoreNet=nn_unet, params=params)

In [None]:
from scipy import integrate
import numpy as np

shape_data = train_loader[0].shape


def ode_fn(t, x):
    x = jnp.asarray(x).reshape(shape_data)
    state = SDEState(x, jnp.array([t]))
    drift = drift_fn(state)
    return np.asarray(drift.flatten())


x = np.asarray(train_loader[0].flatten())
beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)
sde = SDE(beta=beta)

drift_fn = sde.reverso_ode(config["tf"], nn_score)


res = integrate.solve_ivp(
    ode_fn, (config["tf"], 0), x, rtol=1e-5, atol=1e-5, method="RK45"
)

In [None]:
from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController

key = jax.random.PRNGKey(0)
n_steps = 1000

shape_data = train_loader[0].shape
beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)
sde = SDE(beta=beta)
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

drift_fn = sde.reverso_ode(config["tf"], nn_score)


def ode_fn(t, x):
    x = x.reshape(shape_data)
    state = SDEState(x, jnp.array([t]))
    drift = drift_fn(state)
    return drift.flatten()


term = ODETerm(ode_fn)
solver = Dopri5()
for _ in range(10):
    key, subkey = jax.random.split(key)
    y0 = jax.random.normal(subkey, train_loader[0].shape).flatten()
    solution = diffeqsolve(
        term,
        solver,
        t0=0,
        t1=config["tf"],
        dt0=config["tf"] / n_steps,
        y0=y0,
        stepsize_controller=stepsize_controller,
    )
    plt.imshow(solution[..., 0], cmap="gray")
    plt.show()

# Sampling

In [None]:
# Retrieve trained Parameters

checkpoint = jnp.load(
    os.path.join(config["save_path"], "ann_3795.npz"), allow_pickle=True
)

params = checkpoint["params"].item()

# Get the Datasets
wmh = WMH(config)
wmh.setup()
train_loader = wmh.get_train_dataloader().dataset

# Get the ScoreNet
nn_unet = UNet(config["tf"] / config["n_t"], 64, upsampling="pixel_shuffle")


def nn_score_(x, t, scoreNet, params):
    return scoreNet.apply(params, x, t)


nn_score = partial(nn_score_, scoreNet=nn_unet, params=params)

In [None]:
n_steps = 1000
key = jax.random.PRNGKey(0)

ts = jnp.array([config["tf"]])
dts = jnp.array([2.0 / n_steps] * (n_steps))

beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)
sde = SDE(beta=beta)

for _ in range(10):
    key, subkey = jax.random.split(key)
    init_samples = jax.random.normal(subkey, train_loader[0].shape)
    state_f = SDEState(position=init_samples, t=ts)

    revert_sde = partial(sde.reverso, score=nn_score, dts=dts)

    key, subkey = jax.random.split(key)
    state_0, state_Ts = revert_sde(subkey, state_f)

    plt.imshow(state_Ts.position[-1][..., 0], cmap="gray")
    plt.show()

In [None]:
# On bruite
key = jax.random.PRNGKey(0)

x0_samples = jnp.array([train_loader[k] for k in range(config["batch_size"])])
n_x0 = x0_samples.shape[0]

key, subkey = jax.random.split(key)
ts = jax.random.uniform(key, (config["n_t"] - 1, 1), minval=1e-5, maxval=config["tf"])
ts = jnp.concatenate([ts, jnp.array([[config["tf"]]])], axis=0)

state_0 = SDEState(x0_samples, jnp.zeros((n_x0, 1)))
keys_x = jax.random.split(key, n_x0)
state = jax.vmap(sde.path, in_axes=(0, 0, 0))(keys_x, state_0, ts)

# On sample
n_steps = 1000

ts = jnp.array([config["tf"]])
dts = jnp.array([2.0 / n_steps] * (n_steps))

beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)
sde = SDE(beta=beta)

key, subkey = jax.random.split(key)
init_samples = state.position[30]
state_f = SDEState(position=init_samples, t=ts)

revert_sde = partial(sde.reverso, score=nn_score, dts=dts)

key, subkey = jax.random.split(key)
state_0, state_Ts = revert_sde(subkey, state_f)

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(x0_samples[30][..., 0], cmap="gray")

plt.subplot(1, 2, 2)
plt.imshow(state_Ts.position[-1][..., 0], cmap="gray")
plt.show()

# Conditional

In [None]:
# Load dataset
wmh = WMH(config)
wmh.setup()
train_loader = wmh.get_test_dataloader().dataset

key = jax.random.PRNGKey(0)
beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)

checkpoint = jnp.load(
    os.path.join(config["save_path"], "ann_3795.npz"), allow_pickle=True
)


nn_unet = UNet(config["tf"] / config["n_t"], 64, upsampling="pixel_shuffle")
params = checkpoint["params"].item()


def nn_score_(x, t, scoreNet, params):
    return scoreNet.apply(params, x, t)


nn_score = partial(nn_score_, scoreNet=nn_unet, params=params)

sde = SDE(beta=beta)


x = jnp.array(train_loader[680])
plt.imshow(x[..., 0], cmap="gray")
plt.show()

In [None]:
size = (92, 112)

mask_spiral = maskSpiral(img_shape=size, num_spiral=3, num_samples=50000, sigma=0.2)
cond_sde = CondSDE(beta=beta, mask=mask_spiral, tf=2.0, score=nn_score)

xi = jnp.array([3.5, 2.0])  # FOV, k_max
y = mask_spiral.measure(xi, x)
x_sub = slice_inverse_fourier(y[..., 0])

mask = mask_spiral.make(xi)

res = generate_cond_sampleV2(y, mask, key, cond_sde, x.shape, 1000, 300)

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(20, 10))
axs[0, 0].imshow(x[..., 0], cmap="gray")
axs[0, 0].set_title("Original Image")


mask = mask_spiral.make(xi)
axs[0, 1].imshow(mask, cmap="gray")
axs[0, 1].set_title("Fourier Mask")


axs[0, 2].imshow(res[0][0].position[-1, ..., 0], cmap="gray")
axs[0, 2].set_title("Reconstructed Image")


axs[1, 0].imshow(x[..., 1], cmap="gray")
axs[1, 0].set_title("Original Anomaly Map")


axs[1, 1].imshow(x_sub, cmap="gray")
axs[1, 1].set_title("Subsampled reconstruction")

axs[1, 2].imshow(res[0][0].position[-1, ..., 1], cmap="gray")
axs[1, 2].set_title("Reconstructed Anomaly Map")

plt.tight_layout()
plt.show()

# Design

In [None]:
rng_key = key = jax.random.PRNGKey(0)
state = main(rng_key)

# Evaluation

In [None]:
wmh = WMH(config)
wmh.setup()
test_loader = wmh.get_test_dataloader().dataset

key = jax.random.PRNGKey(0)
beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)

checkpoint = jnp.load(
    os.path.join(config["save_path"], "ann_3795.npz"), allow_pickle=True
)


nn_unet = UNet(config["tf"] / config["n_t"], 64, upsampling="pixel_shuffle")
params = checkpoint["params"].item()


def nn_score_(x, t, scoreNet, params):
    return scoreNet.apply(params, x, t)


nn_score = partial(nn_score_, scoreNet=nn_unet, params=params)

sde = SDE(beta=beta)


x = jnp.array(test_loader[680])

plt.imshow(x[..., 0], cmap="gray")
plt.tight_layout()
plt.show()

In [None]:
size = (92, 112)

mask_ano = maskAno(img_shape=size)
cond_sde = CondSDE(beta=beta, mask=mask_ano, tf=2.0, score=nn_score)

xi = jnp.array(0.0)
y = mask_ano.measure(xi, x)

mask = mask_ano.make(xi)

res = generate_cond_sampleV2(y, mask, key, cond_sde, x.shape, 1000, 300)

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(20, 10))
axs[0, 0].imshow(x[..., 0], cmap="gray")
axs[0, 1].imshow(x[..., 1], cmap="gray")
axs[1, 0].imshow(res[0][0].position[1, ..., 0], cmap="gray")
axs[1, 1].imshow(res[0][0].position[1, ..., 1], cmap="gray")

plt.tight_layout()
plt.show()