# Rectified Flow example for Debiasing PDE solutions.

Here we consider the Kuramoto-Sivashinsky equation:
$$\partial_t u = u \partial_x u + \partial_{xx}u - \partial_{xxxx}u,$$
a prime example of a chaotic PDE.

Here considered data steming from two different solvers. For the ground-truth, or reference, dataset we consider data generated by solving the KS equation using a pseudo-spectal solver with a fine mesh. For the biased data we consider data generated using a finite volumes (FV) solver using a much coarser mesh. The data was generated using [jax-cfd](https://github.com/google/jax-cfd) using the same randomly generated initial conditions and with the same warm-up period.

Due to the difference in the resolution and extra dissipativity incurred by the FV discretization, the trajectories computed with FV are "biased" compared to the ones stemming from the pseudo-spectral methods. This can be readily observed when inspecting histograms of the snapshots, in which we observe a different distributions when considering values of the numerical approximation to the same PDE but using diffrent numerical methods.

In this colab we seek to show how to use rectified flows to address the statistical errors of the coarse, FV based solution to match the statitics of the fine, spectrally accurate solutions. This is what we refer to as *debiasing*.

This colab uses the same methodology proposed in [Wan et al](https://arxiv.org/abs/2412.08079) but using a simpler dataset: the dimension is much lower, we only consider one field, there is no climatology, and we suppose that underlying distribution is fixed. In what follows, we seek to walk the reader on how the full pipeline is implemented, including how to define dataloaders, while leveraging the [swirl-dynamics](https://github.com/google-research/swirl-dynamics) library to perform most of the heavy lifting for training and inference.   

### Downloading dependencies.

We use the `swirl-dynamics` library for most of the heavy lifting, so we install it using pip.

In [None]:
!pip install git+https://github.com/google-research/swirl-dynamics.git@main

We also import all the necessary libraries.

In [None]:
import jax
# jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from clu import metric_writers

import abc
from collections.abc import Callable, Mapping, Sequence
import types
from typing import Any, Literal, SupportsIndex

import grain.python as pygrain
import ml_collections
import h5py
import optax
import numpy as np
import matplotlib.pyplot as plt
import flax.linen as nn
import flax
import functools

from orbax import checkpoint

from swirl_dynamics.projects.debiasing.rectified_flow import pygrain_transforms as transforms
from swirl_dynamics.projects.debiasing.rectified_flow import models
from swirl_dynamics.projects.debiasing.rectified_flow import trainers
from swirl_dynamics.templates import callbacks
from swirl_dynamics.templates import train
from swirl_dynamics.lib import layers
from swirl_dynamics.lib.diffusion import unets


# Defines shortened types
ConfigDict = ml_collections.config_dict.config_dict.ConfigDict


### Define Hyper-Parameters

For simplicity we define the parameters inside a `ConfigDict`.

In [None]:
config = ml_collections.ConfigDict()

# Parameters of the trainer. If multidevice host, you can turn this to True.
# The caveat is that the batch_size needs to be mutiple of the num_devices.
config.is_distributed = False
config.seed = 37

# Optimizer parameters.
config.initial_lr = 1e-7
config.peak_lr = 5e-4
config.end_lr = 1e-7
config.warmup_steps = 5_000
config.beta1 = 0.9
config.max_norm = 0.6
config.ema_decay = 0.99

# Training/Evaluation parameters.
config.num_train_steps = 25_000
config.num_batches_per_eval = 8
config.metric_aggregation_steps = 100
config.eval_every_steps = 500

# Checkpointing parameters.
config.save_interval_steps = 1_000
config.max_checkpoints_to_keep = 10


# Batch parameters.
config.batch_size = 64
config.eval_batch_size = 8

# Data parameters. Here we have the input/output shape.
config.input_shapes = ((1, 48, 1), (1, 48, 1))
config.normalize = False  # Wheter to normalize the data or not.

### Downloading the data.

The data was generated using [jax-cfd](https://github.com/google/jax-cfd), and uploaded to a Google Cloud bucket. The file contains both high- and low-resolution datasets.

We use gsutil for downloading the data. If you are running this notebook in colab, it should be already installed, otherwise, you can follow these [instructions](https://cloud.google.com/storage/docs/gsutil_install) to install it in your computer.

In [None]:
!gsutil cp gs://gresearch/swirl_dynamics/downscaling/KS_finite_volumes_vs_pseudo_spectral.hdf5 .

In [None]:
file_name = 'KS_finite_volumes_vs_pseudo_spectral.hdf5'

with h5py.File(file_name,'r+') as file:
    u_lr    = file['u_fd'][()] # Trajectories with finite volumes.
    u_hr    = file['u_sp'][()] # Trajectories with pseudo spectral methods.
    x    = file['x'][()] # Grid in which the trajectories are computed.
    t    = file['t'][()] # Time stamps for the trajectories.

In [None]:
# Plotting the low-res data.
# We choose which trajectory we want to plot.
plot_idx = 1
# Spatial downsampling factor.
ds_x = 4

# Define domain in time and space.
x_ = jnp.concatenate([x, jnp.array(x[-1] + x[1] - x[0]).reshape((-1,))])[::ds_x]
print(f"Shape of the spatial domain: {x_.shape}")

# Plots the low-resolution data.
fig = plt.figure(figsize=(14, 4))
plt.imshow(u_lr[plot_idx, :, :].T, extent=[t[0],t[-1],x_[-1],x_[0]], aspect=5)
plt.xlabel("time")
plt.ylabel("x")
plt.show()

# Plots the low-resolution data.
fig = plt.figure(figsize=(14, 4))
plt.imshow(u_hr[plot_idx, :, :].T, extent=[t[0],t[-1],x_[-1],x_[0]], aspect=5)
plt.xlabel("time")
plt.ylabel("x")
plt.show()


In [None]:
u_lr_hf = u_hr[:, :, ::ds_x]
x_lr_hf = x_[::ds_x]
u_lr_lf = u_lr

print(f"Shape of the low-resolution high-fidelity data {u_lr_hf.shape}")
print(f"Shape of the low-resolution grid {x_lr_hf.shape}")
print(f"Shape of the low-resolution low-fidelity data {u_lr_lf.shape}")

### Defining the dataloaders.

We leverage [pygrain](https://github.com/google/grain) to feed the data to the models during training. As such we need to define a source, which allows random access through an index sampling to the data, and the dataloader itself, which allows us to easily define transformations, such as normalization, concatenatation, change of keys, among others. 

We will use the simple tensorized coupling during training, i.e., the pairs are sampled following
$$(x_0, x_1) \sim \mu_0 \otimes \mu_1,$$
where $\mu_0$ and $\mu_1$ are the distribution of the input/output data respectively. Therefore, we will define dataloaders for each dataset, which will be sampled independently.


Here we define the `DataSource`. As the data is fairly small, we will read it from RAM, by returning a slice of a numpy array.

In [None]:
class SourceInMemoryNumpy(pygrain.RandomAccessDataSource):
  """Source class for Numpy array.

  Attributes:
    source: Numpy array with the data.
    normalize_stats: Dictionary containing the mean and std statistics of the
      data.
  """

  def __init__(
      self,
      dataset: np.ndarray,
  ):
    """Build the source of pre-computed trajectories stored in a numpy array.

    Args:
      dataset: Dataset in numpy format.

    Returns:
      loader, stats: Tuple of dataloader and dictionary containing mean and std.
    """

    self.source = dataset

    # Computes the metrics and store them in a dictionary.
    mean = np.mean(self.source, axis=0)
    std = np.std(self.source, axis=0)
    self.normalize_stats = {"mean": mean, "std": std}

  def __len__(self) -> int:
    """Returns the number of samples in the source."""
    return self.source.shape[0]

  def __getitem__(self, record_key: SupportsIndex) -> Mapping[str, np.ndarray]:
    """Returns the data record for a given record key.

    Args:
      record_key: The index to the position in the array

    Returns:
      A dictionay with the slice of the underlying array and key "u".
    """
    idx = record_key.__index__()

    if idx >= self.__len__():
      raise IndexError("Index out of range.")
    # here we return a dictionary with "u"
    return {"u": self.source[idx]}

In [None]:
def create_loader_from_ndarray(
    dataset: np.ndarray,
    batch_size: int,
    seed: int = 999,
    shuffle: bool = True,
    normalize: bool = True,
    normalize_stats: dict[str, np.ndarray] | None = None,
    drop_remainder: bool = True,
    worker_count: int = 0,
    output_name: Literal['x_0', 'x_1'] = 'x_0',
    num_epochs: int | None = None,
) -> tuple[pygrain.DataLoader, dict[str, np.ndarray] | None]:
  """Load pre-computed trajectories dumped to hdf5 file.

  Args:
    dataset: Absolute path to dataset file.
    batch_size: Batch size returned by dataloader. If set to -1, use entire
      dataset size as batch_size.
    seed: Random seed to be used in data sampling.
    shuffle: Whether to randomly shuffle the data.
    normalize: Flag for adding data normalization (subtact mean divide by std.).
    normalize_stats: Dictionary with mean and std stats to avoid recomputing, or
      if they need to be computed from a different dataset.
    output_name: Name of the output feature in the dictionary.
    drop_remainder: Flag for dropping the last batch if it is not complete.
    worker_count: Number of workers to use in the dataloader.
    output_name: Name of the field in the output, here we only consider either
      the initial or terminal value of the ODE, i.e., "x_0" or "x_1".
    num_epochs: Number of epochs, by default it will never stop.

  Returns:
    loader, stats (optional): Tuple of dataloader and dictionary containing
                              mean and std stats (if normalize=True, else dict
                              contains NoneType values).
  """

  # Creates the source file, which is a random access file wrapping a numpy
  # array
  source = SourceInMemoryNumpy(dataset)

  if normalize_stats is None:
    normalize_stats = source.normalize_stats

  if "mean" not in normalize_stats or "std" not in normalize_stats:
    raise ValueError(
        "The normalize_stats dictionary should contain keys 'mean' and 'std'."
    )

  transformations = []
  if normalize:
    transformations.append(
        transforms.Standardize(
            input_fields=["u",],
            mean={"u": normalize_stats["mean"]},
            std={"u": normalize_stats["std"]},
        )
    )

  if output_name is not None and output_name != "u":
    if not isinstance(output_name, str):
      raise ValueError(
          "The output_name should be a string, but it is a ",
          type(output_name),
      )
    transformations.append(
        transforms.SelectAs(
            select_features=["u",],
            as_features=[output_name,]
        )
    )

  loader = pygrain.load(
      source=source,
      num_epochs=num_epochs,
      shuffle=shuffle,
      seed=seed,
      shard_options=pygrain.ShardByJaxProcess(drop_remainder=True),
      transformations=transformations,
      batch_size=batch_size,
      drop_remainder=drop_remainder,
      worker_count=worker_count,
  )
  return loader, normalize_stats

Now, we leverage each individual dataset which samples either from $\mu_0$ of $\mu_1$ to build the dataloader that samples from $\mu_0 \otimes \mu_1$.

In [None]:
class UnpairedDataLoaderNumpy:
  """Unpaired dataloader for loading samples from two distributions.

  Attributes:
    loader_a: PyGrain dataloader for the input data.
    loader_b: PyGrain dataloader for the output data.
    normalize_stats_a: Dictionary with the statisitcs, mean and std of the
      input data.
    normalize_stats_b: Dictionary with the statisitcs, mean and std of the
      output data.
  """

  def __init__(
      self,
      dataset_a: np.ndarray,
      dataset_b: np.ndarray,
      batch_size: int,
      seed: int = 37,
      normalize: bool = False,
      normalize_stats_a: dict[str, np.ndarray] | None = None,
      normalize_stats_b: dict[str, np.ndarray] | None = None,
      drop_remainder: bool = True,
      worker_count: int = 0,
  ):

    loader, normalize_stats_a = create_loader_from_ndarray(
        dataset=dataset_a,
        batch_size=batch_size,
        seed=seed,
        normalize=normalize,
        normalize_stats=normalize_stats_a,
        output_name="x_0",
        drop_remainder=drop_remainder,
        worker_count=worker_count,
    )

    self.loader_a = iter(loader)

    loader, normalize_stats_b = create_loader_from_ndarray(
        dataset=dataset_b,
        batch_size=batch_size,
        seed=seed+1,
        normalize=normalize,
        normalize_stats=normalize_stats_b,
        output_name="x_1",
        drop_remainder=drop_remainder,
        worker_count=worker_count,
    )
    self.loader_b = iter(loader)

    self.normalize_stats_a = normalize_stats_a
    self.normalize_stats_b = normalize_stats_b

  def __iter__(self):
    return self

  def __next__(self) -> dict[str, np.ndarray]:
    """Defines the next function for the iterator."""

    b = next(self.loader_b)
    a = next(self.loader_a)

    # Returns dictionary with keys "x_0" and "x_1".
    return {**a, **b}

We instantiate both the training and evaluation dataloaders.

In [None]:
# We consider the last 16 trajectories as evaluation ones.
num_test_trajs = 16

train_dataloader = UnpairedDataLoaderNumpy(
    dataset_a=u_lr_lf[:-num_test_trajs].reshape((-1, u_lr_lf.shape[-1], 1)),
    dataset_b=u_lr_hf[:-num_test_trajs].reshape((-1, u_lr_lf.shape[-1], 1)),
    batch_size=config.batch_size,
    normalize=config.normalize,
  )

eval_dataloader = UnpairedDataLoaderNumpy(
    dataset_a=u_lr_lf[-num_test_trajs:].reshape((-1, u_lr_lf.shape[-1], 1)),
    dataset_b=u_lr_hf[-num_test_trajs:].reshape((-1, u_lr_lf.shape[-1], 1)),
    # We use the same stats as the train data.
    normalize_stats_a=train_dataloader.normalize_stats_a,
    normalize_stats_b=train_dataloader.normalize_stats_b,
    batch_size=config.eval_batch_size,
    normalize=config.normalize,
  )

We quickly test the batches produced by the dataloader, in particular is should contain the both the input and output samples, which are indexed by the keys `x_0` and `x_1`, respectively.

In [None]:
batch = next(iter(train_dataloader))
print(batch.keys())
assert batch["x_0"].shape[0] == config.batch_size
print(batch["x_0"].shape)

### Defining the rectified flow optimizers.

Here we define the learning rate schedule, for simplicity we use a linear ramp-up followed with a cosine decay schedule. This can be further tweaked but empirically, this has shown to provide reasonable resutls for this type of problems.

For the optimizer we use the Adam optimizer, and we also add a clipping mechanism to help avoid instabilities.


In [None]:
# Defining experiments through the config file.
lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=config.initial_lr,
    peak_value=config.peak_lr,
    warmup_steps=config.warmup_steps,
    decay_steps=config.num_train_steps,
    end_value=config.end_lr,
)

optimizer = optax.chain(
    optax.clip_by_global_norm(config.max_norm),
    optax.adam(
        learning_rate=lr_schedule,
        b1=config.beta1,
    ),
)

## Defining the model

In this case the model is a fully convolutional model, with several ResNet blocks with a Fourier embedding for the time.

Here this model paremetrized the velocity vector field in the Rectiflied formulation, which is given by solving the following ODE:
$$\dot{x} = v_{\theta}(x, t),$$
in this case, the model defined below parametrized $v_{\theta}(x, t)$.

In [None]:
Array = jax.Array

class ConvModel(nn.Module):
  out_channels: int = 1
  num_channels: tuple[int,...] = (32, 32, 32)
  num_blocks: int = 4
  noise_embed_dim: int = 32
  kernel_dim: tuple[int] = (5,)
  padding: str = "CIRCULAR"
  precision: jax.lax.Precision | None = 'fastest'
  dtype: jnp.dtype = jnp.float32
  param_dtype: jnp.dtype = jnp.float32

  @nn.compact
  def __call__(
      self,
      x: Array,
      sigma: Array,
      cond: dict[str, Array] | None = None,
      *,
      is_training: bool,
  ) -> Array:

    if sigma.ndim < 1:
      sigma = jnp.broadcast_to(sigma, (x.shape[0],))

    if sigma.ndim != 1 or x.shape[0] != sigma.shape[0]:
      raise ValueError(
          "sigma must be 1D and have the same leading (batch) dimension as x"
          f" ({x.shape[0]})!"
      )

    emb = unets.FourierEmbedding(dims=self.noise_embed_dim)(sigma)

    h = layers.ConvLayer(
          features=128,
          kernel_size=self.kernel_dim,
          padding=self.padding,
          kernel_init=unets.default_init(1.0),
          precision=self.precision,
          dtype=self.dtype,
          param_dtype=self.param_dtype,
          name="conv_in",
      )(x)

    for i, channel in enumerate(self.num_channels):
      for j in range(self.num_blocks):
        h = unets.ConvBlock(
                  out_channels=channel,
                  kernel_size=self.kernel_dim,
                  padding=self.padding,
                  # dropout=self.dropout_rate,
                  precision=self.precision,
                  dtype=self.dtype,
                  param_dtype=self.param_dtype,
                  name=f"block_number{i}_size_{channel}.num_subblock_{j}.",
              )(h, emb, is_training=is_training)

    h = layers.ConvLayer(
          features=self.out_channels,
          kernel_size=self.kernel_dim,
          padding=self.padding,
          kernel_init=unets.default_init(1.0),
          precision=self.precision,
          dtype=self.dtype,
          param_dtype=self.param_dtype,
          name="conv_out",
      )(h)

    return h

### Instantiating the model.

Here we use the wrapper in the models module inside `swirl_dynamics`. This allows to avoid a large part of the boiler plate, that we would otherwise need.

In [None]:
def build_model(config: ConfigDict):
  """Builds the model from config file."""

  flow_model = ConvModel()

  model = models.ReFlowModel(
      # Input shape without the batch dimension.
      input_shape=(
          config.input_shapes[0][1],
          config.input_shapes[0][2],
      ),
      # Instance of the nn.Module.
      flow_model=flow_model,
      # Sampling schedule for the reflow time.
      time_sampling=functools.partial(
        jax.random.uniform, dtype=jax.numpy.float32
    )
  )

  return model

In [None]:
model = build_model(config)
init_vars = model.initialize(jax.random.PRNGKey(0))
mutables, params = flax.core.pop(init_vars, "params")
print("Checking the keys in the parames. So it was properly initialized")
print(params.keys())

### Building the trainer

In [None]:

# Builds the trainer. Depending if there are more than one GPU/TPU.
if config.is_distributed:
  trainer_class = trainers.DistributedReFlowTrainer
else:
  trainer_class = trainers.ReFlowTrainer

trainer = trainer_class(
        model=model,
        rng=jax.random.PRNGKey(config.seed),
        optimizer=optimizer,
        ema_decay=config.ema_decay,
    )

# Sets up checkpointing.
ckpt_options = checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.save_interval_steps,
    max_to_keep=config.max_checkpoints_to_keep,
)

# Sets up the working directory.
workdir = "/content" # typical current position in Colab.

In [None]:
### If you need to remove the checkpoint to start from scratch.
# !rm -Rf /content/checkpoints

### Running the training loop.

We run the training loop.

Here the seek to solve the problem

$$ \min_{\theta} \mathbb{E}_{t \sim U[0, 1]}, \mathbb{E}_{(x_0, x_1) \in \mu_0 \otimes \mu_1} \left \| (x_1 - x_0) - v_{\theta}(x_t, t)  \right \|^2,$$
where $x_t = (1-t) x_1 + t x_0$, and $\mu_0$ and $\mu_1$ are the distribution for the data stemming from the FV and spectral disctrization respectively.


The typical speed of training will greatly depend on the accelerator used.

- A100
  - with batch size 16: 60 it/s
  - with batch size 32: 48 it/s
  - with batch size 64: 35 it/s ~ it takes around 13 mins for 25k its.

In [None]:
# Run training loop.
train.run(
    train_dataloader=train_dataloader,
    trainer=trainer,
    workdir=workdir,
    total_train_steps=config.num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        workdir, asynchronous=False
    ),
    metric_aggregation_steps=config.metric_aggregation_steps,  # 30
    eval_dataloader=eval_dataloader,
    eval_every_steps=config.eval_every_steps,
    num_batches_per_eval=config.num_batches_per_eval,
    callbacks=(
        # This callback saves model checkpoint periodically.
        callbacks.TrainStateCheckpoint(
            base_dir=workdir,
            options=ckpt_options,
        ),
        # Callback to add the number of iterations/second.
        callbacks.TqdmProgressBar(
            total_train_steps=config.num_train_steps,
            train_monitors=("train_loss",),
        ),
    ),
)

# Running Inference.

Loading extra libraries for runninf inference.

In [None]:
from swirl_dynamics.lib.solvers import ode as ode_solvers
from tqdm import tqdm

In [None]:
num_sampling_steps = 64
batch_size_sampling = 64

Define the dataloader to run inference.

In [None]:
test_dataloader, _ = create_loader_from_ndarray(
        dataset=u_lr_lf[-num_test_trajs:].reshape((-1, u_lr_lf.shape[-1], 1)),
        batch_size=batch_size_sampling,
        shuffle=False,
        num_epochs=1,
        seed=0,
        normalize=config.normalize,
        normalize_stats=train_dataloader.normalize_stats_a,
        output_name="x_0",
        drop_remainder=True,
)

### Load the last trained model

In [None]:
# TODO: add the option to download an already trained model.
trained_state = trainers.TrainState.restore_from_orbax_ckpt(
        f"{workdir}/checkpoints", step=None
)

In [None]:
latent_dynamics_fn = ode_solvers.nn_module_to_dynamics(
      model.flow_model, autonomous=False, is_training=False
  )

In [None]:
integrator = ode_solvers.RungeKutta4()
integrate_fn = jax.jit(functools.partial(
    integrator,
    latent_dynamics_fn,
    tspan=jnp.arange(0.0, 1.0, 1.0 / num_sampling_steps),
    params=trained_state.model_variables,
))

In [None]:
batch = next(iter(test_dataloader))
print(batch['x_0'].shape)
out_put = integrate_fn(batch['x_0'])
print(out_put.shape)

### Running Inference Loop.

This steps solves the ODE
$$\dot{x} = v_{\theta}(x, t),$$
with initial condition $x(0) = x_0$. For all the $x_0 \sim \mu_0$ in the test dataset.

In [None]:
input_list = []
output_list = []

mean_output = train_dataloader.normalize_stats_b["mean"]
std_output = train_dataloader.normalize_stats_b["std"]

for ii, batch in tqdm(enumerate(test_dataloader)):
  input_list.append(batch["x_0"])

  output = np.array(
          integrate_fn(batch["x_0"])[-1].reshape(
              (-1, config.input_shapes[1][1], config.input_shapes[1][2])
          )
      )
  # Denormalizes the output. If the model was trained with normalized data.
  if config.normalize:
    output = std_output * output + mean_output
  output_list.append(output)

### Evaluation of the transformed data.

In [None]:
# Defining the ground truth, the input and the output to compute the point-wise
# statistics, in a particular point of the domain.
output_array = np.concatenate(
    output_list, axis=0).reshape(16, -1, config.input_shapes[1][1])
output = output_array.reshape(-1, config.input_shapes[1][1])
test_u_lr_lf = u_lr_lf[-num_test_trajs:].reshape(-1, 48)
test_u_lr_hf = u_lr_hf[-num_test_trajs:].reshape(-1, 48)

We plot a simple histogram to check the difference in the marginal statistics of on particular point. As the problem is translation equivariant this is enough.

This is similar to the results obtained using a Sinkhorn iteration solver when solving an Optimal Transport problem for debiasing see this [notebook](https://github.com/google-research/swirl-dynamics/blob/main/swirl_dynamics/projects/debiasing/optimal_transport/colab/optimal_transport_sinkhorn.ipynb). The main advantage of this approach is that we don't need to store the fullt training data, which also allows to process much larger amount of data.

In [None]:
plt.figure()
idx_x = 2
plt.hist(test_u_lr_lf[:, idx_x], bins=50, density=True,
         alpha=0.5, label='Finite Volumes')
plt.hist(test_u_lr_hf[:, idx_x], bins=50, density=True,
         alpha=0.5, label='Spectral')
plt.hist(output[:, idx_x], bins=50, density=True,
         alpha=0.5, label='Reflow')
plt.legend()
plt.show()

We define functions to compare the energy spectra of the different datasets, and the debiased data.

In [None]:
def _energy(x: jax.Array) -> jax.Array:
  """Computes energy of a given array."""
  return jnp.square(jnp.abs(jnp.fft.fftshift(jnp.fft.fft(x))))

def energy_spectra(
    num_bins: int = 24, nx: int = 48
) -> tuple[Callable[[jax.Array], jax.Array], jax.Array]:
  """Computes radial spectra of a given array.

  Args:
    num_bins: Number of bins for the energy spectrum.
    lon: Number of discretization point in the longitudinal direction.
    lat: Number of discretization points in the meridional direction.

  Returns:
    A tuple with the function that computes the radial spectra and the
    frequencies associated to each bin (without the 0-th frequency)
  """
  kx = np.fft.fftshift(np.fft.fftfreq(nx, d=1 / nx))
  k = np.abs(kx)
  bins = np.linspace(np.min(k), np.max(k), num=num_bins)
  indices = np.digitize(k, bins)

  def _radial_spectra(x):
    energy = _energy(x)
    energy_k = lambda kk: jnp.sum((indices == kk) * energy)
    rad_energy = jax.vmap(energy_k)(jnp.arange(1, num_bins))
    return rad_energy

  return jax.jit(_radial_spectra), bins[1:]

def build_vectorized_function(
    num_bins: int = 24, nx: int = 48
) -> Callable[[jax.Array], jax.Array]:
  """Builds vectorized function radial spectra.

  Args:
    num_bins: Number of bins for the energy spectrum.
    lon: Number of discretization point in the longitudinal direction.
    lat: Number of discretization points in the meridional direction.

  Returns:
    A function that computes the radial spectra of a given array.
  """
  rad_spec_fn, _ = energy_spectra(num_bins, nx)
  vmapped_radspec = jax.vmap(rad_spec_fn)

  def sample_radspec(samples: jax.Array) -> jax.Array:
    spec = vmapped_radspec(samples)
    return np.mean(spec, axis=0)

  return sample_radspec

def mean_log_ratio(
    radspec1: jax.Array, radspec2: jax.Array, range_freq: slice
) -> jax.Array:
  """Computes the mean log ratio of two radial spectra.

  Args:
    radspec1: First radial spectra.
    radspec2: Second radial spectra.
    range_freq: Range of frequencies to compute the error.

  Returns:
    The mean log ratio between the two radial spectra.
  """
  return jnp.mean(
      jnp.abs(jnp.log10(radspec1[range_freq] / radspec2[range_freq]))
  )

Comparing the log ratio of the energy spectra.

In [None]:
fun_vect = build_vectorized_function()
# Computes the mean energy spectra.
spectrum_lf = fun_vect(test_u_lr_lf)
spectrum_hf = fun_vect(test_u_lr_hf)
spectrum_reflow = fun_vect(output)

# Computes the mean log ratio.
err_log_ratio_ref = mean_log_ratio(spectrum_lf,
                                   spectrum_hf,
                                   range_freq = None)
err_log_ratio_reflow = mean_log_ratio(spectrum_reflow,
                                      spectrum_hf,
                                      range_freq = None)

print("Error of the log ratio of the energy spectrum of FV data",
      f"{err_log_ratio_ref}")
print("Error of the log ratio of the energy spectrum of debiased FV data",
      f"{err_log_ratio_reflow}")