<a href="https://colab.research.google.com/github/michalshavitNYU/michalshavitnyu.github.io/blob/master/CauchyRiemann.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp
import optax
import numpy as np
import matplotlib.pyplot as plt

# Initialize a simple fully connected neural network
def init_nn_params(layers, key): # A list specifying the number of neurons in each layer and A PRNG key used for reproducible random initialization.
    keys = jax.random.split(key, len(layers) - 1) #This splits the PRNG key into separate subkeys, one for each layer-to-layer transition (because we have len(layers) - 1 weight matrices)
    params = [] #creates an empty list which we define below in the loop
    for k_in, k_out in zip(layers[:-1], layers[1:]): #zip() is a Python function that pairs list: Each layer's input size With the output size of the next layer

        w_key, b_key = jax.random.split(keys[0])
        params.append(
            {
                "W": jax.random.normal(w_key, (k_in, k_out)) * jnp.sqrt(2.0 / k_in), #He initialization, can change to Xavier without the 2
                "b": jnp.zeros((k_out,)) #Initial zeroes
            }
        )
        keys = keys[1:] #Moves on to the next subkey for the next layer.
    return params

# Forward pass through the network #This is the forward pass of the neural network — the part where I feed input through the layers to get the output.
def nn_forward(params, xy): #params were created above, xy will be a 2d input vector [x y]
    for layer in params[:-1]: #This loops over all layers except the last one — these are the hidden layers.
        xy = jnp.tanh(jnp.dot(xy, layer["W"]) + layer["b"])
    return jnp.dot(xy, params[-1]["W"]) + params[-1]["b"]  # Output: (u, v)  this is the final linear transformation that outputs the network's prediction: A 2D vector: [u(x, y), v(x, y)]

# Compute partial derivatives correctly #need to improve to vector gradients
def compute_derivatives(params, x, y):
    """ Compute derivatives of u and v separately using jax.grad. """
    def u_func(x, y):
        return nn_forward(params, jnp.hstack((x, y)))[0]  # Extract u as a scalar

    def v_func(x, y):
        return nn_forward(params, jnp.hstack((x, y)))[1]  # Extract v as a scalar

    u_x = jax.grad(lambda x: u_func(x, y))(x[0])  # Compute du/dx
    u_y = jax.grad(lambda y: u_func(x, y))(y[0])  # Compute du/dy
    v_x = jax.grad(lambda x: v_func(x, y))(x[0])  # Compute dv/dx
    v_y = jax.grad(lambda y: v_func(x, y))(y[0])  # Compute dv/dy

    return u_x, u_y, v_x, v_y

# Loss function (Physics-Informed)
def loss_fn(params, batch):
    x, y = batch[:, 0:1], batch[:, 1:2]

    # Compute u, v and their derivatives for each data point individually
    u_x, u_y, v_x, v_y = jax.vmap(compute_derivatives, in_axes=(None, 0, 0))(params, x, y)

    # Physics loss (Cauchy-Riemann equations)
    physics_loss = jnp.mean((u_x - v_y) ** 2 + (u_y + v_x) ** 2)

    # Boundary condition at y = 0
    x_bc = jnp.linspace(-jnp.pi, jnp.pi, 100).reshape(-1, 1)
    xy_bc = jnp.hstack((x_bc, jnp.zeros_like(x_bc)))
    uv_bc = nn_forward(params, xy_bc)
    boundary_loss = jnp.mean((uv_bc[:, 0] - jnp.cos(x_bc)) ** 2) #+(uv_bc[:, 1] - jnp.sin(x_bc)) ** 2)
    #boundary_loss = jnp.mean((uv_bc[:, 0] - jnp.cos(x_bc)) ** 2 + (uv_bc[:, 1] - jnp.sin(x_bc)) ** 2)

    return physics_loss + 0.1 * boundary_loss

# Generate training data
key = jax.random.PRNGKey(42)
x_train = jax.random.uniform(key, (1000, 1), minval=-jnp.pi, maxval=jnp.pi)
y_train = jax.random.uniform(key, (1000, 1), minval=-jnp.pi, maxval=jnp.pi)
train_data = jnp.hstack((x_train, y_train))

# Initialize network
layers = [2, 64, 64, 64, 2]  # Input: (x,y), Output: (u,v)
params = init_nn_params(layers, key)

# Define Adam optimizer
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)

# Training loop
epochs = 10000
for epoch in range(epochs):
    loss, grads = jax.value_and_grad(loss_fn)(params, train_data)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    if epoch % 500 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.6f}")

print("Training complete!")

# Generate test points for v(0,y)
y_test = jnp.linspace(-jnp.pi, jnp.pi, 100).reshape(-1, 1)
x_test = jnp.zeros_like(y_test)
xy_test = jnp.hstack((x_test, y_test))

# Compute v(0,y) using trained network
uv_test = nn_forward(params, xy_test)
v_test = uv_test[:, 1]  # Extract imaginary part v(x,y)

# Plot v(0,y)
plt.figure(figsize=(6,4))
plt.plot(y_test, v_test, label='v(0,y)', color='b')
plt.xlabel('y')
plt.ylabel('v(0,y)')
plt.title('Imaginary Part v(0,y)')
plt.legend()
plt.grid()
plt.show()

Epoch 0, Loss: 1.246207
Epoch 500, Loss: 0.050505
Epoch 1000, Loss: 0.050495
Epoch 1500, Loss: 0.050492
Epoch 2000, Loss: 0.050491
Epoch 2500, Loss: 0.050491
Epoch 3000, Loss: 0.050490
Epoch 3500, Loss: 0.050491
Epoch 4000, Loss: 0.050490
Epoch 4500, Loss: 0.050490
Epoch 5000, Loss: 0.050490
Epoch 5500, Loss: 0.050490
Epoch 6000, Loss: 0.050492
Epoch 6500, Loss: 0.050490
Epoch 7000, Loss: 0.050490
Epoch 7500, Loss: 0.050490
Epoch 8000, Loss: 0.050490
Epoch 8500, Loss: 0.050493


In [None]:
# --- Plot v(x, 0) ---
x_test_vx0 = jnp.linspace(-jnp.pi, jnp.pi, 100).reshape(-1, 1)
y_test_vx0 = jnp.zeros_like(x_test_vx0)
xy_test_vx0 = jnp.hstack((x_test_vx0, y_test_vx0))
uv_test_vx0 = nn_forward(params, xy_test_vx0)
v_x0 = uv_test_vx0[:, 1]

plt.figure(figsize=(6,4))
plt.plot(x_test_vx0, v_x0, label='v(x, 0)', color='g')
plt.xlabel('x')
plt.ylabel('v(x,0)')
plt.title('Imaginary Part v(x, 0)')
plt.legend()
plt.grid()
plt.show()

In [None]:
# ---- Check Harmonicity: Laplacians Δu and Δv ----
def laplacians(params, xy):
    def u_func(xy): return nn_forward(params, xy)[0]
    def v_func(xy): return nn_forward(params, xy)[1]

    hess_u = jax.jacfwd(jax.grad(u_func))(xy)
    hess_v = jax.jacfwd(jax.grad(v_func))(xy)
    delta_u = jnp.trace(hess_u)
    delta_v = jnp.trace(hess_v)
    return delta_u, delta_v

# Sample points on a grid
grid_x, grid_y = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, 100),
                              jnp.linspace(-jnp.pi, jnp.pi, 100))
xy_grid = jnp.stack([grid_x.ravel(), grid_y.ravel()], axis=-1)
lap_u, lap_v = jax.vmap(lambda xy: laplacians(params, xy))(xy_grid)

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.contourf(grid_x, grid_y, lap_u.reshape(100, 100), levels=50)
plt.colorbar()
plt.title("Laplacian Δu")

plt.subplot(1, 2, 2)
plt.contourf(grid_x, grid_y, lap_v.reshape(100, 100), levels=50)
plt.colorbar()
plt.title("Laplacian Δv")

plt.tight_layout()
plt.show()

# ---- 2D Field Plots of u(x,y) and v(x,y) ----
uv_vals = jax.vmap(lambda xy: nn_forward(params, xy))(xy_grid)
u_vals = uv_vals[:, 0].reshape(100, 100)
v_vals = uv_vals[:, 1].reshape(100, 100)

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.contourf(grid_x, grid_y, u_vals, levels=50)
plt.colorbar()
plt.title("u(x, y)")

plt.subplot(1, 2, 2)
plt.contourf(grid_x, grid_y, v_vals, levels=50)
plt.colorbar()
plt.title("v(x, y)")

plt.tight_layout()
plt.show()