In [9]:
from __future__ import annotations

import sys
import time
from math import ceil
import numpy as np
import jax
import jax.numpy as jnp
from jax import device_put, lax

import numpy as np
import matplotlib.pyplot as plt
import h5py
from scipy.stats import skewnorm

In [10]:
def _pass(carry):
    return carry


def init(xc, u0=1.0, du=0.1):
    u = u0 * jnp.sin((xc + 1.0) * jnp.pi)
    return u

In [11]:
def Courant(u, dx):
    stability_adv = dx / (jnp.max(jnp.abs(u)) + 1.0e-8)
    return stability_adv


def Courant_diff(dx, epsilon=1.0e-3):
    stability_dif = 0.5 * dx**2 / (epsilon + 1.0e-8)
    return stability_dif

In [12]:
def generate_noise(shape, noise_level):
    a = np.random.rand()
    if a < 0.5:
        parameter = np.random.uniform(-5, 5)
        return skewnorm.rvs(a=parameter, size=shape) * noise_level
    else:
        parameter = np.random.uniform(0, 4)
        return np.random.exponential(scale=parameter, size=shape) * noise_level

In [None]:
def bc(u, Ncell, mode="periodic", noise_level=0.0):
    _u = jnp.zeros(Ncell + 4)
    _u = _u.at[2 : Ncell + 2].set(u)
    if mode == "periodic":
        _u = _u.at[0:2].set(u[-2:])  
        _u = _u.at[Ncell + 2 : Ncell + 4].set(u[0:2]) 
    return _u

In [14]:
def VLlimiter(a, b, c, alpha=2.0):
    return (
        jnp.sign(c)
        * (0.5 + 0.5 * jnp.sign(a * b))
        * jnp.minimum(alpha * jnp.minimum(jnp.abs(a), jnp.abs(b)), jnp.abs(c))
    )


def limiting(u, Ncell, if_second_order):
    du_L = u[1 : Ncell + 3] - u[0 : Ncell + 2]
    du_R = u[2 : Ncell + 4] - u[1 : Ncell + 3]
    du_M = (u[2 : Ncell + 4] - u[0 : Ncell + 2]) * 0.5
    gradu = VLlimiter(du_L, du_R, du_M) * if_second_order
    uL, uR = jnp.zeros_like(u), jnp.zeros_like(u)
    uL = uL.at[1 : Ncell + 3].set(u[1 : Ncell + 3] - 0.5 * gradu)
    uR = uR.at[1 : Ncell + 3].set(u[1 : Ncell + 3] + 0.5 * gradu)
    return uL, uR

In [15]:
dt_save = 0.01
ini_time = 0.0
fin_time = 2.0
nx = 1024
xL = -1.0
xR = 1.0
if_second_order = 1.0
show_steps = 100

In [16]:
dx = (xR - xL) / nx
xe = jnp.linspace(xL, xR, nx + 1)
xc = xe[:-1] + 0.5 * dx
it_tot = ceil((fin_time - ini_time) / dt_save) + 1
tc = jnp.arange(it_tot + 1) * dt_save

In [None]:
with h5py.File("data.h5", "w") as g:
    f = g.create_group("coords")
    f.create_dataset("x-coordinates", data=xc)
    f.create_dataset("t-coordinates", data=tc)
    g.close()

In [None]:
def gen(path) -> None:
    epsilon = np.random.uniform(1.0e-4, 1.0e-1)
    u0 = np.random.uniform(0.5, 2.0)
    du = np.random.uniform(0.0, 0.5)
    CFL = np.random.uniform(0.1, 0.9)
    noise_level = np.random.uniform(0.0, 0.5)

    pi_inv = 1.0 / jnp.pi
    dx = (xR - xL) / nx
    dx_inv = 1.0 / dx

    @jax.jit
    def evolve(u, noise_level=0.0, equation_noise_level=0.0):
        t = ini_time
        tsave = t
        steps = 0
        i_save = 0
        dt = 0.0
        uu = jnp.zeros([it_tot, u.shape[0]])
        uu = uu.at[0].set(u)

        tm_ini = time.time()

        cond_fun = lambda x: x[0] < fin_time

        def _body_fun(carry):
            def _save(_carry):
                u, tsave, i_save, uu = _carry
                uu = uu.at[i_save].set(u)
                tsave += dt_save
                i_save += 1
                return (u, tsave, i_save, uu)

            t, tsave, steps, i_save, dt, u, uu = carry

            # if save data
            carry = (u, tsave, i_save, uu)
            u, tsave, i_save, uu = lax.cond(t >= tsave, _save, _pass, carry)

            # Pass noise levels to simulation_fn here
            carry = (u, t, dt, steps, tsave)
            u, t, dt, steps, tsave = lax.fori_loop(
                0,
                show_steps,
                lambda i, carry: simulation_fn(
                    i, carry, noise_level, equation_noise_level
                ),
                carry,
            )
            return (t, tsave, steps, i_save, dt, u, uu)

        carry = t, tsave, steps, i_save, dt, u, uu
        t, tsave, steps, i_save, dt, u, uu = lax.while_loop(cond_fun, _body_fun, carry)
        uu = uu.at[-1].set(u)

        tm_fin = time.time()
        return uu, t

    @jax.jit
    def simulation_fn(i, carry, noise_level=0.0, equation_noise_level=0.0):
        u, t, dt, steps, tsave = carry
        dt_adv = Courant(u, dx) * CFL
        dt_dif = Courant_diff(dx, epsilon * pi_inv) * CFL
        dt = jnp.min(jnp.array([dt_adv, dt_dif, fin_time - t, tsave - t]))

        def _update(carry):
            u, dt = carry
            u_tmp = update(u, u, dt * 0.5, noise_level, equation_noise_level)
            u = update(u, u_tmp, dt, noise_level, equation_noise_level)
            return u, dt

        carry = u, dt
        u, dt = lax.cond(dt > 1.0e-8, _update, _pass, carry)

        t += dt
        steps += 1
        return u, t, dt, steps, tsave

    @jax.jit
    def update(u, u_tmp, dt, noise_level=0.0, equation_noise_level=0.0):
        f = flux(u_tmp, noise_level)
        noise = generate_noise(f.shape, equation_noise_level)
        f = f + noise  # Add noise to the flux
        u -= dt * dx_inv * (f[1 : nx + 1] - f[0:nx])
        return u

    def flux(u, noise_level=0.0):
        _u = bc(
            u, Ncell=nx, noise_level=noise_level
        )  # index 2 for _U is equivalent with index 0 for u
        uL, uR = limiting(_u, nx, if_second_order=if_second_order)
        fL = 0.5 * uL**2
        fR = 0.5 * uR**2
        # Upwind advection scheme
        f_upwd = 0.5 * (
            fR[1 : nx + 2]
            + fL[2 : nx + 3]
            - 0.5
            * jnp.abs(uL[2 : nx + 3] + uR[1 : nx + 2])
            * (uL[2 : nx + 3] - uR[1 : nx + 2])
        )
        # Source term
        f_upwd += -epsilon * pi_inv * (_u[2 : nx + 3] - _u[1 : nx + 2]) * dx_inv
        return f_upwd

    # Initialize the solution without noise
    u = init(xc=xc, u0=u0, du=du)

    # Add noise to the initial condition
    noise = generate_noise(u.shape, noise_level)
    u_noisy = u + noise

    # Evolve the solution without noise
    u_clean = device_put(u)  # Putting variables in GPU (not necessary??)
    uu_clean, t_clean = evolve(u_clean)



    with h5py.File("data.h5", "a") as g:
        f = g.create_group(f"{path}")
        f.create_dataset("epsilon", data=epsilon)
        f.create_dataset("u0", data=u0)
        f.create_dataset("du", data=du)
        f.create_dataset("CFL", data=CFL)
        f.create_dataset("clean", data=uu_clean)
        g.close()

100 samples are generated for future training.

In [None]:
for i in range(1000):
    print(f"Running simulation {i + 1}/100")
    gen(i)

In [22]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from tqdm import tqdm


def visualize_burgers(xcrd, data, name):
    fig, ax = plt.subplots()

    ims = []

    for i in tqdm(range(data.shape[0])):
        if i == 0:
            im = ax.plot(xcrd, data[i].squeeze(), animated=True, color="blue")
        else:
            im = ax.plot(
                xcrd, data[i].squeeze(), animated=True, color="blue"
            )
        ims.append([im[0]])
    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)

    writer = animation.PillowWriter(fps=15, bitrate=1800)
    ani.save(f"{name}.gif", writer=writer)
    plt.close(fig)

In [None]:
from scipy.stats import skewnorm
with h5py.File("data.h5", "a") as f:
    l = list(f.keys())[:-1]
    for i in l:
        print(i)
        u = f[i]["clean"][:]
        noise = skewnorm.rvs(a=1, scale=0.2, size=u.shape)
        un = u + noise
        try:
            f[i].create_dataset("noisy", data=un)
        except:
            print(i)



In [24]:
import random


with h5py.File("simulation_data.h5", "r") as f:
    l = random.choices(list(f.keys()), k=3)
    print(l)
    xc = f["coords"]["x-coordinates"][:]
    for i in l:
        if i == 'coords':
            continue
        print(f"Visualizing simulation {i}")
        data = f[i]["clean"][:]
        visualize_burgers(xc, data, f"{i}_clean")

        data = f[i]["noisy"][:]
        visualize_burgers(xc, data, f"{i}_noisy")

['30', '74', '63']
Visualizing simulation 30


100%|██████████| 201/201 [00:00<00:00, 1867.91it/s]
100%|██████████| 201/201 [00:00<00:00, 2546.29it/s]


Visualizing simulation 74


100%|██████████| 201/201 [00:00<00:00, 2537.61it/s]
100%|██████████| 201/201 [00:00<00:00, 612.85it/s]


Visualizing simulation 63


100%|██████████| 201/201 [00:00<00:00, 2704.81it/s]
100%|██████████| 201/201 [00:00<00:00, 1540.74it/s]
