## Lou-Hou Solutions

We are looking for solutions of the profile chosen by Lou. Hou (2014). Starting with a single, purely rotating eddy in a cylinder, they demonstrated evidence of a singularity forming in finite-time. The initial profile required an odd function in the z-direction of the cylinder.

They claimed that as the simulation progressed, the singularity will become asymptotically self-similar.

## Generalized De Gregorio Equations

The De Gregorio equations, specifiaclly the derived CCF equation, has been shown to have a stable singularities in finite time. The generalized De Gregorio equations are:

$$\omega_t + a u \omega_x = \omega u_x$$

Where $\omega$ is the vorticity and $u$ is the velocity field. We will be analyzing a particular form of the De Gregorio equation, with $a=-1$. This is also known as the CCF equation. 

Using the self-similar ansatz for singularities from Wang et. al. (2023):

$$\omega(x,t) = \frac{1}{1-t}\Omega(\frac{x}{(1-t)^{1+\lambda}})$$

The equation is parametrized by $\lambda$. Then if we define the velocity $u = \int_0^y H\Omega ds$, and the change of coordinates $y=\frac{x}{(1-t)^{1+\lambda}}$, the De Gregorio equation becomes:

$$\Omega + ((1+\lambda)y-u)\frac{\partial \Omega}{\partial y}-\Omega\frac{\partial u}{\partial y}=0$$

Where $a=-1$. 

### Loss Function

The loss function is composed of a *condition loss* and a *equation loss* to track the residues of the Boundary Conditions and governing equations respectively. We implicitly impose the odd condition on the function, and the decay to infinity as:

$$q = (\frac{NN_q(z) - NN_q(-z)}{2})\cdot (1+z^2)^{-1/(2(1+\lambda))}$$

Where $q(z)\in\{u(z), \Omega(z)\}$.

To force the NN away from the trivial solution, we impose normalization conditions:

$$g_1 = \partial_y \Omega(0)+2$$
$$g_2 = \Omega(0.5) + 0.05$$
$$g_3 = \sum_{y\in Y_\infty} \Omega(y)^2$$

Where $g_1, g_2$ normalize the solution away from zero, and $g_3$ guides the function to decay far away from the origin.

Due to the nonlocality of the Hilbert Tranform, the De Gregorio equations need to be solved in a large domain. Therefore we define a new $z$-coordinate with the relation:

$$y=\sinh(z) \iff z=\sinh^{-1}(y)$$

In practice, we sample in the range $d\in[30,30]$.

The equation losses become:

$$f_1 = \Omega(z) + ((1+\lambda)\sinh(z)+au(z)) \cosh^{-1}(z)\partial_z[\Omega(z)] - \Omega(z) \cosh^{-1}(z)\partial_z[u(z)]$$

$$f_2 = \cosh^{-1}(z)\partial_z[u(z)]-H_n[\Omega(z)]$$

Where $f_2$ was derived from the definition $u_y=H\Omega$, $H_n$ is the numerical Hilbert transform. 

For the CCF equations in particular, in order to avoid optimizing for local maxima, we add an additional smoothness constraint in the form of a 3rd order loss term:

$$loss_s = \frac{1}{N_s}\sum_{i=1}^{N_s}|\frac{d^3}{dy^3}(y_i,\hat q(y_i))|^2$$

The final loss function takes the form:

$$J(y) = \hat c_s(\frac{1}{n_b}\sum_{j=1}^{n_b}loss^{(j)}_g) + \hat c_e(\frac{1}{n_e}\sum_{k=1}^{n_e}loss^{(k)}_f) + \hat c_s (\frac{1}{n_e}\sum_{k=1}^{n_e}loss^{(k)}_s)$$

Where $n_b = 3$, $n_e=2$ are the total number of solution conditions and governing equations used.  

## Network Architecture

The solutions for $u, \Omega$ in the CCF equations are solved using a fully-connected neural network with 6 hidden layers and 30 units per hidden layer. $\tanh(x)$ is used as the activation function. The final layer is an exponential activation function.

In [1]:
import haiku as hk
import jax 
import jax.numpy as jnp
import optax
import numpy as np

In [None]:
# Haiku-based NN used to learn profiles for Omega and U
# 6 hidden layers with tanh activation functions, and an ELU as the final activation
def nnet(z : jax.Array) -> jax.Array: 
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.Linear(30), jax.nn.tanh,
        hk.Linear(30), jax.nn.tanh,
        hk.Linear(30), jax.nn.tanh,
        hk.Linear(30), jax.nn.tanh,
        hk.Linear(30), jax.nn.tanh,
        hk.Linear(30), jax.nn.tanh,
        hk.Linear(1), jax.nn.elu
    ])
    return mlp(z)

# defining the output function q for the DG equations
def q_DG(z, lambda_val):
    if z.ndim == 0:
        z = jnp.expand_dims(z, axis=0)

    # evaluating the function
    nn_z = nnet(z)
    nn_neg_z = nnet(-z)
    q = (nn_z - nn_neg_z) / 2*(1+z**2)**(-1/(2*(1+lambda_val)))
    return jnp.squeeze(q)
# making q JAX-compatible
q_DG_jax = hk.transform(q_DG)

# autodiff gradient function wrt z 
def dq_dz(params,rng):
    # fixing params and rng
    apply_fn_with_params = lambda z_arg, lambda_arg: q_DG_jax.apply(params, rng, z_arg, lambda_arg)
    return jax.grad(apply_fn_with_params, argnums=0)

# hilbert transform found on the internet
# def Hn(f, grid, x, hilb_grid, h):
#     eval_pt = (x - hilb_grid * h) / h
#     return jnp.sum(jnp.interp(hilb_grid * h, grid, f) * jnp.sinc(eval_pt / 2) * jnp.sin(eval_pt / 2))

# accurate numerical hilbert transform


# defining the loss terms
def loss(Omega_p, U_p, rng, z, lambda_val):
    # conditional losses
    # values of z where the NNs should decay to zero
    Z_inf = [-30, -29, -28, 28, 29, 30]

    # normalization
    g1 = q_DG_jax.apply(Omega_p, rng, 0.5, lambda_val) + 0.05
    # decay at infinity
    g2 = jnp.sum([q_DG_jax.apply(Omega_p, rng, zinf, lambda_val)**2 for zinf in Z_inf])

    # equation losses
    Omega_z = q_DG_jax.apply(Omega_p, rng, z, lambda_val)
    dOmega_dz = dq_dz(Omega_p, rng)
    dOmega_dz_z = dOmega_dz_z(z, lambda_val)

    U_z = q_DG_jax.apply(U_p, rng, z, lambda_val)
    dU_dz = dq_dz(U_p, rng)
    dU_dz_z = dU_dz(z, lambda_val)
    
    f1 = (Omega_z + ((1+lambda_val)*jnp.sinh(z) - U_z)*(1/jnp.cosh(z))*dOmega_dz
          - Omega_z*(1/jnp.cosh(z))*dU_dz_z)
    f2 = (1/jnp.cosh(z))*dU_dz_z-Hn(Omega_z)

In [13]:
rng = jax.random.PRNGKey(42)
z_in = jnp.array(0.5)
lambda_in = jnp.array(1.0)
params = q_DG_jax.init(rng, z_in, lambda_in)
dqdz = dq_dz(params, rng)
dqdz(z_in, lambda_in)

Array(0.01195263, dtype=float32)

In [18]:
jnp.sech(0.5)

AttributeError: module 'jax.numpy' has no attribute 'sech'

In [17]:
q_DG_jax.apply(params, rng, z_in, lambda_in)

Array(-0.03434513, dtype=float32)

In [None]:
"""
JAX + Haiku starter for the Córdoba–Córdoba–Fontelos (CCF) profile search.

Features included:
  - Haiku MLP implementation (convertible from Flax) with simple mean-zero enforcement.
  - Spectral derivative and periodic Hilbert transform implemented via FFT.
  - Two-stage precision strategy:
      Stage A: pretrain with Adam in float32 (or float64 if jax_enable_x64=True).
      Stage B: cast parameters to float64 and run dense Gauss–Newton refinement in float64.
    (Comment explains how to go beyond float64 using MPFR/mpmath for critical GN steps.)
  - Adam pretraining + full-Jacobian Gauss–Newton refinement using jax.jacrev.

Notes:
  - JAX support for float128 is limited on most backends. For >64-bit you must use
    an external multiprecision library (mpmath / mpfr) to evaluate residuals/Jacobians
    outside JAX and then solve the linear system in high precision. I include notes
    and a small helper showing how to export the residual function for that purpose.
  - Dense GN is expensive; for large networks switch to Jacobian-vector products and CG.

Run: python jax_ccf_haiku.py

Dependencies:
  - jax, jaxlib
  - dm-haiku
  - optax
  - numpy

"""

import jax
import jax.numpy as jnp
from jax import random, jit
from jax.flatten_util import ravel_pytree
from functools import partial

# Haiku imports
import haiku as hk
import optax

# Optionally enable 64-bit for Stage B (do this before creating arrays if you want x64 everywhere)
# jax.config.update('jax_enable_x64', True)

# ----------------------
# Domain
# ----------------------

N = 512
x = jnp.linspace(-0.5, 0.5, N, endpoint=False)
dx = x[1] - x[0]

rng = random.PRNGKey(42)

# ----------------------
# Haiku MLP
# ----------------------

def mlp_fn(x, hidden=(128,128,128)):
    h = x
    for size in hidden:
        h = hk.Linear(size)(h)
        h = jax.nn.relu(h)
    h = hk.Linear(1)(h)
    return jnp.squeeze(h, -1)

net = hk.without_apply_rng(hk.transform(mlp_fn))

# wrapper to evaluate network on grid and enforce mean-zero

def net_forward(params, x_grid):
    xin = x_grid.reshape(-1,1)
    y = net.apply(params, xin)
    # enforce mean-zero (common for CCF)
    y = y - jnp.mean(y)
    return y

# ----------------------
# Profile residual (CCF self-similar, kappa=1)
# R = -Phi - x Phi' - (H Phi) Phi'
# ----------------------

@jit
def profile_residual_haiku(flat_params, unravel_fn, x_grid):
    params = unravel_fn(flat_params)
    Phi = net_forward(params, x_grid)
    Phi_x = spectral_derivative(Phi)
    HPhi = hilbert_transform(Phi)
    R = -Phi - x_grid * Phi_x - (HPhi) * Phi_x
    return R

# ----------------------
# Loss
# ----------------------

def loss_and_residual_haiku(flat_params, unravel_fn, x_grid):
    R = profile_residual_haiku(flat_params, unravel_fn, x_grid)
    loss = 0.5 * jnp.mean(R**2)
    return loss, R

# ----------------------
# Initialize params
# ----------------------

# init in default precision (float32 unless jax_enable_x64 True)
init_rng, rng = random.split(rng)
dummy_in = jnp.ones((1,1))
init_params = net.init(init_rng, dummy_in)
flat_init, unravel_fn = ravel_pytree(init_params)

# ----------------------
# Stage A: Adam pretraining (lower precision / faster)
# ----------------------

def adam_pretrain(params, x_grid, n_steps=2000, lr=1e-3):
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    @jit
    def step(params, opt_state):
        (loss, R), grads = jax.value_and_grad(loss_and_residual_haiku, has_aux=True)(params, unravel_fn, x_grid)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    p = params
    for i in range(n_steps):
        p, opt_state, loss = step(p, opt_state)
        if i % 200 == 0:
            print(f'[Stage A] Adam step {i:5d}  loss={loss:.3e}')
    return p

# ----------------------
# Stage B: Cast to high precision (float64) and run Gauss-Newton
# ----------------------

# Helper to cast flat params to target dtype

def cast_flat_params(flat_params, dtype):
    return jax.tree_map(lambda a: a.astype(dtype), flat_params)

@partial(jit, static_argnums=(3,))
def gn_step_haiku(flat_params, unravel_fn, x_grid, damping=1e-8):
    # compute residual and Jacobian J (N x P)
    R = profile_residual_haiku(flat_params, unravel_fn, x_grid)
    J = jax.jacrev(lambda fp: profile_residual_haiku(fp, unravel_fn, x_grid))(flat_params)

    JTJ = J.T @ J
    JTr = J.T @ R
    dim = JTJ.shape[0]
    lhs = JTJ + jnp.eye(dim) * damping
    delta = -jnp.linalg.solve(lhs, JTr)
    new_flat = flat_params + delta
    return new_flat, delta, R

# Top-level driver showing the two-stage flow

def two_stage_training(x_grid, n_adam=1000, n_gn=6):
    # Stage A: Adam pretrain in default precision
    print('
=== Stage A: Adam pretraining (fast, lower precision) ===')
    flat_params, _ = ravel_pytree(init_params)
    flat_params = adam_pretrain(flat_params, x_grid, n_steps=n_adam)

    # Stage B: enable x64 (do this before casting if necessary). If you want to enable 64-bit
    # globally, uncomment the config line at top of file. Otherwise cast params here.
    print('
=== Stage B: Gauss-Newton refinement (float64) ===')
    # Cast to float64 for refinement
    flat_params64 = cast_flat_params(flat_params, jnp.float64)

    for i in range(n_gn):
        flat_params64, delta, R = gn_step_haiku(flat_params64, unravel_fn, x, damping=1e-12)
        res_norm = jnp.linalg.norm(R)
        delta_norm = jnp.linalg.norm(delta)
        print(f'[Stage B] GN iter {i:3d}  res_norm={res_norm:.3e}  delta_norm={delta_norm:.3e}')

    return flat_params64

# ----------------------
# Notes on >64-bit (MPFR / mpmath)
# ----------------------
# If you require >64-bit precision for the GN linear solve, a practical approach is:
# 1. Export the residual function as a Python function that evaluates R at a given parameter
#    vector using high-precision arithmetic (mpmath). This requires re-implementing spectral
#    FFT/Hilbert in mpmath or evaluating the Fourier series in high precision.
# 2. Compute the Jacobian with finite differences in high precision (or use automatic
#    differentiation in mpmath-like frameworks) and form the normal equations in high
#    precision, then solve for delta using mpmath's linear solver.
# This is slower but robust; include it only in the final refinement stages.

# ----------------------
# If run as script
# ----------------------
if __name__ == '__main__':
    # Use full grid as collocation
    x_grid = x
    # Stage A + B
    final_flat = two_stage_training(x_grid, n_adam=800, n_gn=4)
    final_params = unravel_fn(final_flat)
    Phi = net_forward(final_params, x)
    print('
Sample Phi[0:8]=', Phi[:8])
    R_final = profile_residual_haiku(final_flat, unravel_fn, x)
    print('Final mean |R| =', jnp.mean(jnp.abs(R_final)))

# ----------------------
# Converting back to Flax or to other refinements:
# - Replace dense GN with matrix-free Jvps and CG for large param counts.
# - Add spectral dealiasing and filtering to stabilize high-wave-number components.
# - Use checkpointing / parameter partitioning if memory is a bottleneck during Jacobian assembly.
# ----------------------
