# Imports

In [1]:
import functools
import os
import time
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax.experimental import sparse
jax_device = jax.devices("gpu")[0]
jax.config.update("jax_default_device", jax_device)

import optax
import orbax.checkpoint as ocp
from clu import metric_writers

import h5py
import natsort
import tensorflow as tf
from scipy.ndimage import geometric_transform
from scipy.ndimage import gaussian_filter

In [2]:
from ISP_baseline.src import models, trainers, utils
from ISP_baseline.models import Compressed, Uncompressed 

from swirl_dynamics import templates
from swirl_dynamics.lib import metrics
from pysteps.utils.spectral import rapsd

Pysteps configuration file found at: /share/data/willett-group/oortsang/miniconda/envs/jaxisp-v3/lib/python3.13/site-packages/pysteps/pystepsrc



In [3]:
# For use with our MFISNet-style dataset
import os
from ISP_baseline.src.data_io import (
    load_hdf5_to_dict,
    load_cart_multifreq_dataset,
    load_single_dir_slice,
    load_multi_dir_slice,
    get_multifreq_dset_dirs,
)
from ISP_baseline.src.datasets import (
    convert_mfisnet_data_dict,
    setup_tf_dataset,
    get_io_mean_std,
)
from ISP_baseline.src.more_metrics import (
    l2_error,
)

In [4]:
# To avoid tf to use GPU memory
tf.config.set_visible_devices([], device_type='GPU')

## Load the dataset

In [5]:
rlc_repo = os.path.join("/", "home-nfs", "oortsang", "rlc-repo")
dataset_dir = os.path.join(rlc_repo, "dataset")

In [6]:
L = 4 # number of levels (even number)
s = 12 # leaf size for N_x = 192
# s = 5 # leaf size
# s = 6
# r = 3 # rank
r = 3 # rank

downsample_ratio = 1
s = s // downsample_ratio

# Discretization of Omega (n_eta * n_eta).
neta = (2**L)*s

# Number of sources/detectors (n_sc).
# Discretization of the domain of alpha in polar coordinates (n_theta * n_rho).
# For simplicity, these values are set equal (n_sc = n_theta = n_rho), facilitating computation.
nx = (2**L)*s

# Standard deviation for the Gaussian blur.
blur_sigma = 0.5

# Batch size.
batch_size = 16

# Number of training datapoints.
# NTRAIN = 21000
# NTRAIN = 2000 # Reduced for debugging purposes
NTRAIN = 1000
NVAL   = 1000

In [7]:
# kbar_str_list = ["2.5", "5", "10"]
kbar_str_list = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
nk = len(kbar_str_list)

In [8]:
train_dirs = get_multifreq_dset_dirs(
    "train",
    kbar_str_list,
    base_dir=dataset_dir,
    dir_fmt="{0}_measurements_nu_{1}"
)
train_mfisnet_dd = load_cart_multifreq_dataset(
    train_dirs,
    global_idx_start=0,
    global_idx_end=NTRAIN,
)
print(f"Loaded: {', '.join([f'{key}{val.shape}' for (key, val) in train_mfisnet_dd.items()])}")
train_wb_dd = convert_mfisnet_data_dict(
    train_mfisnet_dd,
    scatter_as_real=True,
    real_imag_axis=2,
    blur_sigma=blur_sigma,
    downsample_ratio=downsample_ratio,
    flip_scobj_axes=True,
)

train_eta     = train_wb_dd["eta"]
train_scatter = train_wb_dd["scatter"]
# Get mean/std for each
(
    train_scatter_mean,
    train_scatter_std,
    train_eta_mean,
    train_eta_std
) = get_io_mean_std(train_scatter, train_eta)

Loaded: x_vals(192,), q_cart(1000, 192, 192), sample_completion(1000,), d_rs(1000, 10, 192, 192)


In [9]:
train_dataset, train_dloader = setup_tf_dataset(
    train_eta,
    train_scatter,
    batch_size=batch_size,
    repeats=True,
)

In [10]:
val_dirs = get_multifreq_dset_dirs(
    "val",
    kbar_str_list,
    base_dir=dataset_dir,
    dir_fmt="{0}_measurements_nu_{1}"
)
val_mfisnet_dd = load_cart_multifreq_dataset(
    val_dirs,
    global_idx_start=0,
    global_idx_end=NVAL,
)
print(f"Loaded: {', '.join([f'{key}{val.shape}' for (key, val) in val_mfisnet_dd.items()])}")
val_wb_dd = convert_mfisnet_data_dict(
    val_mfisnet_dd,
    scatter_as_real=True,
    real_imag_axis=2,
    blur_sigma=blur_sigma,
    downsample_ratio=downsample_ratio,
    flip_scobj_axes=True,
)
# Try downsampling since the sparsepolartocartesian step is so slow :((
val_eta     = val_wb_dd["eta"]
val_scatter = val_wb_dd["scatter"]

Loaded: x_vals(192,), q_cart(1000, 192, 192), sample_completion(1000,), d_rs(1000, 10, 192, 192)


In [11]:
val_batch_size = 16
val_dataset, val_dloader = setup_tf_dataset(
    val_eta,
    val_scatter,
    batch_size=val_batch_size,
    repeats=True,
)

In [12]:
train_eta.shape, train_scatter.shape

((1000, 192, 192), (1000, 36864, 2, 10))

## Architecture

In [13]:
N_resnet_layers = 6
N_cnn_layers = 9
N_cnn_channels = 6
# N_cnn_channels = 12
kernel_size = 5
io_norm = False

In [14]:
# utils.load_mats_from_fp(os.path.join("tmp", "cart_and_rot_mats", f"mats_neta{neta}_nx{nx}.npz"))

In [15]:
%%time
from datetime import datetime
print(f"Starting at {datetime.now()}...")
cart_mat, r_index = utils.load_or_create_mats(
    neta,
    nx,
    mats_dir=os.path.join("tmp", "cart_and_rot_mats"),
    mats_format="mats_neta{0}_nx{1}.npz",
    save_if_created=True,
)

# cart_mat = utils.SparsePolarToCartesian(neta, nx)
# r_index  = utils.rotationindex(nx)

Starting at 2025-10-28 13:30:54.929428...
CPU times: user 164 ms, sys: 292 ms, total: 456 ms
Wall time: 455 ms


In [16]:
# import tornado, asyncio
# print(tornado.version)

In [17]:
# asyncio.__name__

In [18]:
# core_module = Uncompressed.UncompressedModelFlexible(
#     nx = nx,
#     neta = neta,
#     cart_mat = cart_mat,
#     r_index = r_index,
#     # New parameters
#     nk=nk,
#     N_cnn_layers=N_cnn_layers,
#     N_cnn_channels=N_cnn_channels,
#     kernel_size=kernel_size,
# )

core_module = Compressed.CompressedModelFlexible(
    L = L,
    s = s,
    r = r,
    # NUM_RESNET = 6,
    # NUM_CONV = 9,
    cart_mat = cart_mat,
    r_index = r_index,
    nk=nk,
    # Architecture (other than r)
    N_resnet_layers=N_resnet_layers,
    N_cnn_layers=N_cnn_layers,
    N_cnn_channels=N_cnn_channels,
    grad_checkpoint=True,
    # I/O Normalization?
    in_norm=io_norm,
    out_norm=io_norm,
    in_mean=jnp.array(train_scatter_mean),
    in_std=jnp.array(train_scatter_std),
    out_mean=jnp.array(train_eta_mean),
    out_std=jnp.array(train_eta_std),
)

In [19]:
Model = models.DeterministicModel(
    input_shape = train_scatter[0].shape,
    core_module = core_module
)

In [20]:
rng = jax.random.PRNGKey(888)
params = Model.initialize(rng)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print('Number of trainable parameters:', param_count)

Number of trainable parameters: 270595


In [21]:
def recursive_shape(x, depth_left=100):
    if isinstance(x, dict):
        return {
            k: recursive_shape(v, depth_left=depth_left-1)
            for (k,v) in x.items()
        }
    elif isinstance(x, list):
        return [recursive_shape(x_i, depth_left=depth_left-1) for x_i in x]
    elif isinstance(x, tuple):
        return tuple([recursive_shape(x_i, depth_left=depth_left-1) for x_i in x])
    elif isinstance(x, np.ndarray) or isinstance(x, jnp.ndarray):
        return x.shape
    else:
        raise ValueError(f"Unhandled type {type(x)}")
print(recursive_shape(params["params"]["fstar_layers_0"]))
print(recursive_shape(params["params"]["convs_8"]))

{'Gs_0': {'gi1': (8, 6, 6), 'gi2': (8, 6, 6), 'gi3': (8, 6, 6), 'gi4': (8, 6, 6), 'gr1': (8, 6, 6), 'gr2': (8, 6, 6), 'gr3': (8, 6, 6), 'gr4': (8, 6, 6)}, 'Gs_1': {'gi1': (8, 6, 6), 'gi2': (8, 6, 6), 'gi3': (8, 6, 6), 'gi4': (8, 6, 6), 'gr1': (8, 6, 6), 'gr2': (8, 6, 6), 'gr3': (8, 6, 6), 'gr4': (8, 6, 6)}, 'Hs_0': {'hi1': (8, 6, 6), 'hi2': (8, 6, 6), 'hi3': (8, 6, 6), 'hi4': (8, 6, 6), 'hr1': (8, 6, 6), 'hr2': (8, 6, 6), 'hr3': (8, 6, 6), 'hr4': (8, 6, 6)}, 'Hs_1': {'hi1': (8, 6, 6), 'hi2': (8, 6, 6), 'hi3': (8, 6, 6), 'hi4': (8, 6, 6), 'hr1': (8, 6, 6), 'hr2': (8, 6, 6), 'hr3': (8, 6, 6), 'hr4': (8, 6, 6)}, 'Ms_0': {'mi1': (16, 3, 3), 'mi2': (16, 3, 3), 'mi3': (16, 3, 3), 'mi4': (16, 3, 3), 'mr1': (16, 3, 3), 'mr2': (16, 3, 3), 'mr3': (16, 3, 3), 'mr4': (16, 3, 3)}, 'Ms_1': {'mi1': (16, 3, 3), 'mi2': (16, 3, 3), 'mi3': (16, 3, 3), 'mi4': (16, 3, 3), 'mr1': (16, 3, 3), 'mr2': (16, 3, 3), 'mr3': (16, 3, 3), 'mr4': (16, 3, 3)}, 'Ms_2': {'mi1': (16, 3, 3), 'mi2': (16, 3, 3), 'mi3': (16, 

In [22]:
train_scatter[0].shape

(36864, 2, 10)

In [23]:
type(cart_mat)

jax.experimental.sparse.bcoo.BCOO

In [24]:
type(r_index)

jaxlib.xla_extension.ArrayImpl

## Training

In [25]:
epochs = 300
# epochs = 1
num_train_steps = NTRAIN * epochs // 16  #@param
workdir = os.path.abspath('') + f"/tmp/2025-10-19_compressed_nx_{nx}_nk_{nk}"  #@param
if os.path.exists(workdir):
    import shutil
    shutil.rmtree(workdir)
# initial_lr = 1e-5 #@param
initial_lr = 1e-4
peak_lr = 5e-3 #@pawram
warmup_steps = num_train_steps // 20  #@param
end_lr = 1e-8 #@param
ckpt_interval = 2000  #@param
max_ckpt_to_keep = 3  #@param

In [26]:
trainer = trainers.DeterministicTrainer(
    model=Model,
    rng=jax.random.PRNGKey(42),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
)

In [27]:
# untrained_inference_fn = trainers.DeterministicTrainer.build_inference_fn(
#     trainer.train_state, core_module
# )

# print(f"Memory usage associated with a single forward pass...")
# _ = utils.get_memory_info_jax(device=jax_device, print_msg=True)
# # pre-run:     244
# # after 16:   2694
# # after 32:   5286
# # after 64:  10469
# # after 128: 20865
# trial_batch_size = 16

# # first_batch = next(train_dloader)
# trial_input = train_wb_dd["scatter"][:trial_batch_size]
# trial_output = untrained_inference_fn(trial_input)

# _ = utils.get_memory_info_jax(device=jax_device, print_msg=True)

# # Okay, so running the forward pass is not a (big) problem

In [28]:
# _ = utils.get_memory_info_jax(device=jax_device, print_msg=True)
# def trial_loss_fn(params, x):
#     return jnp.linalg.norm(core_module.apply(params, x))
# trial_loss_fn_cp = jax.checkpoint(trial_loss_fn)
# grads = jax.grad(trial_loss_fn_cp)(params, trial_input)
# # grads = jax.grad(trial_loss_fn)(params, trial_input)
# _ = utils.get_memory_info_jax(device=jax_device, print_msg=True)
# print(grads["params"].keys())
# _ = utils.get_memory_info_jax(device=jax_device, print_msg=True)

In [29]:
_ = utils.get_memory_info_jax(device=jax_device, print_msg=True)

RAM Used (MB): 51323
VRAM (MB): 30695 free of 30740 total (within preallocation); usage is currently 44 and peaked at 349


In [31]:
# trial_input.shape

# sdfasdfasd

In [32]:
utils.get_memory_info_jax(jax_device, print_msg=True)

# eval_dloader = train_dloader
eval_dloader = val_dloader
templates.run_train(
    train_dataloader=train_dloader,
    trainer=trainer,
    workdir=workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        workdir, asynchronous=False
    ),
    metric_aggregation_steps=10,
    eval_dataloader=eval_dloader,
    eval_every_steps = 100,
    num_batches_per_eval = 1,
    callbacks=(
        # templates.ProgressReport(
        #     total_train_steps=num_train_steps,
        #     every_steps=None,
        #     eval_secs=60,
        # ),
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
            eval_monitors=("eval_rrmse_mean", "eval_rel_l2_mean"),
        ),
        templates.TrainStateCheckpoint(
            base_dir=workdir,
            options=ocp.CheckpointManagerOptions(
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)
_ = utils.get_memory_info_jax(jax_device, print_msg=True)

RAM Used (MB): 50423
VRAM (MB): 30695 free of 30740 total (within preallocation); usage is currently 44 and peaked at 349


  0%|          | 0/18750 [00:00<?, ?step/s]

2025-10-28 13:35:17.749808: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 30.00GiB (rounded to 32209323776)requested by op 
2025-10-28 13:35:17.755161: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] *___________________________________________________________________________________________________
E1028 13:35:17.755900 1052084 pjrt_stream_executor_client.cc:2839] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 32209323688 bytes. [tf-allocator-allocation-error='']


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 32209323688 bytes.

## Inference

In [None]:
trained_state = trainers.TrainState.restore_from_orbax_ckpt(
    f"{workdir}/checkpoints", step=None
)

inference_fn = trainers.DeterministicTrainer.build_inference_fn(
    trained_state, core_module
)

In [None]:
NTEST = 100

test_dirs = get_multifreq_dset_dirs(
    "test",
    kbar_str_list,
    base_dir=dataset_dir,
    dir_fmt="{0}_measurements_nu_{1}"
)

test_mfisnet_dd = load_cart_multifreq_dataset(
    test_dirs,
    global_idx_start=0,
    global_idx_end=NTEST,
)
print(f"Loaded: {', '.join([f'{key}{val.shape}' for (key, val) in test_mfisnet_dd.items()])}")
test_wb_dd = convert_mfisnet_data_dict(
    test_mfisnet_dd, 
    blur_sigma=blur_sigma,
    downsample_ratio=downsample_ratio,
    flip_scobj_axes=True,
)

# Try downsampling since the sparsepolartocartesian step is so slow :((
test_eta     = test_wb_dd["eta"] # [..., ::2, ::2]
test_scatter = test_wb_dd["scatter"] # [..., ::2, ::2, :]

In [None]:
test_batch_size = 16
test_dataset, test_dloader = setup_tf_dataset(
    test_eta,
    test_scatter,
    # val_eta,
    # val_scatter,
    batch_size=test_batch_size,
)

In [None]:
validation_errors_rrmse = []
validation_errors_rel_l2 = []
validation_errors_rapsd = []
pred_eta = np.zeros(test_eta.shape)

rrmse = functools.partial(
    metrics.mean_squared_error,
    sum_axes=(-1, -2),
    relative=True,
    squared=False,
)
rel_l2 = functools.partial(
    l2_error,
    l2_axes=(-1, -2),
    relative=True,
    squared=False,
)

# for b, batch in enumerate(test_dloader):
for b, batch in enumerate(val_dloader):
    # pred = inference_fn(batch[0])
    pred = inference_fn(batch["scatter"])
    batch_slice = np.s_[b*val_batch_size: (b+1)*val_batch_size, :, :]
    pred_eta[batch_slice] = pred
    true = batch["eta"]
    validation_errors_rrmse.append(rrmse(pred=pred, true=true))
    validation_errors_rel_l2.append(rel_l2(pred=pred, true=true))
    for i in range(true.shape[0]):
        validation_errors_rapsd.append(np.abs(np.log(rapsd(pred[i],fft_method=np.fft)/rapsd(true[i],fft_method=np.fft))))

print(f"Mean rel l2 error: {np.mean(validation_errors_rel_l2)*100:.3f}%")
print('relative root-mean-square error = %.3f' % (np.mean(validation_errors_rrmse)*100), '%') 
print('mean energy log ratio = %.3f' % np.mean(validation_errors_rapsd)) 

In [None]:
# test_data_path = os.path.abspath('../..') + '/data/10hsquares_testdata'
# test_data_path = os.path.join("data", "testdata")

# with h5py.File(f'{test_data_path}/eta.h5', 'r') as f:
#     # Read eta data, apply Gaussian blur, and reshape
#     eta_re = f[list(f.keys())[0]][:, :].reshape(-1, neta, neta)
#     blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
#     eta_test = np.stack([blur_fn(img.T) for img in eta_re]).astype('float32')
    
# # Loading and preprocessing scatter data (Lambda)
# # with h5py.File(f'{test_data_path}/scatter_order_8.h5', 'r') as f:
# with h5py.File(f'{test_data_path}/scatter.h5', 'r') as f:
#     keys = natsort.natsorted(f.keys())

#     # Process real part of scatter data
#     tmp1 = f[keys[3]][:, :]
#     tmp2 = f[keys[4]][:, :]
#     tmp3 = f[keys[5]][:, :]
#     scatter_re = np.stack((tmp1, tmp2, tmp3), axis=-1)

#     # Process imaginary part of scatter data
#     tmp1 = f[keys[0]][:, :]
#     tmp2 = f[keys[1]][:, :]
#     tmp3 = f[keys[2]][:, :]
#     scatter_im = np.stack((tmp1, tmp2, tmp3), axis=-1)
    
#     # Combine real and imaginary parts
#     scatter_test = np.stack((scatter_re, scatter_im), axis=1).astype('float32')
    
# # Clean up temporary variables to free memory
# del scatter_re, scatter_im, tmp1, tmp2, tmp3
