# Installation

In [None]:
!pip install git+https://github.com/borongzhang/back_projection_diffusion.git@main

In [None]:
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 

# Imports

In [None]:
from back_projection_diffusion.src import utils, fstars, fstar_cnn

import functools
import os
from clu import metric_writers
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import optax
import orbax.checkpoint as ocp

import h5py
import natsort
import tensorflow as tf
from scipy.ndimage import geometric_transform
from scipy.ndimage import gaussian_filter

from swirl_dynamics import templates
from swirl_dynamics.lib import diffusion as dfn_lib
from swirl_dynamics.lib iport solvers as solver_lib
from swirl_dynamics.projects import probabilistic_diffusion as dfn


In [None]:
# To avoid tf to use GPU memory
tf.config.set_visible_devices([], device_type='GPU')


# Dataset

In [None]:
# Parameters for the computational task.

L = 4 # number of levels (even number)
s = 5 # leaf size
r = 3 # rank

# Discretization of Omega (n_eta * n_eta).
neta = (2**L)*s

# Number of sources/detectors (n_sc).
# Discretization of the domain of alpha in polar coordinates (n_theta * n_rho).
# For simplicity, these values are set equal (n_sc = n_theta = n_rho), facilitating computation.
nx = (2**L)*s

# Standard deviation for the Gaussian blur.
blur_sigma = 0.5

# Number of training datapoints.
ntrain = 21000
batch_size_train = 16

# Number of testing datapoints.
ntest = 500
batch_size_test = 10


In [None]:
training_eta_path = os.path.abspath('../..') + '/data/10hsquares_trainingdata/eta.h5'
training_scatter_path = os.path.abspath('../..') + '/data/10hsquares_trainingdata/scatter.h5'
eta_train, mean_eta, std_eta = utils.load_eta_data(training_eta_path, ntrain, blur_sigma=0.5, normalize=True)
scatter_train, norm_constants = utils.load_scatter_data(training_scatter_path, ntrain, scatter_norm_constants=None)

test_eta_path = os.path.abspath('../..') + '/data/10hsquares_testdata/eta.h5'
test_scatter_path = os.path.abspath('../..') + '/data/10hsquares_testdata/scatter_order_8.h5'
eta_test = utils.load_eta_data(test_eta_path, ntest, blur_sigma=0.5, normalize=False)
scatter_test = utils.load_scatter_data(test_scatter_path, ntest, scatter_norm_constants=norm_constants)


In [None]:
eta_train = eta_train.reshape(-1, 80, 80, 1)
scatter_train = scatter_train.reshape(-1, 6400, 2, 3) 
dataset = utils.create_dataset(eta_train, scatter_train, batch_size=batch_size_train, repeat=True)

eta_test = eta_test.reshape(-1, 80, 80, 1)
scatter_test = scatter_test.reshape(-1, 6400, 2, 3)
c = 0.0 # percentage of noise to add
scatter_test += np.random.normal(0, c, size = scatter_test.shape)
dataset_test = utils.create_dataset(eta_test, scatter_test, batch_size=batch_size_test, repeat=False)



# Architecture

In [None]:
L1 = nx
L2x = nx
L2y = nx
Nw1 = 20
Nb1 = L1 // Nw1
Nw2x = 10
Nw2y = 10
Nb2x = L2x // Nw2x
Nb2y = L2y // Nw2y
r = 3  # rank 


In [None]:
# a list of NN approximations of the back-scattering operator for each frequency 
# n_freq can be changed based on how many frequencies the data has.
n_freq = 3

fstarlist = [fstars.SwitchNetFstar( 
    L1=L1, L2x=L2x, L2y=L2y, Nw1=Nw1, Nb1=Nb1, 
    Nw2x=Nw2x, Nw2y=Nw2y, Nb2x=Nb2x, Nb2y=Nb2y, 
    r=r
) for i in range(n_freq)]


In [None]:
cond_denoiser_model = fstar_cnn.PreconditionedDenoiser(
    fstars=fstarlist,
    out_channels=1,
    squeeze_ratio=8,
    cond_embed_iter=10, 
    noise_embed_dim=96, 
    num_conv=8,
    num_feature=96, # multiples of 32
)


In [None]:
diffusion_scheme = dfn_lib.Diffusion.create_variance_preserving(
    sigma=dfn_lib.tangent_noise_schedule(),
    data_std=1, # we always use normalized data
)

cond_model = dfn.DenoisingModel(
    input_shape=(80,80,1),
    cond_shape={"channel:scatter0": (6400,2),
                "channel:scatter1": (6400,2),
                "channel:scatter2": (6400,2)},
    denoiser=cond_denoiser_model,
    noise_sampling=dfn_lib.time_uniform_sampling(
        diffusion_scheme, clip_min=1e-4, uniform_grid=True,
    ),
    noise_weighting=dfn_lib.edm_weighting(data_std=1),
)


In [None]:
rng = jax.random.PRNGKey(888)
params = cond_model.initialize(rng)
total_parameters = utils.count_params(params)
print(f"Total parameters in the model: {total_parameters}")


# Training

In [None]:
epochs = 100
num_train_steps = 21000 * epochs // 16  #@param
cond_workdir = os.path.abspath('..') + "/tmp/switchnet_cnn_10hsquares"
initial_lr = 1e-5 #@param
peak_lr = 1e-3 #@pawram
warmup_steps = num_train_steps // 20  #@param
end_lr = 1e-8 #@param
ema_decay = 0.999  #@param
ckpt_interval = 2000 #@param
max_ckpt_to_keep = 3 #@param


In [None]:
cond_trainer = dfn.DenoisingTrainer(
    model=cond_model,
    rng=jax.random.PRNGKey(888),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
    ema_decay=ema_decay,
)

In [None]:
templates.run_train(
    train_dataloader=dataset,
    trainer=cond_trainer,
    workdir=cond_workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        cond_workdir, asynchronous=False
    ),
    metric_aggregation_steps = 100,
    callbacks=(
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
        ),
        templates.TrainStateCheckpoint(
            base_dir=cond_workdir,
            options=ocp.CheckpointManagerOptions( 
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)

# Inference

In [None]:
trained_state = dfn.DenoisingModelTrainState.restore_from_orbax_ckpt(
    f"{cond_workdir}/checkpoints", step=None
)

# Construct the inference function
cond_denoise_fn = dfn.DenoisingTrainer.inference_fn_from_state_dict(
    trained_state, use_ema=True, denoiser=cond_denoiser_model
)

In [None]:
cond_sampler = dfn_lib.SdeSampler(
    input_shape=(80,80,1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.exponential_noise_decay(diffusion_scheme, num_steps=256, end_sigma=1e-3),
    scheme=diffusion_scheme,
    denoise_fn=cond_denoise_fn,
    guidance_transforms=(),
    apply_denoise_at_end=True,
    return_full_paths=False,
)


In [None]:
num_samples_per_cond = 10

generate = jax.jit(
    functools.partial(cond_sampler.generate, num_samples_per_cond)
)


In [None]:
eta_pred = np.zeros((ntest, num_samples_per_cond, neta, neta, 1))

b = 0
for batch in dataset_test:
    print(f"\rProcessing batch {b + 1} / {ntest//batch_size_test}", end='', flush=True)
    cond_samples = jax.device_get(jax.vmap(generate, in_axes=(0, 0, None))(
        jax.random.split(jax.random.PRNGKey(68), batch_size_test),
        batch["cond"],
        None,  # Guidance inputs = None since no guidance transforms involved  
    ))
    eta_pred[b*batch_size_test:(b+1)*batch_size_test,:,:,:,:] = cond_samples*std_eta+mean_eta[:, :, jnp.newaxis]
    b += 1


In [None]:
errors = []
for i in range(ntest):
    errors.append(np.linalg.norm(eta_test[i,:,:,0]-eta_pred[i,0,:,:,0])/np.linalg.norm(eta_test[i,:,:,0]))
        
print('Mean of validation relative l2 error:', np.mean(errors))
print('Median of validation relative l2 error:', np.median(errors))
print('Min of validation relative l2 error:', np.min(errors))
print('Max of validation relative l2 error:', np.max(errors))
print('Standard deviation of validation relative l2 errors:', np.std(errors))


In [None]:
#with h5py.File("results_switchnet_cnn_10hsquares.h5", "w") as f:
#    f.create_dataset('eta', data=eta_test)
#    f.create_dataset('eta_pred', data=eta_pred)