## Lou-Hou Solutions

We are looking for solutions of the profile chosen by Lou. Hou (2014). Starting with a single, purely rotating eddy in a cylinder, they demonstrated evidence of a singularity forming in finite-time. The initial profile required an odd function in the z-direction of the cylinder.

They claimed that as the simulation progressed, the singularity will become asymptotically self-similar.

## Generalized De Gregorio Equations

The De Gregorio equations, specifiaclly the derived CCF equation, has been shown to have a stable singularities in finite time. The generalized De Gregorio equations are:

$$\omega_t + a u \omega_x = \omega u_x$$

Where $\omega$ is the vorticity and $u$ is the velocity field. We will be analyzing a particular form of the De Gregorio equation, with $a=-1$. This is also known as the CCF equation. 

Using the self-similar ansatz for singularities from Wang et. al. (2023):

$$\omega(x,t) = \frac{1}{1-t}\Omega(\frac{x}{(1-t)^{1+\lambda}})$$

The equation is parametrized by $\lambda$. Then if we define the velocity $u = \int_0^y H\Omega ds$, and the change of coordinates $y=\frac{x}{(1-t)^{1+\lambda}}$, the De Gregorio equation becomes:

$$\Omega + ((1+\lambda)y-u)\frac{\partial \Omega}{\partial y}-\Omega\frac{\partial u}{\partial y}=0$$

Where $a=-1$. 

### Loss Function

The loss function is composed of a *condition loss* and a *equation loss* to track the residues of the Boundary Conditions and governing equations respectively. We implicitly impose the odd condition on the function, and the decay to infinity as:

$$q = (\frac{NN_q(z) - NN_q(-z)}{2})\cdot (1+z^2)^{-1/(2(1+\lambda))}$$

Where $q(z)\in\{u(z), \Omega(z)\}$.

To force the NN away from the trivial solution, we impose normalization conditions:

$$g_1 = \partial_y \Omega(0)+2$$
$$g_2 = \Omega(0.5) + 0.05$$
$$g_3 = \sum_{y\in Y_\infty} \Omega(y)^2$$

Where $g_1, g_2$ normalize the solution away from zero, and $g_3$ guides the function to decay far away from the origin.

Due to the nonlocality of the Hilbert Tranform, the De Gregorio equations need to be solved in a large domain. Therefore we define a new $z$-coordinate with the relation:

$$y=\sinh(z) \iff z=\sinh^{-1}(y)$$

In practice, we sample in the range $d\in[-30,30]$.

The equation losses become:

$$f_1 = \Omega(z) + ((1+\lambda)\sinh(z)+au(z)) \cosh^{-1}(z)\partial_z[\Omega(z)] - \Omega(z) \cosh^{-1}(z)\partial_z[u(z)]$$

$$f_2 = \cosh^{-1}(z)\partial_z[u(z)]-H_n[\Omega(z)]$$

Where $f_2$ was derived from the definition $u_y=H\Omega$, $H_n$ is the numerical Hilbert transform. 

For the CCF equations in particular, in order to avoid optimizing for local maxima, we add an additional smoothness constraint in the form of a 3rd order loss term:

$$loss_s = \frac{1}{N_s}\sum_{i=1}^{N_s}|\frac{d^3}{dy^3}(y_i,\hat q(y_i))|^2$$

The final loss function takes the form:

$$J(y) = \hat c_s(\frac{1}{n_b}\sum_{j=1}^{n_b}loss^{(j)}_g) + \hat c_e(\frac{1}{n_e}\sum_{k=1}^{n_e}loss^{(k)}_f) + \hat c_s (\frac{1}{n_e}\sum_{k=1}^{n_e}loss^{(k)}_s)$$

Where $n_b = 3$, $n_e=2$ are the total number of solution conditions and governing equations used.  

## Network Architecture

The solutions for $u, \Omega$ in the CCF equations are solved using a fully-connected neural network with 6 hidden layers and 30 units per hidden layer. $\tanh(x)$ is used as the activation function. The final layer is an exponential activation function.

In [1]:
import haiku as hk
import jax 
import jax.numpy as jnp
import jax.random as jr
import optax
import numpy as np
from hilbert_toolkit import hilbert_haar # not the same high-accuracy hilbert transform 

In [None]:
# Haiku-based NN used to learn profiles for Omega and U
# 3 hidden layers with tanh activation functions, and an ELU as the final activation
def nnet(z : jax.Array) -> jax.Array: 
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.Linear(20), jax.nn.tanh,
        hk.Linear(20), jax.nn.tanh,
        hk.Linear(20), jax.nn.tanh,
        hk.Linear(1), jax.nn.elu
    ])
    return mlp(z)

# defining the output function q for the DG equations
def q_DG(z, lambda_val):
    if z.ndim == 0:
        z = jnp.expand_dims(z, axis=0)

    # evaluating the function
    nn_z = nnet(z)
    nn_neg_z = nnet(-z)
    q = (nn_z - nn_neg_z) / 2*(1+z**2)**(-1/(2*(1+lambda_val)))
    return jnp.squeeze(q)

# making q JAX-compatible
q_DG_jax = hk.transform(q_DG)
# autodiff gradients
dq_dz = jax.grad(q_DG_jax.apply, argnums=2)
d3q_dz3 = jax.grad(jax.grad(jax.grad(q_DG_jax.apply, argnums=2), argnums=2), argnums=2)

# vmappings
in_axes = (None, None, 0, None)
q_DG_vmap = jax.vmap(q_DG_jax.apply, in_axes=in_axes)
dq_dz_vmap = jax.vmap(dq_dz, in_axes=in_axes)
d3q_dz3_vmap = jax.vmap(d3q_dz3, in_axes=in_axes)

In [103]:
# Custom high-accuracy Hilbert transform
import jax
import jax.numpy as jnp
from functools import partial

# 2nd order lagrange polynomials
def _simpson_weights(n_points):
    if n_points % 2 == 0:
        raise ValueError("Simpson's rule requires an odd number of grid points.")
    
    # [1, 4, 2, 4, ..., 2, 4, 1]
    weights = jnp.ones(n_points)
    weights = weights.at[1::2].set(4.0)
    weights = weights.at[2:-1:2].set(2.0)
    return weights

def jax_simps(y, x):
    n_points = y.shape[0]
    h = (x[-1] - x[0]) / (n_points - 1)
    weights = _simpson_weights(n_points)
    return (h / 3.0) * jnp.dot(weights, y)

def make_Hn(q_vmap, q, s_grid):    
    # Get the domain limits L from the grid
    L = s_grid[-1]
    
    # hilbert transform evaluated at a single z
    def _hilbert_fn_internal(params, rng, z, lambda_val):
        omega_at_z = q(params, rng, z, lambda_val)
        omega_on_grid = q_vmap(params, rng, s_grid, lambda_val)
        
        # This integrand is now smooth at s=z
        integrand = (omega_on_grid - omega_at_z) / (z - s_grid)
        integral_part = jax_simps(integrand, s_grid)
        
        # P.V. integral of 1/(z-s) from -L to L is log(|(L-z)/(L+z)|)
        # Small epsilon for numerical stability if z == L or z == -L
        epsilon = 1e-10
        log_term = jnp.log(jnp.abs((L - z + epsilon) / (L + z + epsilon)))
        analytical_part = omega_at_z * log_term
        
        return (integral_part + analytical_part) / jnp.pi

    return _hilbert_fn_internal

In [None]:
# Loss functions

def conditional_loss(Omega_p, rng, lambda_val):
      # normalization constant 
      g1 = (q_DG_jax.apply(Omega_p, rng, 0.5, lambda_val) + 0.05)**2
      # uniform sampling of points to decay at infinity
      bd_pts = jnp.concatenate([jr.uniform(rng, shape=(10,), minval=29, maxval=30), 
                              jr.uniform(rng, shape=(10,), minval=-30, maxval=-29)])
      g2 = 1/20*jnp.sum(q_DG_vmap(Omega_p, rng, bd_pts, lambda_val)**2)

      return 1/2*(g1+g2)

# first equation residue
def f1(Omega_p, U_p, rng, z, lambda_val):
      Omega_z = q_DG_jax.apply(Omega_p, rng, z, lambda_val)
      dOmega_dz = dq_dz(Omega_p, rng, z, lambda_val)
      U_z = q_DG_jax.apply(U_p, rng, z, lambda_val)
      dU_dz = dq_dz(U_p, rng, z, lambda_val)

      return (Omega_z + ((1+lambda_val)*jax.nn.sinh(z) - U_z)
              *(1/jax.nn.cosh(z))*dOmega_dz - Omega_z*(1/jax.nn.cosh(z))*dU_dz)
# taking the 1st and 3rd derivative terms for the smoothness functions 
df1_dz = jax.grad(f1, argnums=3)
d3f1_dz3 = jax.grad(jax.grad(jax.grad(f1, argnums=3), argnums=3), argnums=3)

# vmapping
f1_vmap = jax.vmap(f1, in_axes=(None, None, None, 0, None))
df1_dz_vmap = jax.vmap(df1_dz, in_axes=(None, None, None, 0, None))
d3f1_dz3_vmap = jax.vmap(d3f1_dz3, in_axes=(None, None, None, 0, None))

# second equation residue
# init the custom hilbert transform
s_grid = jnp.linspace(-30,30,50001)
Hn = make_Hn(q_DG_vmap, q_DG_jax.apply, s_grid)

def f2(Omega_p, U_p, rng, z, lambda_val):
      dU_dz = dq_dz.apply(U_p, rng, z, lambda_val)
      return (1/jax.nn.cosh(z))*dU_dz-Hn(Omega_p, rng, z, lambda_val)
# taking the 1st and 3rd derivative terms for the smoothness functions 
df2_dz = jax.grad(f2, argnums=3)
d3f2_dz3 = jax.grad(jax.grad(jax.grad(f2, argnums=3), argnums=3), argnums=3)

# vmapping
f2_vmap = jax.vmap(f2, in_axes=(None, None, None, 0, None))
df2_dz_vmap = jax.vmap(df2_dz, in_axes=(None, None, None, 0, None))
d3f2_dz3_vmap = jax.vmap(d3f2_dz3, in_axes=(None, None, None, 0, None))

def equation_loss(Omega_p, U_p, rng, lambda_val):
      # first equation condition
      start_end = jnp.concatenate([jr.uniform(rng, 1, minval=-30, maxval=-20), jr.uniform(rng, 1, minval=20, maxval=30)])
      colloc_pts_1 = jnp.linspace(start_end[0], start_end[1], 10000)
      f1 = 1/colloc_pts_1.shape[0]*f1_vmap(Omega_p, U_p, rng, colloc_pts_1, lambda_val)**2

      # second equation condition
      start_end = jnp.concatenate([jr.uniform(rng, 1, minval=-30, maxval=-29), jr.uniform(rng, 1, minval=29, maxval=30)])
      colloc_pts_2 = jnp.linspace(start_end[0], start_end[1], 10000)
      f2 = 1/colloc_pts_2.shape[0]*f2_vmap(Omega_p, U_p, rng, colloc_pts_2, lambda_val)**2

      return 1/2*(f1+f2)

def smoothness_loss(Omega_p, U_p, rng, lambda_val):
      colloc_pts = jr.uniform(rng, 80, minval=-1, maxval=1)
      # df2_dz to find the first smooth lambda value
      f1s = 1/colloc_pts.shape[0]*jnp.sum(jnp.abs(df1_dz_vmap(Omega_p, U_p, rng, colloc_pts, lambda_val))**2)
      f2s = 1/colloc_pts.shape[0]*jnp.sum(jnp.abs(df2_dz_vmap(Omega_p, U_p, rng, colloc_pts, lambda_val))**2)
      return 1/2*(f1s+f2s)

# TODO - move colloc pt generation outside of loss functions?
def total_loss(Omega_p, U_p, rng, lambda_val):
      return jax.jit(conditional_loss(Omega_p, U_p, rng, lambda_val) 
              + equation_loss(Omega_p, U_p, rng, lambda_val) 
              + smoothness_loss(Omega_p, U_p, rng, lambda_val))

In [None]:
# Training loop 


In [90]:
rng = jax.random.PRNGKey(42)
z_in = jnp.array([0.5, 0.25, 1.0])
lambda_in = jnp.array(1.0)
Omega_p = q_DG_jax.init(rng, jnp.array([0.5]), lambda_in)
U_p = q_DG_jax.init(rng, jnp.array([0.75]), lambda_in)

# first equation condition
colloc_pts_1 = jr.uniform(rng, shape=(10000,), minval=-10, maxval=10)
Omega_z = q_DG_vmap(Omega_p, rng, colloc_pts_1, lambda_in)
dOmega_dz = dq_dz_vmap(Omega_p, rng, colloc_pts_1, lambda_in)

U_z = q_DG_vmap(U_p, rng, colloc_pts_1, lambda_in)
dU_dz = dq_dz_vmap(U_p, rng, colloc_pts_1, lambda_in)

f1 = (Omega_z + ((1+lambda_in)*jnp.sinh(colloc_pts_1) - U_z)*(1/jnp.cosh(colloc_pts_1))*dOmega_dz
        - Omega_z*(1/jnp.cosh(colloc_pts_1))*dU_dz)

# second equation condition
start_end = jnp.concatenate([jr.uniform(rng, 1, minval=-30, maxval=-29), jr.uniform(rng, 1, minval=29, maxval=30)])
colloc_pts_2 = jnp.linspace(start_end[0], start_end[1], 5001)

# eval for second colloc
Omega_z2 = q_DG_vmap(Omega_p, rng, colloc_pts_2, lambda_in)
dU_dz2 = dq_dz_vmap(U_p, rng, colloc_pts_2, lambda_in)

s_grid = jnp.linspace(-30,30,50001)
Hn = make_Hn(q_DG_vmap, q_DG_jax.apply, lambda_in, s_grid)
Hn_vmap = jax.vmap(Hn, in_axes=(None, None, 0))

f2 = (1/jnp.cosh(colloc_pts_2))*dU_dz2-Hn_vmap(Omega_p, rng, colloc_pts_2)