<a href="https://colab.research.google.com/github/jakinng/a-PINN/blob/main/turbulence_DNS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install netCDF4

Collecting netCDF4
  Downloading netCDF4-1.6.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cftime (from netCDF4)
  Downloading cftime-1.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: cftime, netCDF4
Successfully installed cftime-1.6.3 netCDF4-1.6.5


In [39]:
# Dependencies for generating turbulence

import math
# import numpy as np
import jax.numpy as jnp
import jax
from jax.numpy import pi
import pandas as pd
import netCDF4 as nc
import matplotlib.pyplot as plt
import os

In [None]:
#@title Pseudo-Spectral Method for 2D Turbulence

def _add_convection(self):
    """
    Convective term. To prevent alliasing, we zero-pad the array before
    using the convolution theorem to evaluate it in physical space.
    """
    # initialize padded arrays
    j1f_padded = jnp.zeros((self.mx,self.mk), dtype='complex128')
    j2f_padded = jnp.zeros((self.mx,self.mk), dtype='complex128')
    j3f_padded = jnp.zeros((self.mx,self.mk), dtype='complex128')
    j4f_padded = jnp.zeros((self.mx,self.mk), dtype='complex128')

    # populate
    j1f_padded[self.padder, :self.nk] = 1.0j*self.kx[:self.nk     ]*self.psih[:,:]
    j2f_padded[self.padder, :self.nk] = 1.0j*self.ky[:, np.newaxis]  *self.wh[:,:]
    j3f_padded[self.padder, :self.nk] = 1.0j*self.ky[:, np.newaxis]*self.psih[:,:]
    j4f_padded[self.padder, :self.nk] = 1.0j*self.kx[:self.nk     ]  *self.wh[:,:]

    # backward transform
    j1 = np.fft.irfft2(j1f_padded, axes=(-2,-1))
    j2 = np.fft.irfft2(j2f_padded, axes=(-2,-1))
    j3 = np.fft.irfft2(j3f_padded, axes=(-2,-1))
    j4 = np.fft.irfft2(j4f_padded, axes=(-2,-1))

    # forward transform
    jacpf = np.fft.rfft2(j1*j2 - j3*j4, axes=(-2,-1))

    # this term is the result of padding, padder allows for easier sclicing
    self.dwhdt[:,:] = jacpf[self.padder, :self.nk]*self.pad**(2)



In [80]:
# -*- coding: utf-8 -*-
"""vcurrent generate_turbulence.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1KbxFOBQsy1w0Bc9x_G01wsryWjTn05Qk

this one may or may not work lol
"""
"""
Install dependencies for plotting
"""
# !python3 -m pip install pyDOE
# !python3 -m pip install scipy==1.6.2
# !python3 -m pip install ipykernel==5.3.4
# !python3 -m pip install matplotlib==3.4.2
# !python3 -m pip install tensorflow==2.4.1
# !python3 -m pip install tensorflow-probability==0.12.2
"""
Import statements
"""
# from pyDOE import lhs
# import tensorflow as tf
# import tensorflow_probability as tfp
# import numpy as jnp
# import matplotlib.pyplot as plt
# import time
# import scipy
# from scipy.interpolate import griddata
# import matplotlib.gridspec as gridspec

# from google.colab import drive
# drive.mount('/content/drive')

# @title Turbulence Code
# From https://twitter.com/CFDonia/status/1659560884769832961/photo/1
# https://marinlauber.github.io/2D-Turbulence/

def run_turbulence(
    N = 2 ** 6,
    tend = 50,
    forcing=True,
    symmetric_index=True,
    filename="content/drive/MyDrive/test.nc",
    description="Generated turbulence data using third-order Runge-Kutta method.",
    plot=True,
):
    L = 2 * math.pi  # length of x and y scales in square domain with LxL = [0, 2pi]x[0, 2pi]
    nu = 5e-4  # kinematic viscosity
    Sc = 0.7  # Schmidt number
    beta = 0  # meridional gradient of Coriolis parameter
    ar = 0.02  # random number amplitude
    b = 1  # mean scalar gradient
    CFLmax = 0.8
    seed = 1324

    # Create domain and wavenumbers in (discrete) Fourier domain
    x = jnp.linspace(0, L, N, endpoint=False)
    dx = L / N
    kx = jnp.fft.fftfreq(N, d = L / (N * 2 * pi)) # DFT sample frequencies
    ky = jnp.fft.fftfreq(N, d = L / (N * 2 * pi))
    KX, KY = jnp.meshgrid(kx, ky)

    # 2/3 dealiasing rule, since there are at most second-order derivatives https://math.jhu.edu/~feilu/notes/DealiasingFFT.pdf https://arxiv.org/pdf/1606.05432.pdf https://github.com/ketch/PseudoSpectralPython/blob/master/PSPython_03-FFT-aliasing-filtering.ipynb
    index_kmax = math.ceil(N / 3)
    filter = jnp.ones((N, N))
    if symmetric_index: # filters down to [-floor(N*2/3), floor(N*2/3) - 1]
        idx_l = index_kmax - 1
        idx_u = 2 * index_kmax
    else: # MATLAB code is slightly off from 2/3 rule
        idx_l = index_kmax
        idx_u = 2 * index_kmax + 2
    filter = filter.at[idx_l:idx_u, idx_l:idx_u].set(0)

    # Initialize to N x N arrays of zeros. These arrays are the "sampled values" of the variables
    u = jnp.zeros((N, N))  # x-component of velocity
    v = jnp.zeros((N, N))  # y-component of velocity
    omega = jnp.zeros((N, N))  # vorticity
    psi = jnp.zeros((N, N))  # stream function
    kk = jnp.zeros((N, N))  # TODO : ask about this
    k2 = jnp.zeros((N, N))  # TODO : ask about this

    # Sampled x- and y-derivatives in frequency domain (Algorithm 1 in https://math.mit.edu/~stevenj/fft-deriv.pdf)
    ddx = KX * 1j # d/dx of basis function b_kx,ky = e^(i(kx * x / + ky * y)) is i * kx * b_kx,ky https://marinlauber.github.io/2D-Turbulence/
    ddy = KY * 1j # To compute ddx(v), elementwise multiply ddx * v
    # Spectral inverse of del^2 (Laplacian) in frequency domain, where del2 = d^2f/dx^2 + d^2f/dy^2
    idel2 = 1 / ((1j * KX) ** 2 + (1j * KY) ** 2)
    idel2 = idel2.at[0, 0].set(0) # The mode is the mean value of psi, and results in a division by zero. To avoid rounding error, explicitly set this to zero every computation.
    # k^2 is the kx^2 + ky^2, where psi_hat = omega_hat / (kx^2 + ky^2) (relationship between vorticity and streamfunction)
    k2

    time = 0

# run_turbulence(include_phi = True, forcing = False, symmetric_index = True, filename = "/content/drive/MyDrive/6_27/test_no_forcing.nc")
N = 8
description = f"{N}_"
run_turbulence(
    forcing=True,
    symmetric_index=True,
    filename=f"/content/drive/MyDrive/Stanford/2024_02_01/fake_turbulence_{description}.nc",
    N=N,
    tend=25,
    plot=True,
)

In [25]:
# -*- coding: utf-8 -*-
"""vcurrent generate_turbulence.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1KbxFOBQsy1w0Bc9x_G01wsryWjTn05Qk

this one may or may not work lol
"""
"""
Install dependencies for plotting
"""
# !python3 -m pip install pyDOE
# !python3 -m pip install scipy==1.6.2
# !python3 -m pip install ipykernel==5.3.4
# !python3 -m pip install matplotlib==3.4.2
# !python3 -m pip install tensorflow==2.4.1
# !python3 -m pip install tensorflow-probability==0.12.2
"""
Import statements
"""
# from pyDOE import lhs
# import tensorflow as tf
# import tensorflow_probability as tfp
# import numpy as jnp
# import matplotlib.pyplot as plt
# import time
# import scipy
# from scipy.interpolate import griddata
# import matplotlib.gridspec as gridspec

# from google.colab import drive
# drive.mount('/content/drive')

# @title Turbulence Code
# From https://twitter.com/CFDonia/status/1659560884769832961/photo/1


def run_turbulence(
    N = 2 ** 6,
    tend = 50,
    forcing=True,
    symmetric_index=True,
    filename="content/drive/MyDrive/test.nc",
    description="Generated turbulence data using third-order Runge-Kutta method.",
    plot=True,
):
    Lx = 2 * math.pi  # length of y scale
    Ly = 2 * math.pi  # length of x scale
    nu = 5e-4  # kinematic viscosity
    Sc = 0.7  # Schmidt number
    beta = 0  # meridional gradient of Coriolis parameter
    ar = 0.02  # random number amplitude
    b = 1  # mean scalar gradient
    CFLmax = 0.8
    seed = 1324

    # x-grid
    x = jnp.linspace(0, Lx, N, endpoint=False)
    dx = Lx / N
    kx = jnp.append(
        jnp.arange(0, N / 2 + 1), jnp.arange(-N / 2 + 1, 0)
    )  # [0, 1, ..., M/2, -M/2+1, -M/2+2, ..., -1]

    print(f"{kx=}")

    kx = jnp.fft.fftfreq(N)

    print(f"{kx=}")
    # y-grid
    y = jnp.linspace(0, Ly, N, endpoint=False)
    dy = Ly / N
    ky = jnp.append(
        jnp.arange(0, N / 2 + 1), jnp.arange(-N / 2 + 1, 0)
    )  # [0, 1, ..., N/2, -N/2+1, -N/2+2, ..., -1]

    KX, KY = jnp.meshgrid(kx, ky)

    time = 0
    index_kmax = math.ceil(N / 3)
    kmax = kx[index_kmax]
    filter = jnp.ones((N, N))
    if symmetric_index:
        idx_l = index_kmax - 1
        idx_u = 2 * index_kmax
    else:
        idx_l = index_kmax
        idx_u = 2 * index_kmax + 3
    filter = filter.at[idx_l:idx_u, idx_l:idx_u].set(0)

    # Seed random number generator
    key = jax.random.PRNGKey(seed)

    # Initialize to M x N arrays of zeros. These arrays are the "sampled values" of the variables
    u = jnp.zeros((N, N))  # x-component of velocity
    v = jnp.zeros((N, N))  # y-component of velocity
    omega = jnp.zeros((N, N))  # vorticity
    psi = jnp.zeros((N, N))  # stream function
    ddx = jnp.zeros((N, N))  # Sampled x-derivative in frequency domain
    ddy = jnp.zeros((N, N))  # Sampled y-derivative in frequency domain
    idel2 = jnp.zeros((N, N))  # Inverse of del^2 (Laplacian) in frequency domain
    kk = jnp.zeros((N, N))  # TODO : ask about this
    k2 = jnp.zeros((N, N))  # TODO : ask about this

    # Take sampled derivative in frequency domain (Algorithm 1 in https://math.mit.edu/~stevenj/fft-deriv.pdf)
    ddx = jnp.tile((1j * kx)[:, None], (1, N))  # Spectral x-derivative matrix
    print(f"{ddx=}")
    ddy = jnp.tile((1j * ky)[None, :], (N, 1))  # Spectral y-derivative matrix
    idel2 = jnp.array(
        [
            [
                0 if (i == 0 and j == 0) else 1 / (-1 * kx[i] ** 2 - 1 * ky[j] ** 2)
                for j in range(N)
            ]
            for i in range(N)
        ]
    )  # Spectral inverse of Laplacian del^2 operator
    print(idel2)

    ##### FORCING #######
    if forcing:
        print("forcing")
        kk = jnp.array([[kx[i] ** 2 + ky[j] ** 2 for j in range(N)] for i in range(N)])
        for i in range(N):
            for j in range(N):
                # Simple forcing: invert viscosity sign in band
                if kk[i, j] >= 6**2 and kk[i, j] <= 7**2:
                    kk[i, j] = -kk[i, j]
                # Increase viscosity 8 times for large scale dissipation
                if kk[i, j] <= 2**2:
                    # kk[i, j] = 8 * kk[i, j]
                    kk = kk.at[i, j].multiply(8)
    ##### FORCING #######

    # print(f"kk: {kk}")
    # plt.figure()
    # plt.pcolormesh(kk)
    # plt.show()

    # Initialize velocity vector with random perturbation
    # Taylor-Green initial condition
    u = jnp.array(
        [
            [
                jnp.cos(2 * x[i]) * jnp.sin(2 * y[j]) + ar * jnp.random.random()
                for j in range(N)
            ]
            for i in range(N)
        ]
    )
    v = jnp.array(
        [
            [
                -jnp.sin(2 * x[i]) * jnp.cos(2 * y[j]) + ar * jnp.random.random()
                for j in range(N)
            ]
            for i in range(N)
        ]
    )

    # 2D Fourier Transform
    uhat = jnp.fft.fft2(u)
    vhat = jnp.fft.fft2(v)
    omegahat = ddx * vhat - ddy * uhat  # vorticity: omega = curl((u, v)) in 2D
    omega = jnp.fft.ifft2(omegahat * filter).real

    if os.path.isfile(filename):
        os.remove(filename)  ## COMMENT OUT THIS LINE TO NOT DESTROY FILE
    ds = nc.Dataset(filename, "w", format="NETCDF4")
    ds.description = description
    ds.variation = f"Filter out frequencies outside of {kx[idx_u - 1]} and {kx[idx_l]}."
    ds.source = "Adapted code from https://twitter.com/CFDonia/status/1659560884769832961/photo/1 into Python by Jakin Ng."

    x_dim = ds.createDimension("x", N)
    y_dim = ds.createDimension("y", N)
    time_dim = ds.createDimension("time", None)

    ds.createVariable("x", "f4", ("x"))
    ds.createVariable("y", "f4", ("y"))
    ds.createVariable("time", "f4", ("time"))
    ds.createVariable("u", "f4", ("x", "y", "time"))
    ds.createVariable("v", "f4", ("x", "y", "time"))
    ds.createVariable("vorticity", "f4", ("x", "y", "time"))
    ds.createVariable("psi", "f4", ("x", "y", "time"))
    ds.createVariable("dissipation", "f4", ("x", "y", "time"))

    ds["x"][:] = x
    ds["y"][:] = y
    ds["time"][0] = 0
    ds["u"][:, :, 0] = u
    ds["v"][:, :, 0] = v
    ds["vorticity"][:, :, 0] = omega
    ds["dissipation"][:, :, 0] = jnp.zeros((N, N))
    ds["psi"][:, :, 0] = psi

    ### 3-STEP RUNGE-KUTTA
    dt = 0.5 * min(dx, dy)
    nstep = 1
    # while False:
    while time < tend:
        # for i in range(3):
        # for i in range(20):
        #### Substep 1
        psihat = -1 * idel2 * omegahat
        uhat = ddy * psihat
        vhat = -1 * ddx * psihat

        u = jnp.fft.ifft2(uhat).real
        v = jnp.fft.ifft2(vhat).real

        omegadx = jnp.fft.ifft2(ddx * omegahat).real
        omegady = jnp.fft.ifft2(ddy * omegahat).real

        facto = jnp.exp(-1 * nu * 8 / 15 * dt * kk)
        r0o = -1 * jnp.fft.fft2(u * omegadx + v * omegady) + beta * vhat
        omegahat = facto * (omegahat + dt * 8 / 15 * r0o)  # update omega

        #### Substep 2
        psihat = -1 * idel2 * omegahat
        uhat = ddy * psihat
        vhat = -1 * ddx * psihat

        u = jnp.fft.ifft2(uhat).real
        v = jnp.fft.ifft2(vhat).real

        omegadx = jnp.fft.ifft2(ddx * omegahat).real
        omegady = jnp.fft.ifft2(ddy * omegahat).real

        r1o = -1 * jnp.fft.fft2(u * omegadx + v * omegady) + beta * vhat
        omegahat = omegahat + dt * (-17 / 60 * facto * r0o + 5 / 12 * r1o)
        facto = jnp.exp(-1 * nu * (-17 / 60 + 5 / 12) * dt * kk)
        omegahat = omegahat * facto

        #### Substep 3
        psihat = -1 * idel2 * omegahat
        uhat = ddy * psihat
        vhat = -1 * ddx * psihat

        u = jnp.fft.ifft2(uhat).real
        v = jnp.fft.ifft2(vhat).real

        omegadx = jnp.fft.ifft2(ddx * omegahat).real
        omegady = jnp.fft.ifft2(ddy * omegahat).real

        r2o = -1 * jnp.fft.fft2(u * omegadx + v * omegady) + beta * vhat
        omegahat = omegahat + dt * (-5 / 12 * facto * r1o + 3 / 4 * r2o)
        facto = jnp.exp(-1 * nu * (-5 / 12 + 3 / 4) * dt * kk)
        omegahat = omegahat * facto
        omegahat = filter * omegahat

        ## Increment
        time = time + dt
        nstep = nstep + 1

        # Courant–Friedrichs–Lewy condition
        CFL = jnp.max(jnp.abs(u)) / dx * dt + jnp.max(jnp.abs(v)) / dy * dt

        psi = jnp.fft.ifft2(psihat).real
        omega = jnp.fft.ifft2(omegahat).real
        dissipation = (
            2
            * nu
            * (
                jnp.fft.ifft2(ddx * uhat).real ** 2
                + jnp.fft.ifft2(ddy * uhat).real ** 2
                + jnp.fft.ifft2(ddx * vhat).real ** 2
                + jnp.fft.ifft2(ddy * vhat).real ** 2
            )
        )
        eta = (nu**3 / jnp.mean(dissipation)) ** 0.25

        if nstep % 10 == 0:
            print(
                f"step = {nstep}    time = {time}   dt = {dt}    CFL = {CFL}  kmax*eta = {kmax * eta}   kmax*eta/sqrt(Sc) = {kmax * eta / math.sqrt(Sc)}"
            )
            if plot:
                fig = plt.figure(figsize=[5, 5], dpi=300)

                ax = plt.subplot(221)
                ax.set_title("Vorticity")
                pcolor = ax.pcolormesh(x, y, omega, shading="auto")
                plt.colorbar(pcolor, ax=ax)

                ax = plt.subplot(222)
                ax.set_title("Stream Function")
                pcolor = ax.pcolormesh(x, y, psi, shading="auto")
                plt.colorbar(pcolor, ax=ax)

                ax = plt.subplot(223)
                ax.set_title("Dissipation")
                pcolor = ax.pcolormesh(x, y, dissipation, shading="auto")
                plt.colorbar(pcolor, ax=ax)

                ax = plt.subplot(224)
                ax.set_title("u-velocity")
                pcolor = ax.pcolormesh(x, y, u, shading="auto")
                plt.colorbar(pcolor, ax=ax)

                plt.tight_layout()
                plt.savefig(f"{filename[:-3]}{nstep}.png")
                plt.show()

        if nstep % 10 == 0:
            t_idx = len(ds["time"][:])
            ds["time"][t_idx] = time
            ds["u"][:, :, t_idx] = u
            ds["v"][:, :, t_idx] = v
            ds["vorticity"][:, :, t_idx] = omega
            ds["dissipation"][:, :, t_idx] = dissipation
            ds["psi"][:, :, t_idx] = psi
            ds.close()
            ds = nc.Dataset(filename, "r+", format="NETCDF4")
            dt = CFLmax / CFL * dt
    ds.close()


# run_turbulence(include_phi = True, forcing = False, symmetric_index = True, filename = "/content/drive/MyDrive/6_27/test_no_forcing.nc")
N = 8
description = f"{N}_"
run_turbulence(
    forcing=True,
    symmetric_index=True,
    filename=f"/content/drive/MyDrive/Stanford/2024_02_01/fake_turbulence_{description}.nc",
    N=N,
    tend=25,
    plot=True,
)

kx=Array([ 0.,  1.,  2.,  3.,  4., -3., -2., -1.], dtype=float32)
kx=Array([ 0.   ,  0.125,  0.25 ,  0.375, -0.5  , -0.375, -0.25 , -0.125],      dtype=float32)
ddx=Array([[ 0.+0.j   ,  0.+0.j   ,  0.+0.j   ,  0.+0.j   ,  0.+0.j   ,
         0.+0.j   ,  0.+0.j   ,  0.+0.j   ],
       [ 0.+0.125j,  0.+0.125j,  0.+0.125j,  0.+0.125j,  0.+0.125j,
         0.+0.125j,  0.+0.125j,  0.+0.125j],
       [ 0.+0.25j ,  0.+0.25j ,  0.+0.25j ,  0.+0.25j ,  0.+0.25j ,
         0.+0.25j ,  0.+0.25j ,  0.+0.25j ],
       [ 0.+0.375j,  0.+0.375j,  0.+0.375j,  0.+0.375j,  0.+0.375j,
         0.+0.375j,  0.+0.375j,  0.+0.375j],
       [-0.-0.5j  , -0.-0.5j  , -0.-0.5j  , -0.-0.5j  , -0.-0.5j  ,
        -0.-0.5j  , -0.-0.5j  , -0.-0.5j  ],
       [-0.-0.375j, -0.-0.375j, -0.-0.375j, -0.-0.375j, -0.-0.375j,
        -0.-0.375j, -0.-0.375j, -0.-0.375j],
       [-0.-0.25j , -0.-0.25j , -0.-0.25j , -0.-0.25j , -0.-0.25j ,
        -0.-0.25j , -0.-0.25j , -0.-0.25j ],
       [-0.-0.125j, -0.-0.125j, -0.-0.125j, 

AttributeError: module 'jax.numpy' has no attribute 'random'

In [None]:
# -*- coding: utf-8 -*-
"""vcurrent generate_turbulence.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1KbxFOBQsy1w0Bc9x_G01wsryWjTn05Qk

this one may or may not work lol
"""
"""
Install dependencies for plotting
"""
# !python3 -m pip install pyDOE
# !python3 -m pip install scipy==1.6.2
# !python3 -m pip install ipykernel==5.3.4
# !python3 -m pip install matplotlib==3.4.2
# !python3 -m pip install tensorflow==2.4.1
# !python3 -m pip install tensorflow-probability==0.12.2
"""
Import statements
"""
# from pyDOE import lhs
# import tensorflow as tf
# import tensorflow_probability as tfp
# import numpy as jnp
# import matplotlib.pyplot as plt
# import time
# import scipy
# from scipy.interpolate import griddata
# import matplotlib.gridspec as gridspec

# from google.colab import drive
# drive.mount('/content/drive')

# @title Turbulence Code
# From https://twitter.com/CFDonia/status/1659560884769832961/photo/1


def run_turbulence(
    N = 2 ** 6,
    tend = 50,
    forcing=True,
    symmetric_index=True,
    filename="content/drive/MyDrive/test.nc",
    description="Generated turbulence data using third-order Runge-Kutta method.",
    plot=True,
):
    Lx = 2 * math.pi  # length of y scale
    Ly = 2 * math.pi  # length of x scale
    nu = 5e-4  # kinematic viscosity
    Sc = 0.7  # Schmidt number
    beta = 0  # meridional gradient of Coriolis parameter
    ar = 0.02  # random number amplitude
    b = 1  # mean scalar gradient
    CFLmax = 0.8
    seed = 1324

    # x-grid
    x = jnp.linspace(0, Lx, N, endpoint=False)
    dx = Lx / N
    kx = jnp.append(
        jnp.arange(0, N / 2 + 1), jnp.arange(-N / 2 + 1, 0)
    )  # [0, 1, ..., M/2, -M/2+1, -M/2+2, ..., -1]

    print(f"{kx=}")

    # y-grid
    y = jnp.linspace(0, Ly, N, endpoint=False)
    dy = Ly / N
    ky = jnp.append(
        jnp.arange(0, N / 2 + 1), jnp.arange(-N / 2 + 1, 0)
    )  # [0, 1, ..., N/2, -N/2+1, -N/2+2, ..., -1]

    KX, KY = jnp.meshgrid(kx, ky)

    time = 0
    index_kmax = math.ceil(N / 3)
    kmax = kx[index_kmax]
    filter = jnp.ones((N, N))
    if symmetric_index:
        idx_l = index_kmax - 1
        idx_u = 2 * index_kmax
    else:
        idx_l = index_kmax
        idx_u = 2 * index_kmax + 3
    filter = filter.at[idx_l:idx_u, idx_l:idx_u].set(0)

    # Seed random number generator
    key = jax.random.PRNGKey(seed)

    # Initialize to M x N arrays of zeros. These arrays are the "sampled values" of the variables
    u = jnp.zeros((N, N))  # x-component of velocity
    v = jnp.zeros((N, N))  # y-component of velocity
    omega = jnp.zeros((N, N))  # vorticity
    psi = jnp.zeros((N, N))  # stream function
    ddx = jnp.zeros((N, N))  # Sampled x-derivative in frequency domain
    ddy = jnp.zeros((N, N))  # Sampled y-derivative in frequency domain
    idel2 = jnp.zeros((N, N))  # Inverse of del^2 (Laplacian) in frequency domain
    kk = jnp.zeros((N, N))  # TODO : ask about this
    k2 = jnp.zeros((N, N))  # TODO : ask about this

    # Take sampled derivative in frequency domain (Algorithm 1 in https://math.mit.edu/~stevenj/fft-deriv.pdf)
    ddx = jnp.tile((1j * kx)[:, None], (1, N))  # Spectral x-derivative matrix
    print(f"{ddx=}")
    ddy = jnp.tile((1j * ky)[None, :], (N, 1))  # Spectral y-derivative matrix
    idel2 = jnp.array(
        [
            [
                0 if (i == 0 and j == 0) else 1 / (-1 * kx[i] ** 2 - 1 * ky[j] ** 2)
                for j in range(N)
            ]
            for i in range(N)
        ]
    )  # Spectral inverse of Laplacian del^2 operator
    print(idel2)

    ##### FORCING #######
    if forcing:
        print("forcing")
        kk = jnp.array([[kx[i] ** 2 + ky[j] ** 2 for j in range(N)] for i in range(N)])
        for i in range(N):
            for j in range(N):
                # Simple forcing: invert viscosity sign in band
                if kk[i, j] >= 6**2 and kk[i, j] <= 7**2:
                    kk[i, j] = -kk[i, j]
                # Increase viscosity 8 times for large scale dissipation
                if kk[i, j] <= 2**2:
                    # kk[i, j] = 8 * kk[i, j]
                    kk = kk.at[i, j].multiply(8)
    ##### FORCING #######

    # print(f"kk: {kk}")
    # plt.figure()
    # plt.pcolormesh(kk)
    # plt.show()

    # Initialize velocity vector with random perturbation
    # Taylor-Green initial condition
    u = jnp.array(
        [
            [
                jnp.cos(2 * x[i]) * jnp.sin(2 * y[j]) + ar * jnp.random.random()
                for j in range(N)
            ]
            for i in range(N)
        ]
    )
    v = jnp.array(
        [
            [
                -jnp.sin(2 * x[i]) * jnp.cos(2 * y[j]) + ar * jnp.random.random()
                for j in range(N)
            ]
            for i in range(N)
        ]
    )

    # 2D Fourier Transform
    uhat = jnp.fft.fft2(u)
    vhat = jnp.fft.fft2(v)
    omegahat = ddx * vhat - ddy * uhat  # vorticity: omega = curl((u, v)) in 2D
    omega = jnp.fft.ifft2(omegahat * filter).real

    if os.path.isfile(filename):
        os.remove(filename)  ## COMMENT OUT THIS LINE TO NOT DESTROY FILE
    ds = nc.Dataset(filename, "w", format="NETCDF4")
    ds.description = description
    ds.variation = f"Filter out frequencies outside of {kx[idx_u - 1]} and {kx[idx_l]}."
    ds.source = "Adapted code from https://twitter.com/CFDonia/status/1659560884769832961/photo/1 into Python by Jakin Ng."

    x_dim = ds.createDimension("x", N)
    y_dim = ds.createDimension("y", N)
    time_dim = ds.createDimension("time", None)

    ds.createVariable("x", "f4", ("x"))
    ds.createVariable("y", "f4", ("y"))
    ds.createVariable("time", "f4", ("time"))
    ds.createVariable("u", "f4", ("x", "y", "time"))
    ds.createVariable("v", "f4", ("x", "y", "time"))
    ds.createVariable("vorticity", "f4", ("x", "y", "time"))
    ds.createVariable("psi", "f4", ("x", "y", "time"))
    ds.createVariable("dissipation", "f4", ("x", "y", "time"))

    ds["x"][:] = x
    ds["y"][:] = y
    ds["time"][0] = 0
    ds["u"][:, :, 0] = u
    ds["v"][:, :, 0] = v
    ds["vorticity"][:, :, 0] = omega
    ds["dissipation"][:, :, 0] = jnp.zeros((N, N))
    ds["psi"][:, :, 0] = psi

    ### 3-STEP RUNGE-KUTTA
    dt = 0.5 * min(dx, dy)
    nstep = 1
    # while False:
    while time < tend:
        # for i in range(3):
        # for i in range(20):
        #### Substep 1
        psihat = -1 * idel2 * omegahat
        uhat = ddy * psihat
        vhat = -1 * ddx * psihat

        u = jnp.fft.ifft2(uhat).real
        v = jnp.fft.ifft2(vhat).real

        omegadx = jnp.fft.ifft2(ddx * omegahat).real
        omegady = jnp.fft.ifft2(ddy * omegahat).real

        facto = jnp.exp(-1 * nu * 8 / 15 * dt * kk)
        r0o = -1 * jnp.fft.fft2(u * omegadx + v * omegady) + beta * vhat
        omegahat = facto * (omegahat + dt * 8 / 15 * r0o)  # update omega

        #### Substep 2
        psihat = -1 * idel2 * omegahat
        uhat = ddy * psihat
        vhat = -1 * ddx * psihat

        u = jnp.fft.ifft2(uhat).real
        v = jnp.fft.ifft2(vhat).real

        omegadx = jnp.fft.ifft2(ddx * omegahat).real
        omegady = jnp.fft.ifft2(ddy * omegahat).real

        r1o = -1 * jnp.fft.fft2(u * omegadx + v * omegady) + beta * vhat
        omegahat = omegahat + dt * (-17 / 60 * facto * r0o + 5 / 12 * r1o)
        facto = jnp.exp(-1 * nu * (-17 / 60 + 5 / 12) * dt * kk)
        omegahat = omegahat * facto

        #### Substep 3
        psihat = -1 * idel2 * omegahat
        uhat = ddy * psihat
        vhat = -1 * ddx * psihat

        u = jnp.fft.ifft2(uhat).real
        v = jnp.fft.ifft2(vhat).real

        omegadx = jnp.fft.ifft2(ddx * omegahat).real
        omegady = jnp.fft.ifft2(ddy * omegahat).real

        r2o = -1 * jnp.fft.fft2(u * omegadx + v * omegady) + beta * vhat
        omegahat = omegahat + dt * (-5 / 12 * facto * r1o + 3 / 4 * r2o)
        facto = jnp.exp(-1 * nu * (-5 / 12 + 3 / 4) * dt * kk)
        omegahat = omegahat * facto
        omegahat = filter * omegahat

        ## Increment
        time = time + dt
        nstep = nstep + 1

        # Courant–Friedrichs–Lewy condition
        CFL = jnp.max(jnp.abs(u)) / dx * dt + jnp.max(jnp.abs(v)) / dy * dt

        psi = jnp.fft.ifft2(psihat).real
        omega = jnp.fft.ifft2(omegahat).real
        dissipation = (
            2
            * nu
            * (
                jnp.fft.ifft2(ddx * uhat).real ** 2
                + jnp.fft.ifft2(ddy * uhat).real ** 2
                + jnp.fft.ifft2(ddx * vhat).real ** 2
                + jnp.fft.ifft2(ddy * vhat).real ** 2
            )
        )
        eta = (nu**3 / jnp.mean(dissipation)) ** 0.25

        if nstep % 10 == 0:
            print(
                f"step = {nstep}    time = {time}   dt = {dt}    CFL = {CFL}  kmax*eta = {kmax * eta}   kmax*eta/sqrt(Sc) = {kmax * eta / math.sqrt(Sc)}"
            )
            if plot:
                fig = plt.figure(figsize=[5, 5], dpi=300)

                ax = plt.subplot(221)
                ax.set_title("Vorticity")
                pcolor = ax.pcolormesh(x, y, omega, shading="auto")
                plt.colorbar(pcolor, ax=ax)

                ax = plt.subplot(222)
                ax.set_title("Stream Function")
                pcolor = ax.pcolormesh(x, y, psi, shading="auto")
                plt.colorbar(pcolor, ax=ax)

                ax = plt.subplot(223)
                ax.set_title("Dissipation")
                pcolor = ax.pcolormesh(x, y, dissipation, shading="auto")
                plt.colorbar(pcolor, ax=ax)

                ax = plt.subplot(224)
                ax.set_title("u-velocity")
                pcolor = ax.pcolormesh(x, y, u, shading="auto")
                plt.colorbar(pcolor, ax=ax)

                plt.tight_layout()
                plt.savefig(f"{filename[:-3]}{nstep}.png")
                plt.show()

        if nstep % 10 == 0:
            t_idx = len(ds["time"][:])
            ds["time"][t_idx] = time
            ds["u"][:, :, t_idx] = u
            ds["v"][:, :, t_idx] = v
            ds["vorticity"][:, :, t_idx] = omega
            ds["dissipation"][:, :, t_idx] = dissipation
            ds["psi"][:, :, t_idx] = psi
            ds.close()
            ds = nc.Dataset(filename, "r+", format="NETCDF4")
            dt = CFLmax / CFL * dt
    ds.close()


# run_turbulence(include_phi = True, forcing = False, symmetric_index = True, filename = "/content/drive/MyDrive/6_27/test_no_forcing.nc")
N = 8
description = f"{N}_"
run_turbulence(
    forcing=True,
    symmetric_index=True,
    filename=f"/content/drive/MyDrive/Stanford/2024_02_01/fake_turbulence_{description}.nc",
    N=N,
    tend=25,
    plot=True,
)