# Imports

In [1]:
%cd ..

/grad/bzhang388/pisp/physics_aware_diffusion


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import functools
import utils
import fstars 
import fstar_net
import os

from clu import metric_writers
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import optax
import orbax.checkpoint as ocp

import h5py
import natsort
import tensorflow as tf
from scipy.ndimage import geometric_transform
from scipy.ndimage import gaussian_filter

from swirl_dynamics import templates
from swirl_dynamics.lib import diffusion as dfn_lib
from swirl_dynamics.lib import solvers as solver_lib
from swirl_dynamics.projects import probabilistic_diffusion as dfn



In [3]:
jax.devices()
tf.config.set_visible_devices([], device_type='GPU')

# Dataset

In [4]:
# Parameters for the computational task.

#L = 4 # number of levels (even number)
#s = 5 # leaf size
#r = 3 # rank
#
## Discretization of Omega (n_eta * n_eta).
#neta = (2**L)*s
#
## Number of sources/detectors (n_sc).
## Discretization of the domain of alpha in polar coordinates (n_theta * n_rho).
## For simplicity, these values are set equal (n_sc = n_theta = n_rho), facilitating computation.
#nx = (2**L)*s
neta = 60
nx = 60
# Standard deviation for the Gaussian blur.
blur_sigma = 0.5

# Number of training datapoints.
NTRAIN = 21000

# Number of testing datapoints.
NTEST = 500

# Total number
NTOTAL = NTRAIN + NTEST


In [5]:
name = '../data/fastmri60'

# Loading and preprocessing perturbation data (eta)
with h5py.File(f'{name}/eta60.h5', 'r') as f:
    # Read eta data, apply Gaussian blur, and reshape
    eta_re = f[list(f.keys())[0]][:NTRAIN, :].reshape((-1,neta,neta))
    blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
    eta_re = np.stack([blur_fn(eta_re[i, :, :].T) for i in range(NTRAIN)]).astype('float32')
    
mean_eta = np.mean(eta_re, axis = 0)
eta_re -= mean_eta
std_eta = np.std(eta_re)
eta_re /= std_eta

# Loading and preprocessing scatter data (Lambda)
with h5py.File(f'{name}/scatter60.h5', 'r') as f:
    keys = natsort.natsorted(f.keys())

    # Process real part of scatter data
    tmp1 = f[keys[3]][:NTRAIN, :]
    tmp2 = f[keys[4]][:NTRAIN, :]
    tmp3 = f[keys[5]][:NTRAIN, :]
    scatter_re = np.stack((tmp1, tmp2, tmp3), axis=-1)

    # Process imaginary part of scatter data
    tmp1 = f[keys[0]][:NTRAIN, :]
    tmp2 = f[keys[1]][:NTRAIN, :]
    tmp3 = f[keys[2]][:NTRAIN, :]
    scatter_im = np.stack((tmp1, tmp2, tmp3), axis=-1)
    
    # Combine real and imaginary parts
    scatter = np.stack((scatter_re, scatter_im), axis=-2).astype('float32')

mean0, std0 = np.mean(scatter[:,:,:,0]), np.std(scatter[:,:,:,0])
mean1, std1 = np.mean(scatter[:,:,:,1]), np.std(scatter[:,:,:,1])
mean2, std2 = np.mean(scatter[:,:,:,2]), np.std(scatter[:,:,:,2])

scatter[:,:,:,0] -= mean0
scatter[:,:,:,0] /= std0
scatter[:,:,:,1] -= mean1
scatter[:,:,:,1] /= std1
scatter[:,:,:,2] -= mean2
scatter[:,:,:,2] /= std2

# Clean up temporary variables to free memory
del scatter_re, scatter_im, tmp1, tmp2, tmp3


In [6]:
eta_train = eta_re.reshape(-1, neta, neta, 1)
scatter_train = scatter.reshape(-1, nx**2, 2, 3)

In [7]:
# To avoid tf to use GPU memory

batch_size = 16
dict_data = {"x": eta_train}
dict_data["cond"] = {"channel:scatter0": scatter_train[:,:,:,0],
                     "channel:scatter1": scatter_train[:,:,:,1],
                     "channel:scatter2": scatter_train[:,:,:,2]}
dataset = tf.data.Dataset.from_tensor_slices(dict_data)
dataset = dataset.shuffle(NTRAIN) 
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
dataset = dataset.as_numpy_iterator()

The architecture is similar to the unconditional case. We provide additional args that specify how to resize the conditioning signal (in order to be compatible with the noisy sample for channel-wise concatenation).

In [8]:
r_index = utils.rotationindex(nx)
cart_mat = utils.SparsePolarToCartesian(neta, nx)

2024-07-02 02:48:59.596620: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 

In [9]:
import importlib
importlib.reload(fstars)

<module 'fstars' from '/grad/bzhang388/pisp/physics_aware_diffusion/fstars.py'>

In [10]:
fstarlist = [fstars.uncompressed_fstar( 
    nx = nx, 
    neta = neta,
    cart_mat = cart_mat,
    r_index = r_index
) for i in range(3)]

In [11]:
del r_index, cart_mat

In [12]:
import importlib
importlib.reload(fstar_net)

<module 'fstar_net' from '/grad/bzhang388/pisp/physics_aware_diffusion/fstar_net.py'>

In [13]:
cond_denoiser_model = fstar_net.PreconditionedDenoiser(
    fstars=fstarlist,
    out_channels=1,
    squeeze_ratio=8,
    cond_embed_iter=10, 
    noise_embed_dim=96, 
    num_conv=8,
    num_feature=96, # multiples of 32
)

In [14]:
diffusion_scheme = dfn_lib.Diffusion.create_variance_preserving(
    sigma=dfn_lib.tangent_noise_schedule(),
    data_std=1, # we always use normalized data
)

cond_model = dfn.DenoisingModel(
    input_shape=(neta,neta,1),
    cond_shape={"channel:scatter0": (nx**2,2),
                "channel:scatter1": (nx**2,2),
                "channel:scatter2": (nx**2,2)},
    denoiser=cond_denoiser_model,
    noise_sampling=dfn_lib.time_uniform_sampling(
        diffusion_scheme, clip_min=1e-4, uniform_grid=True,
    ),
    noise_weighting=dfn_lib.edm_weighting(data_std=1),
)

In [15]:
epochs = 120
num_train_steps = 21000 * epochs // 16  #@param
cond_workdir = os.path.abspath('') + "/tmp/diffusion_uncompressed_fastmri_60"
initial_lr = 1e-5 #@param
peak_lr = 1e-3 #@pawram
warmup_steps = num_train_steps // 20  #@param
end_lr = 1e-8 #@param
ema_decay = 0.999  #@param
ckpt_interval = 2000 #@param
max_ckpt_to_keep = 3 #@param

In [16]:
rng = jax.random.PRNGKey(666)
params = cond_model.initialize(rng)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print('Number of trainable parameters:', param_count)

Number of trainable parameters: 462415


In [17]:
def count_params(params, parent_name=''):
    """ Recursively count the number of parameters in the JAX model. """
    total_params = 0
    for key, value in params.items():
        layer_name = f"{parent_name}/{key}" if parent_name else key
        if isinstance(value, dict):
            # Recurse into nested dictionary
            layer_params = count_params(value, layer_name)
            total_params += layer_params
        else:
            # Assume value is a parameter array
            layer_params = value.size
            total_params += layer_params
            print(f"Layer: {layer_name}, Parameters: {layer_params}")
    return total_params

total_parameters = count_params(params)
print(f"Total parameters in the model: {total_parameters}")


Layer: params/fstars_0/pre1, Parameters: 60
Layer: params/fstars_0/pre2, Parameters: 60
Layer: params/fstars_0/pre3, Parameters: 60
Layer: params/fstars_0/pre4, Parameters: 60
Layer: params/fstars_0/post1, Parameters: 60
Layer: params/fstars_0/post2, Parameters: 60
Layer: params/fstars_0/post3, Parameters: 60
Layer: params/fstars_0/post4, Parameters: 60
Layer: params/fstars_0/cos_kernel1, Parameters: 3600
Layer: params/fstars_0/sin_kernel1, Parameters: 3600
Layer: params/fstars_0/cos_kernel2, Parameters: 3600
Layer: params/fstars_0/sin_kernel2, Parameters: 3600
Layer: params/fstars_0/cos_kernel3, Parameters: 3600
Layer: params/fstars_0/sin_kernel3, Parameters: 3600
Layer: params/fstars_0/cos_kernel4, Parameters: 3600
Layer: params/fstars_0/sin_kernel4, Parameters: 3600
Layer: params/fstars_1/pre1, Parameters: 60
Layer: params/fstars_1/pre2, Parameters: 60
Layer: params/fstars_1/pre3, Parameters: 60
Layer: params/fstars_1/pre4, Parameters: 60
Layer: params/fstars_1/post1, Parameters: 60

# Training

In [18]:
cond_trainer = dfn.DenoisingTrainer(
    model=cond_model,
    rng=jax.random.PRNGKey(666),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
    ema_decay=ema_decay,
)

In [19]:
templates.run_train(
    train_dataloader=dataset,
    trainer=cond_trainer,
    workdir=cond_workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        cond_workdir, asynchronous=False
    ),
    metric_aggregation_steps = 100,
    callbacks=(
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
        ),
        templates.TrainStateCheckpoint(
            base_dir=cond_workdir,
            options=ocp.CheckpointManagerOptions( 
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)



  0%|          | 0/157500 [00:00<?, ?step/s]

# Inference

In [20]:
trained_state = dfn.DenoisingModelTrainState.restore_from_orbax_ckpt(
    f"{cond_workdir}/checkpoints", step=None
)

# Construct the inference function
cond_denoise_fn = dfn.DenoisingTrainer.inference_fn_from_state_dict(
    trained_state, use_ema=True, denoiser=cond_denoiser_model
)



In [21]:
cond_sampler = dfn_lib.SdeSampler(
    input_shape=(neta,neta,1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.exponential_noise_decay(diffusion_scheme, num_steps=256, end_sigma=1e-3,),
    scheme=diffusion_scheme,
    denoise_fn=cond_denoise_fn,
    guidance_transforms=(),
    apply_denoise_at_end=True,
    return_full_paths=False,
)

We again JIT the generate function for the sake of faster repeated sampling calls. Here we employ `functools.partial` to specify `num_samples=5`, making it easier to vectorize across the batch dimension with `jax.vmap`.

In [22]:
num_samples_per_cond = 1

generate = jax.jit(
    functools.partial(cond_sampler.generate, num_samples_per_cond)
)

In [23]:
NTEST = 500 # Change it!

Loading a test batch of conditions with 4 elements:

In [24]:
name = '../data/fastmri60'

# Loading and preprocessing perturbation data (eta)
with h5py.File(f'{name}/eta60.h5', 'r') as f:
    # Read eta data, apply Gaussian blur, and reshape
    eta_re = f[list(f.keys())[0]][-NTEST:, :].reshape(-1, neta, neta)
    blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
    eta_re = np.stack([blur_fn(img.T) for img in eta_re]).astype('float32')

# Loading and preprocessing scatter data (Lambda)
with h5py.File(f'{name}/scatter60.h5', 'r') as f:
    keys = natsort.natsorted(f.keys())

    # Process real part of scatter data
    tmp1 = f[keys[3]][-NTEST:, :]
    tmp2 = f[keys[4]][-NTEST:, :]
    tmp3 = f[keys[5]][-NTEST:, :]
    scatter_re = np.stack((tmp1, tmp2, tmp3), axis=-1)

    # Process imaginary part of scatter data
    tmp1 = f[keys[0]][-NTEST:, :]
    tmp2 = f[keys[1]][-NTEST:, :]
    tmp3 = f[keys[2]][-NTEST:, :]
    scatter_im = np.stack((tmp1, tmp2, tmp3), axis=-1)
    
    # Combine real and imaginary parts
    scatter = np.stack((scatter_re, scatter_im), axis=-2).astype('float32')
    
scatter[:,:,:,0] -= mean0
scatter[:,:,:,0] /= std0
scatter[:,:,:,1] -= mean1
scatter[:,:,:,1] /= std1
scatter[:,:,:,2] -= mean2
scatter[:,:,:,2] /= std2

# Clean up temporary variables to free memory
del scatter_re, scatter_im, tmp1, tmp2, tmp3

In [25]:
eta_test = eta_re.reshape(-1, neta, neta, 1)
scatter_test = scatter.reshape(-1, nx**2, 2, 3) 
#scatter_test = np.swapaxes(scatter_test,1,2).reshape(-1, 6400, 2, 3)
#c = 0.4
#scatter_test += np.random.normal(0,c,size=scatter_test.shape)

In [26]:
batch_size_test = 20
dict_data_test = {}
dict_data_test["cond"] = {"channel:scatter0": scatter_test[:,:,:,0],
                          "channel:scatter1": scatter_test[:,:,:,1],
                          "channel:scatter2": scatter_test[:,:,:,2]}

dataset_test = tf.data.Dataset.from_tensor_slices(dict_data_test)
dataset_test = dataset_test.batch(batch_size_test)
dataset_test = dataset_test.prefetch(tf.data.AUTOTUNE)
dataset_test = dataset_test.as_numpy_iterator()

In [27]:
eta_pred = np.zeros((NTEST, num_samples_per_cond, neta, neta, 1))

b = 0
for batch in dataset_test:
    print(b)
    cond_samples = jax.device_get(jax.vmap(generate, in_axes=(0, 0, None))(
        jax.random.split(jax.random.PRNGKey(68), batch_size_test),
        batch["cond"],
        None,  # Guidance inputs = None since no guidance transforms involved
    ))
    eta_pred[b*batch_size_test:(b+1)*batch_size_test,:,:,:,:] = cond_samples*std_eta+mean_eta[:, :, jnp.newaxis]
    b += 1


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24


2024-07-02 07:31:22.362990: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [28]:
errors = []
for i in range(NTEST):
    for j in range(num_samples_per_cond):
        errors.append(np.linalg.norm(eta_test[i,:,:,0]-eta_pred[i,0,:,:,0])/np.linalg.norm(eta_test[i,:,:,0]))
        
print('Mean of validation relative l2 error:', np.mean(errors))
print('Median of validation relative l2 error:', np.median(errors))
print('Min of validation relative l2 error:', np.min(errors))
print('Max of validation relative l2 error:', np.max(errors))
print('Standard deviation of validation relative l2 errors:', np.std(errors))


Mean of validation relative l2 error: 0.05363113087343309
Median of validation relative l2 error: 0.054492172578428255
Min of validation relative l2 error: 0.02654575802626506
Max of validation relative l2 error: 0.10394141285320327
Standard deviation of validation relative l2 errors: 0.013645110504661438


In [68]:
#with h5py.File("results_diffusion_uncompressed_squares_final.h5", "w") as f:
#    f.create_dataset('eta', data=eta_test)
#    f.create_dataset('eta_pred', data=eta_pred)


In [69]:
#with h5py.File("results_diffusion_uncompressed_squares.h5", "w") as f:
#    f.create_dataset('eta', data=eta_test)
#    f.create_dataset('eta_pred', data=eta_pred)
