# Dataset Generation

The code has been adopted from https://github.com/pdebench/PDEBench and necessary modifications have been made to add noise to the data. Burger's equation has been used as an example to show proof of concept. Noise added is a random choice between skewed normal and exponential noise. 

In [1]:
from __future__ import annotations

import sys
import time
from math import ceil
import numpy as np
import jax
import jax.numpy as jnp
from jax import device_put, lax

import numpy as np
import matplotlib.pyplot as plt
import h5py
from scipy.stats import skewnorm

## Initial Conditions 

This function initializes the initial condition of the PDE based on the specified mode.


- if mode is `sin`, the initial condition is a sine wave.
- if mode is `gaussian`, the initial condition is a Gaussian distribution.
- if mode is `step`, the initial condition is a step function.
- if mode is `possin`, the initial condition is a positive sine wave, in which the entire sine wave stays above zero
- if mode is `sinsin`, the initial condition is a double sine wave, with each wave having a different frequency


In [2]:
def _pass(carry):
    return carry


def init(xc, mode="sin", u0=1.0, du=0.1):
    """
    :param xc: cell center coordinate
    :param mode: initial condition
    :return: 1D scalar function u at cell center
    """
    modes = ["sin", "sinsin", "Gaussian", "react", "possin"]
    assert mode in modes, "mode is not defined!!"
    if mode == "sin":  # sinusoidal wave
        u = u0 * jnp.sin((xc + 1.0) * jnp.pi)
    elif mode == "sinsin":  # sinusoidal wave
        u = jnp.sin((xc + 1.0) * jnp.pi) + du * jnp.sin((xc + 1.0) * jnp.pi * 8.0)
    elif mode == "Gaussian":  # for diffusion check
        t0 = 0.01
        u = jnp.exp(-(xc**2) * jnp.pi / (4.0 * t0)) / jnp.sqrt(2.0 * t0)
    elif mode == "react":  # for reaction-diffusion eq.
        logu = -0.5 * (xc - jnp.pi) ** 2 / (0.25 * jnp.pi) ** 2
        u = jnp.exp(logu)
    elif mode == "possin":  # sinusoidal wave
        u = u0 * jnp.abs(jnp.sin((xc + 1.0) * jnp.pi))
    return u

## Courant Functions 

The Courant number, also known as the Courant-Friedrichs-Lewy (CFL) number, is a dimensionless number that plays a crucial role in the stability of numerical solutions to partial differential equations (PDEs). The Courant number is a critical parameter for ensuring the stability of numerical solutions to PDEs. It is used to determine appropriate time step sizes relative to the spatial grid size and the characteristic velocity or diffusion coefficient of the problem.

`Courant` function calculates the Courant number for the given velocity field u and grid spacing dx.

`Courant_diff` function calculates the Courant number for diffusion with the given grid spacing dx and diffusion coefficient epsilon.

In [3]:
def Courant(u, dx):
    stability_adv = dx / (jnp.max(jnp.abs(u)) + 1.0e-8)
    return stability_adv


def Courant_diff(dx, epsilon=1.0e-3):
    stability_dif = 0.5 * dx**2 / (epsilon + 1.0e-8)
    return stability_dif

## Noise Generation

This function generates a noise array with random variability, to give us noisy data of different forms, to increase the robustness of the model to the noise. 

  - Randomly chooses between two noise types: **skew normal** or **exponential** distribution.  
  - If a random value `a` < 0.5, uses  a skew normal distribution with skewness randomly selected between -5 and 5.  
  - Otherwise, uses an exponential distribution with a scale randomly selected between 0 and 4.  
- It returns the generated noise scaled by `noise_level`.  
- It adds diverse noise patterns for training models under different conditions.

In [4]:
def generate_noise(shape, noise_level):
    a = np.random.rand()
    if a < 0.5:
        parameter = np.random.uniform(-5, 5)
        return skewnorm.rvs(a=parameter, size=shape) * noise_level
    else:
        parameter = np.random.uniform(0, 4)
        return np.random.exponential(scale=parameter, size=shape) * noise_level

## Boundary Condtions

- Implements boundary conditions for a 1D array with optional noise addition.  
- Takes input array `u`, grid spacing, number of interior cells `Ncell`, boundary condition mode, noise level, and an optional flag to return the noise-free version.  
- Adds noise to boundary cells using the `generate_noise` function.  
- Periodic condition matches boundaries to opposite edges of the array.  
- Returns the updated array and optionally the version without noise to store both noisy and noise free versions in the dataset.

Periodic boundary conditions connect the boundaries of a system so that values at one edge match those at the opposite edge, creating a continuous, looped domain. This is useful for simulating systems with cyclic or infinite behavior. In the Burgers' equation, periodic boundary conditions are often applied to study wave propagation, shock formation, and turbulence in a confined, repeating space, ensuring smooth transitions at the domain edges without artificial boundary effects.

In [5]:
def bc(u, Ncell, mode="periodic", noise_level=0.0, retNoise=False):
    _u = jnp.zeros(Ncell + 4)  
    _u = _u.at[2 : Ncell + 2].set(u)
    noise = generate_noise(_u.shape, noise_level)
    if retNoise:
        _u_no_noise = _u.copy()
    if mode == "periodic": 
        _u = _u.at[0:2].set(u[-2:] + noise[0:2]) 
        _u = _u.at[Ncell + 2 : Ncell + 4].set(
            u[0:2] + noise[Ncell + 2 : Ncell + 4]
        ) 

    if retNoise:
        return _u, _u_no_noise
    else:
        return _u

## Van Leer limiter function

This code applies a **Van Leer limiter** to reconstruct left (`uL`) and right (`uR`) states of a solution variable `u` for finite volume methods. The `VLlimiter` limits gradients (`gradu`) between cells to ensure stability and avoid oscillations near discontinuities. These reconstructed states are used to compute numerical fluxes in PDE solvers, particularly for hyperbolic or conservation-law equations, enabling stable and accurate updates of the solution.

In [6]:
def VLlimiter(a, b, c, alpha=2.0):
    return (
        jnp.sign(c)
        * (0.5 + 0.5 * jnp.sign(a * b))
        * jnp.minimum(alpha * jnp.minimum(jnp.abs(a), jnp.abs(b)), jnp.abs(c))
    )


def limiting(u, Ncell, if_second_order):
    du_L = u[1 : Ncell + 3] - u[0 : Ncell + 2]
    du_R = u[2 : Ncell + 4] - u[1 : Ncell + 3]
    du_M = (u[2 : Ncell + 4] - u[0 : Ncell + 2]) * 0.5
    gradu = VLlimiter(du_L, du_R, du_M) * if_second_order
    uL, uR = jnp.zeros_like(u), jnp.zeros_like(u)
    uL = uL.at[1 : Ncell + 3].set(u[1 : Ncell + 3] - 0.5 * gradu)
    uR = uR.at[1 : Ncell + 3].set(u[1 : Ncell + 3] + 0.5 * gradu)
    return uL, uR

## Parameters

- `dt_save`: Time interval for saving simulation results.  
- `ini_time`: Initial time of the simulation.  
- `fin_time`: Final time of the simulation.  
- `nx`: Number of grid points in the spatial domain.  
- `xL`: Left boundary of the spatial domain.  
- `xR`: Right boundary of the spatial domain.  
- `if_second_order`: Flag to enable or disable second-order accuracy in calculations.  
- `show_steps`: Frequency of displaying or outputting simulation results.  

In [7]:
dt_save = 0.01
ini_time = 0.0
fin_time = 2.0
nx = 1024
xL = -1.0
xR = 1.0
if_second_order = 1.0
show_steps = 100

## Spatial and Temporal Grids

This code computes spatial and temporal grids for a numerical simulation:

- `dx`: Calculates the spatial resolution (grid spacing) as the total domain length divided by the number of grid points (`nx`).  
- `xe`: Generates the **edge coordinates** of the grid points using `jnp.linspace`, creating `nx + 1` evenly spaced points between `xL` and `xR`.  
- `xc`: Computes the **center coordinates** of each grid cell by averaging adjacent edge coordinates (`xe[:-1] + 0.5 * dx`).  
- `it_tot`: Determines the total number of time steps to save results, based on the total simulation time and saving interval (`dt_save`).  
- `tc`: Creates the time array for saved outputs, with `it_tot + 1` evenly spaced time values starting from 0.

In [8]:
dx = (xR - xL) / nx
xe = jnp.linspace(xL, xR, nx + 1)
xc = xe[:-1] + 0.5 * dx
it_tot = ceil((fin_time - ini_time) / dt_save) + 1
tc = jnp.arange(it_tot + 1) * dt_save

## Hierarchial Storage

HDF5 (`h5py`) is used here for efficiently storing and organizing simulation data in a structured and portable format. The reasons for using HDF5 in this case are:

- **Hierarchical Organization**: The `coords` group organizes data logically, making it easy to manage related datasets (e.g., `x-coordinates` and `t-coordinates`).  
- **Scalability**: HDF5 handles large datasets efficiently, which is useful for high-resolution simulations like this one (`nx = 1024`, potentially large time arrays).  

The spatial (`xc`) and temporal (`tc`) coordinates are saved for later use or analysis in a structured and accessible way.

In [9]:
with h5py.File("simulation_data.h5", "w") as g:
    f = g.create_group("coords")
    f.create_dataset("x-coordinates", data=xc)
    f.create_dataset("t-coordinates", data=tc)
    g.close()

## **Dataset Generation function**  

The main `gen` function generates simulation data for a PDE with configurable physical parameters and noise settings. It initializes the solution with or without noise, evolves the system over time for both cases, and stores the results (e.g., solution states, boundary conditions, and parameters) in an HDF5 file. This enables efficient storage, reproducibility, and analysis of clean and noisy simulations for future training. 

Parameters like viscosity(`epsilon`), `init-mode` that influence the exact behaviour of the system are randomised over a range to give us a robust dataset that represents the equation over a very wide range of physical regimes.

### Inner Functions

- **`evolve`**: This function handles the temporal evolution of the solution array `u` by repeatedly applying the time-stepping logic in `simulation_fn` until the final simulation time is reached. It saves intermediate results at specified time intervals (`dt_save`) and handles clean and noisy simulations separately by incorporating noise levels. The function ensures efficient looping using `jax.lax.while_loop` for compatibility with JAX's just-in-time (JIT) compilation.

- **`simulation_fn`**: This function manages a single time step of the simulation. It calculates the time step size (`dt`) based on CFL constraints for advection and diffusion processes, applies the `update` function to compute the next state of `u`, and updates the time and step counters. It ensures the stability and consistency of the numerical solution through careful time-stepping.

- **`update`**: This function applies a two-stage time-stepping scheme to update the solution array `u`. It computes numerical fluxes using the `flux` function and incorporates noise at each stage if specified. The result is a stable numerical integration of the PDE over one time step, accounting for noise in both the flux and the governing equation.

- **`flux`**: This function calculates the fluxes needed to solve the PDE using an upwind scheme for stability and accuracy. It applies boundary conditions via the `bc` function and uses slope limiting with the `limiting` function to prevent unphysical oscillations. Additionally, it includes a source term for diffusion and optional noise to simulate variability in the flux.

- **`bc`**: This function defines the boundary conditions for the solution array `u`, such as periodic or reflective boundaries, and optionally adds noise to the boundary values. It ensures that the boundaries are correctly set for accurate flux calculations and consistent evolution of the PDE.

In [10]:
def gen(path) -> None:
    epsilon = np.random.uniform(1.0e-4, 1.0e-1)
    u0 = np.random.uniform(0.5, 2.0)
    du = np.random.uniform(0.0, 0.5)
    CFL = np.random.uniform(0.1, 0.9)
    init_mode = np.random.choice(["sin", "sinsin", "possin"])
    noise_level = np.random.uniform(0.0, 0.5)
    equation_noise_level = np.random.uniform(0.0, 0.1)

    pi_inv = 1.0 / jnp.pi
    dx = (xR - xL) / nx
    dx_inv = 1.0 / dx

    @jax.jit
    def evolve(u, noise_level=0.0, equation_noise_level=0.0):
        t = ini_time
        tsave = t
        steps = 0
        i_save = 0
        dt = 0.0
        uu = jnp.zeros([it_tot, u.shape[0]])
        uu = uu.at[0].set(u)

        tm_ini = time.time()

        cond_fun = lambda x: x[0] < fin_time

        def _body_fun(carry):
            def _save(_carry):
                u, tsave, i_save, uu = _carry
                uu = uu.at[i_save].set(u)
                tsave += dt_save
                i_save += 1
                return (u, tsave, i_save, uu)

            t, tsave, steps, i_save, dt, u, uu = carry

            # if save data
            carry = (u, tsave, i_save, uu)
            u, tsave, i_save, uu = lax.cond(t >= tsave, _save, _pass, carry)

            # Pass noise levels to simulation_fn here
            carry = (u, t, dt, steps, tsave)
            u, t, dt, steps, tsave = lax.fori_loop(
                0,
                show_steps,
                lambda i, carry: simulation_fn(
                    i, carry, noise_level, equation_noise_level
                ),
                carry,
            )
            return (t, tsave, steps, i_save, dt, u, uu)

        carry = t, tsave, steps, i_save, dt, u, uu
        t, tsave, steps, i_save, dt, u, uu = lax.while_loop(cond_fun, _body_fun, carry)
        uu = uu.at[-1].set(u)

        tm_fin = time.time()
        return uu, t

    @jax.jit
    def simulation_fn(i, carry, noise_level=0.0, equation_noise_level=0.0):
        u, t, dt, steps, tsave = carry
        dt_adv = Courant(u, dx) * CFL
        dt_dif = Courant_diff(dx, epsilon * pi_inv) * CFL
        dt = jnp.min(jnp.array([dt_adv, dt_dif, fin_time - t, tsave - t]))

        def _update(carry):
            u, dt = carry
            u_tmp = update(u, u, dt * 0.5, noise_level, equation_noise_level)
            u = update(u, u_tmp, dt, noise_level, equation_noise_level)
            return u, dt

        carry = u, dt
        u, dt = lax.cond(dt > 1.0e-8, _update, _pass, carry)

        t += dt
        steps += 1
        return u, t, dt, steps, tsave

    @jax.jit
    def update(u, u_tmp, dt, noise_level=0.0, equation_noise_level=0.0):
        f = flux(u_tmp, noise_level)
        noise = generate_noise(f.shape, equation_noise_level)
        f = f + noise  # Add noise to the flux
        u -= dt * dx_inv * (f[1 : nx + 1] - f[0:nx])
        return u

    def flux(u, noise_level=0.0):
        _u, _u_no_noise = bc(
            u, Ncell=nx, noise_level=noise_level, retNoise=True
        )  # index 2 for _U is equivalent with index 0 for u
        uL, uR = limiting(_u, nx, if_second_order=if_second_order)
        fL = 0.5 * uL**2
        fR = 0.5 * uR**2
        # Upwind advection scheme
        f_upwd = 0.5 * (
            fR[1 : nx + 2]
            + fL[2 : nx + 3]
            - 0.5
            * jnp.abs(uL[2 : nx + 3] + uR[1 : nx + 2])
            * (uL[2 : nx + 3] - uR[1 : nx + 2])
        )
        # Source term
        f_upwd += -epsilon * pi_inv * (_u[2 : nx + 3] - _u[1 : nx + 2]) * dx_inv
        return f_upwd

    # Initialize the solution without noise
    u = init(xc=xc, mode=init_mode, u0=u0, du=du)

    # Add noise to the initial condition
    noise = generate_noise(u.shape, noise_level)
    u_noisy = u + noise

    # Evolve the solution without noise
    u_clean = device_put(u)  # Putting variables in GPU (not necessary??)
    uu_clean, t_clean = evolve(u_clean)

    # Evolve the solution with noise
    u_noisy = device_put(u_noisy)  # Putting variables in GPU (not necessary??)
    uu_noisy, t_noisy = evolve(
        u_noisy, noise_level=noise_level, equation_noise_level=equation_noise_level
    )

    # Save boundary condition without noise
    _, boundary_condition_no_noise = bc(
        u, Ncell=nx, noise_level=noise_level, retNoise=True
    )
    # Save boundary condition with noise
    boundary_condition_noisy, _ = bc(
        u_noisy, Ncell=nx, noise_level=noise_level, retNoise=True
    )

    with h5py.File("simulation_data.h5", "a") as g:
        f = g.create_group(f"{path}")
        f.create_dataset("epsilon", data=epsilon)
        f.create_dataset("u0", data=u0)
        f.create_dataset("du", data=du)
        f.create_dataset("CFL", data=CFL)
        f.create_dataset("init_mode", data=np.string_(init_mode))
        f.create_dataset("noise_level", data=noise_level)
        f.create_dataset("equation_noise_level", data=equation_noise_level)
        f.create_dataset("initial_condition_clean", data=u)
        f.create_dataset("initial_condition_noisy", data=u_noisy)
        f.create_dataset("boundary_condition_clean", data=boundary_condition_no_noise)
        f.create_dataset("boundary_condition_noisy", data=boundary_condition_noisy)
        f.create_dataset("clean", data=uu_clean)
        f.create_dataset("noisy", data=uu_noisy)
        g.close()

100 samples are generated for future training.

In [11]:
for i in range(100):
    print(f"Running simulation {i + 1}/100")
    gen(i)

Running simulation 1/100
Running simulation 2/100
Running simulation 3/100
Running simulation 4/100
Running simulation 5/100
Running simulation 6/100
Running simulation 7/100
Running simulation 8/100
Running simulation 9/100
Running simulation 10/100
Running simulation 11/100
Running simulation 12/100
Running simulation 13/100
Running simulation 14/100
Running simulation 15/100
Running simulation 16/100
Running simulation 17/100
Running simulation 18/100
Running simulation 19/100
Running simulation 20/100
Running simulation 21/100
Running simulation 22/100
Running simulation 23/100
Running simulation 24/100
Running simulation 25/100
Running simulation 26/100
Running simulation 27/100
Running simulation 28/100
Running simulation 29/100
Running simulation 30/100
Running simulation 31/100
Running simulation 32/100
Running simulation 33/100
Running simulation 34/100
Running simulation 35/100
Running simulation 36/100
Running simulation 37/100
Running simulation 38/100
Running simulation 39

## Visualisation

The `visualize_burgers` function generates an animated GIF of the Burgers equation's solution over time. It takes spatial coordinates (`xcrd`), the simulation data (`data`), and an identifier (`i`) to name the output GIF. The function iterates through the time steps of the solution, plots each one, and stores the frames for the animation. It then creates an animation using `matplotlib.animation.ArtistAnimation`, saving it as a `.gif` file with a specified frame rate.

In [12]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from tqdm import tqdm


def visualize_burgers(xcrd, data, i):
    fig, ax = plt.subplots()

    ims = []

    for i in tqdm(range(data.shape[0])):
        if i == 0:
            im = ax.plot(xcrd, data[i].squeeze(), animated=True, color="blue")
        else:
            im = ax.plot(
                xcrd, data[i].squeeze(), animated=True, color="blue"
            ) 
        ims.append([im[0]])
    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)

    writer = animation.PillowWriter(fps=15, bitrate=1800)
    ani.save(f"{i}.gif", writer=writer)
    plt.close(fig)

## Visualisation


In [None]:
import random


with h5py.File("simulation_data.h5", "r") as f:
    l = random.choices(list(f.keys()), k=3)
    for i in l:
        if i == 'coords':
            continue
        print(f"Visualizing simulation {i}")
        data = f[i]["clean"][:]
        xc = f["coords"]["x-coordinates"][:]
        visualize_burgers(xc, data, f"{i}_clean")

        data = f[i]["noisy"][:]
        visualize_burgers(xc, data, f"{i}_noisy")

Visualizing simulation 29


100%|██████████| 201/201 [00:00<00:00, 1629.46it/s]
100%|██████████| 201/201 [00:00<00:00, 610.08it/s] 


Visualizing simulation 65


100%|██████████| 201/201 [00:00<00:00, 1703.09it/s]
100%|██████████| 201/201 [00:00<00:00, 1674.15it/s]


Visualizing simulation 98


100%|██████████| 201/201 [00:00<00:00, 1631.57it/s]
100%|██████████| 201/201 [00:00<00:00, 1705.81it/s]
