# Imports

In [None]:
# Change the current working directory to the "src" folder, located one level up from the current directory.
# Note: This command can only be run once in a Jupyter notebook, as it permanently changes the working directory.
%cd ../src


In [None]:
import utils
import fstars 
import fstar_cnn

import functools
import os
from clu import metric_writers
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import optax
import orbax.checkpoint as ocp

import h5py
import natsort
import tensorflow as tf
from scipy.ndimage import geometric_transform
from scipy.ndimage import gaussian_filter

from swirl_dynamics import templates
from swirl_dynamics.lib import diffusion as dfn_lib
from swirl_dynamics.lib import solvers as solver_lib
from swirl_dynamics.projects import probabilistic_diffusion as dfn


In [None]:
# To avoid tf to use GPU memory
tf.config.set_visible_devices([], device_type='GPU')


# Dataset

In [None]:
# Parameters for the computational task.

L = 4 # number of levels (even number)
s = 5 # leaf size
r = 3 # rank

# Discretization of Omega (n_eta * n_eta).
neta = (2**L)*s

# Number of sources/detectors (n_sc).
# Discretization of the domain of alpha in polar coordinates (n_theta * n_rho).
# For simplicity, these values are set equal (n_sc = n_theta = n_rho), facilitating computation.
nx = (2**L)*s

# Standard deviation for the Gaussian blur.
blur_sigma = 0.5

# Number of training datapoints.
NTRAIN = 21000



In [None]:
name = '../data/10hsquares_trainingdata'

# Loading and preprocessing perturbation data (eta)
with h5py.File(f'{name}/eta.h5', 'r') as f:
    # Read eta data, apply Gaussian blur, and reshape
    eta_re = f[list(f.keys())[0]][:NTRAIN, :].reshape(-1, neta, neta)
    blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
    eta_re = np.stack([blur_fn(eta_re[i, :, :].T) for i in range(NTRAIN)]).astype('float32')
    
mean_eta = np.mean(eta_re, axis = 0)
eta_re -= mean_eta
std_eta = np.std(eta_re)
eta_re /= std_eta

# Loading and preprocessing scatter data (Lambda)
with h5py.File(f'{name}/scatter.h5', 'r') as f:
    keys = natsort.natsorted(f.keys())

    # Process real part of scatter data
    tmp1 = f[keys[3]][:NTRAIN, :]
    tmp2 = f[keys[4]][:NTRAIN, :]
    tmp3 = f[keys[5]][:NTRAIN, :]
    scatter_re = np.stack((tmp1, tmp2, tmp3), axis=-1)

    # Process imaginary part of scatter data
    tmp1 = f[keys[0]][:NTRAIN, :]
    tmp2 = f[keys[1]][:NTRAIN, :]
    tmp3 = f[keys[2]][:NTRAIN, :]
    scatter_im = np.stack((tmp1, tmp2, tmp3), axis=-1)
    
    # Combine real and imaginary parts
    scatter = np.stack((scatter_re, scatter_im), axis=-2).astype('float32')

mean0, std0 = np.mean(scatter[:,:,:,0]), np.std(scatter[:,:,:,0])
mean1, std1 = np.mean(scatter[:,:,:,1]), np.std(scatter[:,:,:,1])
mean2, std2 = np.mean(scatter[:,:,:,2]), np.std(scatter[:,:,:,2])

scatter[:,:,:,0] -= mean0
scatter[:,:,:,0] /= std0
scatter[:,:,:,1] -= mean1
scatter[:,:,:,1] /= std1
scatter[:,:,:,2] -= mean2
scatter[:,:,:,2] /= std2

# Clean up temporary variables to free memory
del scatter_re, scatter_im, tmp1, tmp2, tmp3


In [None]:
eta_train = eta_re.reshape(-1, 80, 80, 1)
scatter_train = np.swapaxes(scatter.reshape(-1, 80, 80, 2, 3),1,2).reshape(-1, 6400, 2, 3)


In [None]:
batch_size = 16
dict_data = {"x": eta_train}
dict_data["cond"] = {"channel:scatter0": scatter_train[:,:,:,0],
                     "channel:scatter1": scatter_train[:,:,:,1],
                     "channel:scatter2": scatter_train[:,:,:,2]}
dataset = tf.data.Dataset.from_tensor_slices(dict_data)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
dataset = dataset.as_numpy_iterator()


# Architecture

The architecture is similar to the unconditional case. We provide additional args that specify how to resize the conditioning signal (in order to be compatible with the noisy sample for channel-wise concatenation).

In [None]:
# computing analytical back-scattering operator
F_adj_3, F_adj_5, F_adj_10 = utils.compute_F_adj(2.5), utils.compute_F_adj(5), utils.compute_F_adj(10)
fstarlist = [fstars.analytical_fstar(F_adj_3), fstars.analytical_fstar(F_adj_5), fstars.analytical_fstar(F_adj_10)]


In [None]:
cond_denoiser_model = fstar_cnn.PreconditionedDenoiser(
    fstars=fstarlist,
    out_channels=1,
    squeeze_ratio=8,
    cond_embed_iter=10, 
    noise_embed_dim=96, 
    num_conv=8,
    num_feature=96, # multiples of 32
)

In [None]:
diffusion_scheme = dfn_lib.Diffusion.create_variance_preserving(
    sigma=dfn_lib.tangent_noise_schedule(),
    data_std=1, # we always use normalized data
)

cond_model = dfn.DenoisingModel(
    input_shape=(80,80,1),
    cond_shape={"channel:scatter0": (6400,2),
                "channel:scatter1": (6400,2),
                "channel:scatter2": (6400,2)},
    denoiser=cond_denoiser_model,
    noise_sampling=dfn_lib.time_uniform_sampling(
        diffusion_scheme, clip_min=1e-4, uniform_grid=True,
    ),
    noise_weighting=dfn_lib.edm_weighting(data_std=1),
)


In [None]:
epochs = 100
num_train_steps = 21000 * epochs // 16  #@param
cond_workdir = os.path.abspath('') + "/tmp/analytical_cnn_10hsquares"
initial_lr = 1e-5 #@param
peak_lr = 5e-4 #@pawram
warmup_steps = num_train_steps // 20  #@param
end_lr = 1e-8 #@param
ema_decay = 0.999  #@param
ckpt_interval = 2000 #@param
max_ckpt_to_keep = 3 #@param


In [None]:
def count_params(params, parent_name=''):
    """ Recursively count the number of parameters in the JAX model. """
    total_params = 0
    for key, value in params.items():
        layer_name = f"{parent_name}/{key}" if parent_name else key
        if isinstance(value, dict):
            # Recurse into nested dictionary
            layer_params = count_params(value, layer_name)
            total_params += layer_params
        else:
            # Assume value is a parameter array
            layer_params = value.size
            total_params += layer_params
            print(f"Layer: {layer_name}, Parameters: {layer_params}")
    return total_params
    
rng = jax.random.PRNGKey(888)
params = cond_model.initialize(rng)
total_parameters = count_params(params)
print(f"Total parameters in the model: {total_parameters}")


# Training

In [None]:
cond_trainer = dfn.DenoisingTrainer(
    model=cond_model,
    rng=jax.random.PRNGKey(888),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
    ema_decay=ema_decay,
)


In [None]:
templates.run_train(
    train_dataloader=dataset,
    trainer=cond_trainer,
    workdir=cond_workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        cond_workdir, asynchronous=False
    ),
    metric_aggregation_steps = 100,
    callbacks=(
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
        ),
        templates.TrainStateCheckpoint(
            base_dir=cond_workdir,
            options=ocp.CheckpointManagerOptions( 
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)


# Inference

In [None]:
trained_state = dfn.DenoisingModelTrainState.restore_from_orbax_ckpt(
    f"{cond_workdir}/checkpoints", step=None
)

# Construct the inference function
cond_denoise_fn = dfn.DenoisingTrainer.inference_fn_from_state_dict(
    trained_state, use_ema=True, denoiser=cond_denoiser_model
)


In [None]:
cond_sampler = dfn_lib.SdeSampler(
    input_shape=(80,80,1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.exponential_noise_decay(diffusion_scheme, num_steps=256, end_sigma=1e-3,),
    scheme=diffusion_scheme,
    denoise_fn=cond_denoise_fn,
    guidance_transforms=(),
    apply_denoise_at_end=True,
    return_full_paths=False,
)


We again JIT the generate function for the sake of faster repeated sampling calls. Here we employ `functools.partial` to specify `num_samples=5`, making it easier to vectorize across the batch dimension with `jax.vmap`.

In [None]:
num_samples_per_cond = 50 # Choose the number of samples for each condition

generate = jax.jit(
    functools.partial(cond_sampler.generate, num_samples_per_cond)
)


In [None]:
NTEST = 500 # Choose the number test points


Loading a test batch of conditions with 4 elements:

In [None]:
name = '../data/10hsquares_testdata'

# Loading and preprocessing perturbation data (eta)
with h5py.File(f'{name}/eta.h5', 'r') as f:
    # Read eta data, apply Gaussian blur, and reshape
    eta_re = f[list(f.keys())[0]][:NTEST, :].reshape(-1, neta, neta)
    blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
    eta_re = np.stack([blur_fn(img.T) for img in eta_re]).astype('float32')

# Loading and preprocessing scatter data (Lambda)
with h5py.File(f'{name}/scatter_order_8.h5', 'r') as f:
    keys = natsort.natsorted(f.keys())

    # Process real part of scatter data
    tmp1 = f[keys[3]][:NTEST, :]
    tmp2 = f[keys[4]][:NTEST, :]
    tmp3 = f[keys[5]][:NTEST, :]
    scatter_re = np.stack((tmp1, tmp2, tmp3), axis=-1)

    # Process imaginary part of scatter data
    tmp1 = f[keys[0]][:NTEST, :]
    tmp2 = f[keys[1]][:NTEST, :]
    tmp3 = f[keys[2]][:NTEST, :]
    scatter_im = np.stack((tmp1, tmp2, tmp3), axis=-1)
    
    # Combine real and imaginary parts
    scatter = np.stack((scatter_re, scatter_im), axis=-2).astype('float32')
    
scatter[:,:,:,0] -= mean0
scatter[:,:,:,0] /= std0
scatter[:,:,:,1] -= mean1
scatter[:,:,:,1] /= std1
scatter[:,:,:,2] -= mean2
scatter[:,:,:,2] /= std2

# Clean up temporary variables to free memory
del scatter_re, scatter_im, tmp1, tmp2, tmp3


In [None]:
eta_test = eta_re.reshape(-1, 80, 80, 1)
scatter_test = np.swapaxes(scatter.reshape(-1, 80, 80, 2, 3),1,2).reshape(-1, 6400, 2, 3)


In [None]:
batch_size_test = 10
dict_data_test = {}
dict_data_test["cond"] = {"channel:scatter0": scatter_test[:,:,:,0],
                          "channel:scatter1": scatter_test[:,:,:,1],
                          "channel:scatter2": scatter_test[:,:,:,2]}

dataset_test = tf.data.Dataset.from_tensor_slices(dict_data_test)
dataset_test = dataset_test.batch(batch_size_test)
dataset_test = dataset_test.prefetch(tf.data.AUTOTUNE)
dataset_test = dataset_test.as_numpy_iterator()


In [None]:
eta_pred = np.zeros((NTEST, num_samples_per_cond, neta, neta, 1))

b = 0
for batch in dataset_test:
    print(b)
    cond_samples = jax.device_get(jax.vmap(generate, in_axes=(0, 0, None))(
        jax.random.split(jax.random.PRNGKey(888), batch_size_test),
        batch["cond"],
        None,  # Guidance inputs = None since no guidance transforms involved
    ))
    eta_pred[b*batch_size_test:(b+1)*batch_size_test,:,:,:,:] = cond_samples*std_eta+mean_eta[:, :, jnp.newaxis]
    b += 1


In [None]:
errors = []
for i in range(NTEST):
    for j in range(num_samples_per_cond):
        errors.append(np.linalg.norm(eta_test[i,:,:,0]-eta_pred[i,j,:,:,0])/np.linalg.norm(eta_test[i,:,:,0]))
        
print('Mean of validation relative l2 error:', np.mean(errors))
print('Median of validation relative l2 error:', np.median(errors))
print('Min of validation relative l2 error:', np.min(errors))
print('Max of validation relative l2 error:', np.max(errors))
print('Standard deviation of validation relative l2 errors:', np.std(errors))


In [None]:
#with h5py.File("results_analytical_cnn_squares.h5", "w") as f:
#    f.create_dataset('eta', data=eta_test)
#    f.create_dataset('eta_pred', data=eta_pred)