# Imports

In [1]:
import functools
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"

from clu import metric_writers
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental import sparse
import matplotlib.pyplot as plt
import optax
import orbax.checkpoint as ocp

import h5py
import natsort
import tensorflow as tf
from scipy.ndimage import geometric_transform
from scipy.ndimage import gaussian_filter

In [2]:
from ISP_baseline.src import models, trainers, utils
from ISP_baseline.models import Uncompressed 

from swirl_dynamics import templates
from swirl_dynamics.lib import metrics
from pysteps.utils.spectral import rapsd

Pysteps configuration file found at: /share/data/willett-group/oortsang/miniconda/envs/jaxisp-v3/lib/python3.13/site-packages/pysteps/pystepsrc



In [3]:
# For use with our MFISNet-style dataset
import os
from ISP_baseline.src.data_io import (
    load_hdf5_to_dict,
    load_cart_multifreq_dataset,
    load_single_dir_slice,
    load_multi_dir_slice,
    get_multifreq_dset_dirs,
)
from ISP_baseline.src.datasets import (
    convert_mfisnet_data_dict,
    setup_tf_dataset,
)

In [4]:
# To avoid tf to use GPU memory
tf.config.set_visible_devices([], device_type='GPU')

## Load the dataset

In [5]:
rlc_repo = os.path.join("/", "home-nfs", "oortsang", "rlc-repo")
dataset_dir = os.path.join(rlc_repo, "dataset")

In [6]:
L = 4 # number of levels (even number)
# s = 12 # leaf size for N_x = 192
# s = 5 # leaf size
# s = 6
s = 6
# r = 3 # rank -- not used??

# Discretization of Omega (n_eta * n_eta).
neta = (2**L)*s

# Number of sources/detectors (n_sc).
# Discretization of the domain of alpha in polar coordinates (n_theta * n_rho).
# For simplicity, these values are set equal (n_sc = n_theta = n_rho), facilitating computation.
nx = (2**L)*s

# Standard deviation for the Gaussian blur.
blur_sigma = 0.5

# Batch size.
batch_size = 16

# Number of training datapoints.
# NTRAIN = 21000
# NTRAIN = 2000 # Reduced for debugging purposes
NTRAIN = 100 # for debugging purposes

In [7]:
# kbar_str_list = ["2.5", "5", "10"]
kbar_str_list = ["2", "5", "10"] # uuhhh I haven't prepared the 2.5 data
nk = len(kbar_str_list)

In [8]:
# get_dset_dirs = lambda dset, kbar_list: [
#     os.path.join(dataset_dir, f"{dset}_train_measurements_nu_{kbar}")
#     for kbar in kbar_list
# ]
# train_dirs = get_dset_dirs("train", kbar_str_list)
train_dirs = get_multifreq_dset_dirs(
    "train",
    kbar_str_list,
    base_dir=dataset_dir,
    dir_fmt="{0}_measurements_nu_{1}"
)

train_mfisnet_dd = load_cart_multifreq_dataset(
    train_dirs,
    global_idx_start=0,
    global_idx_end=NTRAIN,
)
print(f"Loaded: {', '.join([f'{key}{val.shape}' for (key, val) in train_mfisnet_dd.items()])}")
train_wb_dd = convert_mfisnet_data_dict(train_mfisnet_dd, blur_sigma=blur_sigma)
# eta     = train_wb_dd["eta"]
# scatter = train_wb_dd["scatter"]

# Try downsampling since the sparsepolartocartesian step is so slow :((
train_eta     = train_wb_dd["eta"][..., ::2, ::2]
train_scatter = train_wb_dd["scatter"][..., ::2, ::2, :]

Loaded: x_vals(192,), q_cart(100, 192, 192), sample_completion(100,), d_rs(100, 3, 192, 192)


In [9]:
# dd = load_multi_dir_slice(
#     train_dirs,
#     global_idx_start=0,
#     global_idx_end=10,
#     load_keys=["q_cart", "d_rs"],
#     sample_keys=["q_cart", "d_rs"],
#     freq_dep_keys=["d_rs"],
# )
# dd = load_single_dir_slice(
#     train_dirs[-1],
#     global_idx_start=0,
#     global_idx_end=10,
#     load_keys=["q_cart", "d_rs"],
#     ignore_keys=["d_rs"],
#     sample_keys=["q_cart", "d_rs"],
# )
# dd_shapes = [f"{key}{val.shape}" for (key, val) in dd.items()]
# print(f"dd shapes: {', '.join(dd_shapes)}")
# dd["q_cart"].shape

In [10]:
# load_hdf5_to_dict("/home-nfs/oortsang/rlc-repo/dataset/train_measurements_nu_2/measurements_0.h5").keys()

In [11]:
train_eta.shape, train_scatter.shape

((100, 96, 96), (100, 96, 96, 3))

In [12]:
train_dataset, train_dloader = setup_tf_dataset(
    train_eta,
    train_scatter,
    batch_size=batch_size,
)

## Architecture

In [13]:
N_cnn_layers = 3
N_cnn_channels = 6
kernel_size = 3

In [14]:
# utils.load_mats_from_fp(os.path.join("tmp", "cart_and_rot_mats", f"mats_neta{neta}_nx{nx}.npz"))

(BCOO(float32[2304, 2304], nse=67049),
 Array([[   0,    1,    2, ...,   45,   46,   47],
        [  48,   49,   50, ...,   93,   94,   95],
        [  96,   97,   98, ...,  141,  142,  143],
        ...,
        [2159, 2112, 2113, ..., 2156, 2157, 2158],
        [2207, 2160, 2161, ..., 2204, 2205, 2206],
        [2255, 2208, 2209, ..., 2252, 2253, 2254]], dtype=int32))

In [15]:
%%time
cart_mat, r_index = utils.load_or_create_mats(
    neta,
    nx,
    mats_dir=os.path.join("tmp", "cart_and_rot_mats"),
    mats_format="mats_neta{0}_nx{1}.npz",
    save_if_created=True,
)
# cart_mat = utils.SparsePolarToCartesian(neta, nx)
# r_index  = utils.rotationindex(nx)

CPU times: user 2.26 ms, sys: 5.01 ms, total: 7.27 ms
Wall time: 6.97 ms


In [16]:
# import scipy.sparse
# sp_cart_mat = scipy.sparse.coo_array(
#     (cart_mat.data, cart_mat.indices.T),
#     shape=cart_mat.shape,
# )

In [17]:
# from jax.experimental.sparse.bcoo import BCOO

# jsp_cart_mat = BCOO((cart_mat.data, cart_mat.indices), shape=cart_mat.shape)

In [18]:
core_module = Uncompressed.UncompressedModelFlexible(
    nx = nx,
    neta = neta,
    cart_mat = cart_mat,
    r_index = r_index,
    # New parameters
    nk=nk,
    N_cnn_layers=N_cnn_layers,
    N_cnn_channels=N_cnn_channels,
    kernel_size=kernel_size,
)

# core_module = Uncompressed.UncompressedModel(
#     nx = nx,
#     neta = neta,
#     cart_mat = cart_mat,
#     r_index = r_index,
#     # # Doesn't support these
#     # nk=nk,
#     # N_cnn_layers=N_cnn_layers,
#     # N_cnn_channels=N_cnn_channels,
#     # kernel_size=kernel_size,
# )

In [19]:
Model = models.DeterministicModel(
    input_shape = train_scatter[0].shape,
    core_module = core_module
)

In [20]:
rng = jax.random.PRNGKey(888)
params = Model.initialize(rng)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print('Number of trainable parameters:', param_count)

Number of trainable parameters: 70138


In [21]:
train_scatter[0].shape

(96, 96, 3)

In [22]:
type(cart_mat)

jax.experimental.sparse.bcoo.BCOO

In [23]:
type(r_index)

jaxlib.xla_extension.ArrayImpl

## Training

In [24]:
epochs = 100
num_train_steps = NTRAIN * epochs // 16  #@param
workdir = os.path.abspath('') + "/tmp/Uncompressed10squares"  #@param
initial_lr = 1e-5 #@param
peak_lr = 5e-3 #@pawram
warmup_steps = num_train_steps // 20  #@param
end_lr = 1e-8 #@param
ckpt_interval = 2000  #@param
max_ckpt_to_keep = 3  #@param

In [25]:
trainer = trainers.DeterministicTrainer(
    model=Model, 
    rng=jax.random.PRNGKey(42), 
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
)
     

In [26]:
eval_dloader = train_dloader
templates.run_train(
    train_dataloader=train_dloader,
    trainer=trainer,
    workdir=workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        workdir, asynchronous=False
    ),
    metric_aggregation_steps=10,
    eval_dataloader=eval_dloader,
    eval_every_steps = 100,
    num_batches_per_eval = 1,
    callbacks=(
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
            eval_monitors=("eval_rrmse_mean",),
        ),
        templates.TrainStateCheckpoint(
            base_dir=workdir,
            options=ocp.CheckpointManagerOptions(
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)

ValueError: `pred` (16, 48, 48) and `true` (16, 96, 96) must have the same shape.

## Inference

In [None]:
trained_state = trainers.TrainState.restore_from_orbax_ckpt(
    f"{workdir}/checkpoints", step=None
)

inference_fn = trainers.DeterministicTrainer.build_inference_fn(
    trained_state, core_module
)

In [None]:
NTEST = 100

test_dirs = get_multifreq_dset_dirs(
    "test",
    kbar_str_list,
    base_dir=dataset_dir,
    dir_fmt="{0}_measurements_nu_{1}"
)

test_mfisnet_dd = load_cart_multifreq_dataset(
    test_dirs,
    global_idx_start=0,
    global_idx_end=NTEST,
)
print(f"Loaded: {', '.join([f'{key}{val.shape}' for (key, val) in test_mfisnet_dd.items()])}")
test_wb_dd = convert_mfisnet_data_dict(test_mfisnet_dd, blur_sigma=blur_sigma)

# Try downsampling since the sparsepolartocartesian step is so slow :((
test_eta     = test_wb_dd["eta"][..., ::2, ::2]
test_scatter = test_wb_dd["scatter"][..., ::2, ::2, :]

In [None]:
test_dataset, test_dloader = setup_tf_dataset(
    test_eta,
    test_scatter,
    batch_size=batch_size,
)

In [None]:
validation_errors_rrmse = [] 
validation_errors_rapsd = [] 
pred_eta = np.zeros(test_eta.shape)

rrmse = functools.partial(
        metrics.mean_squared_error,
        sum_axes=(-1, -2),
        relative=True,
        squared=False,
    )

b = 0
for batch in test_dloader:
    # pred = inference_fn(batch[0])
    pred = inference_fn(batch["scatter"])
    pred_eta[b*test_batch:(b+1)*test_batch,:,:] = pred
    b += 1
    true = batch[1]
    validation_errors_rrmse.append(rrmse(pred=pred, true=true))
    for i in range(true.shape[0]):
        validation_errors_rapsd.append(np.abs(np.log(rapsd(pred[i],fft_method=np.fft)/rapsd(true[i],fft_method=np.fft))))

print('relative root-mean-square error = %.3f' % (np.mean(validation_errors_rrmse)*100), '%') 
print('mean energy log ratio = %.3f' % np.mean(validation_errors_rapsd)) 

In [None]:
next(test_dloader)

In [None]:
# test_data_path = os.path.abspath('../..') + '/data/10hsquares_testdata'
# test_data_path = os.path.join("data", "testdata")

# with h5py.File(f'{test_data_path}/eta.h5', 'r') as f:
#     # Read eta data, apply Gaussian blur, and reshape
#     eta_re = f[list(f.keys())[0]][:, :].reshape(-1, neta, neta)
#     blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
#     eta_test = np.stack([blur_fn(img.T) for img in eta_re]).astype('float32')
    
# # Loading and preprocessing scatter data (Lambda)
# # with h5py.File(f'{test_data_path}/scatter_order_8.h5', 'r') as f:
# with h5py.File(f'{test_data_path}/scatter.h5', 'r') as f:
#     keys = natsort.natsorted(f.keys())

#     # Process real part of scatter data
#     tmp1 = f[keys[3]][:, :]
#     tmp2 = f[keys[4]][:, :]
#     tmp3 = f[keys[5]][:, :]
#     scatter_re = np.stack((tmp1, tmp2, tmp3), axis=-1)

#     # Process imaginary part of scatter data
#     tmp1 = f[keys[0]][:, :]
#     tmp2 = f[keys[1]][:, :]
#     tmp3 = f[keys[2]][:, :]
#     scatter_im = np.stack((tmp1, tmp2, tmp3), axis=-1)
    
#     # Combine real and imaginary parts
#     scatter_test = np.stack((scatter_re, scatter_im), axis=1).astype('float32')
    
# # Clean up temporary variables to free memory
# del scatter_re, scatter_im, tmp1, tmp2, tmp3
