# Neutral brush cDFT

In [None]:
import jax
import jax.numpy as jnp
from jax.scipy.linalg import solve_banded

In [None]:
def setup_grid(z_points: jnp.ndarray):
    z = z_points
    dz = jnp.diff(z)
    dzm = dz[:-1]
    dzp = dz[1:]
    return z, dz, dzm, dzp

In [None]:
def crank_nicolson_step(q, w, dzm, dzp, b, dt):
    N = q.size
    alpha = b ** 2 / 6

    # Build diagonals of tridiagonal matrix A and B
    main = jnp.zeros(N)
    upper = jnp.zeros(N - 1)
    lower = jnp.zeros(N - 1)

    for i in range(1, N - 1):
        dzmi = dzm[i - 1]
        dzpi = dzp[i - 1]
        denom = dzmi * (dzmi + dzpi) * dzpi

        lower = lower.at[i - 1].set(-2 * alpha * dt / (dzmi * (dzmi + dzpi)))
        main = main.at[i].set(1 + 2 * alpha * dt / denom + dt * w[i] / 2)
        upper = upper.at[i].set(-2 * alpha * dt / (dzpi * (dzmi + dzpi)))

    # Neumann BCs (zero flux)
    main = main.at[0].set(1.0)
    main = main.at[-1].set(1.0)

    # Right-hand side: B * q
    rhs = jnp.zeros_like(q)
    for i in range(1, N - 1):
        dzmi = dzm[i - 1]
        dzpi = dzp[i - 1]
        denom = dzmi * (dzmi + dzpi) * dzpi

        rhs = rhs.at[i].set(
            (1 - 2 * alpha * dt / denom - dt * w[i] / 2) * q[i]
            + 2 * alpha * dt / (dzpi * (dzmi + dzpi)) * q[i + 1]
            + 2 * alpha * dt / (dzmi * (dzmi + dzpi)) * q[i - 1]
        )

    rhs = rhs.at[0].set(q[0])
    rhs = rhs.at[-1].set(q[-1])

    # Solve tridiagonal system
    ab = jnp.zeros((3, N))
    ab = ab.at[0, 1:].set(upper)
    ab = ab.at[1, :].set(main)
    ab = ab.at[2, :-1].set(lower)

    q_new = solve_banded((1, 1), ab, rhs)
    return q_new