In [1]:
%load_ext autoreload
%autoreload 2 
%matplotlib nbagg
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
from pathlib import Path

import haiku as hk
os.environ['XLA_FLAGS']='--xla_gpu_cuda_data_dir=/gpfslocalsys/cuda/10.1.2'
import jax
from jax.experimental import optix
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from nsec.normalization import SNParamsTree as CustomSNParamsTree
import pickle
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
except IndexError:
    pass
import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax
from tqdm.notebook import tqdm

from nsec.datasets.fastmri import mri_noisy_generator
from nsec.mri.model import get_model
from nsec.samplers import ScoreHamiltonianMonteCarlo
from nsec.tempered_sampling import TemperedMC

In [2]:
plt.rcParams['image.cmap'] = 'gray'

In [3]:
batch_size = 32
contrast = 'CORPD_FBK'
magnitude_images = True
noise_power_spec = 3*1e1
val_mri_gen = mri_noisy_generator(
    split='val',
    scale_factor=1e6,
    noise_power_spec=noise_power_spec,
    batch_size=batch_size,
    contrast=contrast,
    magnitude=magnitude_images,
    image_size=320,
)
##### BATCH DEFINITION
# (image_noisy, noise_power), noise_realisation
# here the noise_realisation is the full one, not the epsilon from the standard normal law

In [4]:
model, loss_fn, _, _, _, _, _, rng_seq = get_model(opt=False, magnitude_images=True, pad_crop=False, stride=False)



In [5]:
# Importing saved model
with open('../conv-dae-L2-mri-30.0_mag_no_stride_backup.pckl', 'rb') as file:
    params, state, sn_state = pickle.load(file)

In [6]:
from functools import partial
score = partial(model.apply, params, state, next(rng_seq))

In [7]:
(x, s), su = next(val_mri_gen)
s = s[..., None, None, None] * 1e-1
res, state = score(x, s, is_training=False)

In [8]:
for i in range(10):
    ind = i
    fig, axs = plt.subplots(1, 4, sharex=True, sharey=True)
    axs[0].set_title("%0.3f"%s[ind,0,0,0])
    axs[0].imshow(jnp.abs(x)[ind,...,0],cmap='gray')
    axs[0].axis('off')
    axs[1].imshow(jnp.abs(x - su)[ind,...,0],cmap='gray')
    axs[1].axis('off')
    axs[2].imshow(jnp.abs(res)[ind,...,0],cmap='gray')
    axs[2].axis('off')
    axs[2].set_title("%0.3f"%jnp.std(s[ind,:,:,0]**2 *res[ind,...,0]))
    axs[3].imshow(jnp.abs(x[ind,...,0] + s[ind,:,:,0]**2 * res[ind,...,0]),cmap='gray')
    axs[3].axis('off')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [9]:
# Trying to sample from the model

In [10]:
plt.figure()
plt.subplot(131)
plt.imshow(jnp.abs(x)[0,...,0].reshape((320,320,)))
plt.subplot(132)
plt.imshow(jnp.abs(x)[1,...,0].reshape((320,320,)))
# subplot(133)
# imshow(x[2,...,0].reshape((320,320,)))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fc208832908>

In [11]:
mask = np.ones_like(x[0:1])
mask[:, 130:170, 170:200] = 0

In [12]:
x_black = np.copy(x[0:1])
x_black = x_black * mask

In [13]:
def likelihood_fn(x_, sigma):
    """ This is a likelihood function for masked and noisy data
    """
    return jnp.sum(mask*((x - x_)/(s+sigma))**2)/2.

score_likelihood = jax.vmap(jax.grad(likelihood_fn))

In [14]:
def score_fn(x, sigma):
    return ( score(x.reshape((-1,320,320,1)), sigma.reshape((-1,1,1,1)), is_training=False)[0] + 
             - score_likelihood(x.reshape((-1,320,320,1)), sigma.reshape((-1,1,1,1)))
           ).reshape((-1, 320*320))

In [15]:
init_image=(x_black+0.1*np.random.randn(1,320,320,1)).reshape((-1, 320*320,)).astype('float32')

In [None]:
s0 =0.1
def make_kernel_fn(target_log_prob_fn, target_score_fn, sigma):
  return ScoreHamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn,
      target_score_fn=target_score_fn,
      step_size=0.001*(sigma/s0)**0.5,
      num_leapfrog_steps=3,
      num_delta_logp_steps=4)

tmc = TemperedMC(
            target_score_fn=score_fn,
            inverse_temperatures=s0*np.ones([1]),
            make_kernel_fn=make_kernel_fn,
            gamma=0.98,
            min_steps_per_temp=10,
            num_delta_logp_steps=4)

num_results = int(1e2)
num_burnin_steps = int(1e1)

samples, trace = tfp.mcmc.sample_chain(
        num_results=num_results,
        current_state=init_image,
        kernel=tmc,
        num_burnin_steps=num_burnin_steps,
        trace_fn=lambda _, pkr: (pkr.pre_tempering_results.is_accepted,
                                 pkr.post_tempering_inverse_temperatures,
                                 pkr.tempering_log_accept_ratio),
        seed=jax.random.PRNGKey(0))



[Traced<ShapedArray(float32[1,102400]):JaxprTrace(level=1/0)>]


In [None]:
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
axs[0].imshow(jnp.squeeze(jnp.abs(x[0])), vmin=10, vmax=150)
axs[1].imshow(jnp.squeeze(jnp.abs(x_black[0])), vmin=10, vmax=150)

In [None]:
fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 5))
axs[0].imshow(jnp.squeeze(jnp.abs(x[0])), vmin=10, vmax=150)
axs[1].imshow(jnp.squeeze(jnp.abs(x_black[0])), vmin=10, vmax=150)
axs[2].imshow(jnp.squeeze(jnp.abs(samples[-1].reshape((320, 320)))))