In [None]:
import numpy as np
from scipy.linalg import solve_banded


def fokker_planck_step_vectorized_implicit_phi(
        f, X, Z, PHI, t, dt, dx, dz, dphi,
        bc_x='open', bc_z='noflux', bc_phi='noflux'  # periodic not handled here
    ):
    """
    IMEX Fokker–Planck step:
      - x,z directions: explicit upwind
      - phi direction: implicit Chang–Cooper (tridiagonal solve per (i,j))
    """
    Nx, Nz, Nphi = f.shape

    # ===================== drift & diffusion ======================
    MU_X   = mu_x_vec(X, Z, PHI, t)
    MU_Z   = mu_z_vec(X, Z, PHI, t)
    MU_PHI = mu_phi_vec(X, Z, PHI, t)
    D      = diffusion_vec(X, Z, PHI, t)

    # ===================== explicit x-flux ========================
    MU_X_face_p = 0.5 * (MU_X + np.roll(MU_X, -1, axis=0))
    F_x_p = np.where(MU_X_face_p >= 0, MU_X_face_p * f,
                     MU_X_face_p * np.roll(f, -1, axis=0))

    MU_X_face_m = 0.5 * (np.roll(MU_X, 1, axis=0) + MU_X)
    F_x_m = np.where(MU_X_face_m >= 0,
                     MU_X_face_m * np.roll(f, 1, axis=0),
                     MU_X_face_m * f)

    if bc_x == "open":
        F_x_m[0, :, :]  = np.minimum(0, MU_X[0, :, :])  * f[0, :, :]
        F_x_p[-1, :, :] = np.maximum(0, MU_X[-1, :, :]) * f[-1, :, :]
    elif bc_x == "noflux":
        F_x_m[0, :, :]  = 0.0
        F_x_p[-1, :, :] = 0.0

    div_x = -(F_x_p - F_x_m) / dx

    # ===================== explicit z-flux ========================
    MU_Z_face_p = 0.5 * (MU_Z + np.roll(MU_Z, -1, axis=1))
    F_z_p = np.where(MU_Z_face_p >= 0, MU_Z_face_p * f,
                     MU_Z_face_p * np.roll(f, -1, axis=1))

    MU_Z_face_m = 0.5 * (np.roll(MU_Z, 1, axis=1) + MU_Z)
    F_z_m = np.where(MU_Z_face_m >= 0,
                     MU_Z_face_m * np.roll(f, 1, axis=1),
                     MU_Z_face_m * f)

    if bc_z == "noflux":
        F_z_m[:, 0, :]  = 0.0
        F_z_p[:, -1, :] = 0.0
    elif bc_z == "open":
        F_z_m[:, 0, :]  = np.minimum(0, MU_Z[:, 0, :])  * f[:, 0, :]
        F_z_p[:, -1, :] = np.maximum(0, MU_Z[:, -1, :]) * f[:, -1, :]

    div_z = -(F_z_p - F_z_m) / dz

    # ==============================================================
    #              EXPLICIT x,z UPDATE -> f_star
    # ==============================================================
    f_star = f + dt * (div_x + div_z)

    # ==============================================================
    #              IMPLICIT Chang–Cooper in phi
    # ==============================================================

    # Precompute face coefficients
    MU_p = 0.5 * (MU_PHI + np.roll(MU_PHI, -1, axis=2))   # k+1/2
    MU_m = 0.5 * (np.roll(MU_PHI, 1, axis=2) + MU_PHI)    # k-1/2

    D_p = 0.5 * (D + np.roll(D, -1, axis=2))
    D_m = 0.5 * (np.roll(D, 1, axis=2) + D)

    Pe_p = MU_p * dphi / (D_p + 1e-16)
    Pe_m = MU_m * dphi / (D_m + 1e-16)

    δ_p = delta_cc_vec(Pe_p)
    δ_m = delta_cc_vec(Pe_m)

    # Chang–Cooper flux coefficients J_{k+1/2}, J_{k-1/2}:
    # J_{k+1/2} = c_p_c * f_k + c_p_l * f_{k+1}
    # J_{k-1/2} = c_m_l * f_{k-1} + c_m_c * f_k
    c_p_l = MU_p * (1 - δ_p) - D_p / dphi
    c_p_c = MU_p * δ_p         + D_p / dphi

    c_m_l = MU_m * δ_m         + D_m / dphi
    c_m_c = MU_m * (1 - δ_m)   - D_m / dphi

    # Divergence operator L_phi such that df/dt = L_phi f:
    # div_phi_k = -(J_{k+1/2} - J_{k-1/2})/dphi
    #           = a_k f_k + b_k f_{k+1} + c_k f_{k-1}
    # gives (for reference):
    #   a_k = (-c_p_c + c_m_c)/dphi
    #   b_k = -c_p_l / dphi
    #   c_k =  c_m_l / dphi
    #
    # Backward Euler in phi:  (I - dt L_phi) f^{n+1} = f_star

    f_new = np.empty_like(f)
    alpha = dt / dphi

    for i in range(Nx):
        for j in range(Nz):
            # For each fixed (i,j), build tridiagonal A in k
            cpl = c_p_l[i, j, :]
            cpc = c_p_c[i, j, :]
            cml = c_m_l[i, j, :]
            cmc = c_m_c[i, j, :]

            # Main and off-diagonal entries for (I - dt L_phi)
            # from algebra above:
            B = 1.0 + alpha * (cpc - cmc)   # main diagonal
            C = alpha * cpl                 # upper diagonal (k+1)
            A = -alpha * cml                # lower diagonal (k-1)

            # If you want "no flux" / decoupled endpoints in phi,
            # you can explicitly zero the cyclic couplings:
            if bc_phi == 'noflux':
                A[0]   = 0.0    # no coupling from k=-1
                C[-1]  = 0.0    # no coupling to k=Nphi

            # Banded matrix: ab[0] upper, ab[1] main, ab[2] lower
            ab = np.zeros((3, Nphi))
            ab[0, 1:]  = C[:-1]   # upper diag: element at (k-1,k)
            ab[1, :]   = B
            ab[2, :-1] = A[1:]    # lower diag: element at (k+1,k)

            rhs = f_star[i, j, :]

            f_new[i, j, :] = solve_banded((1, 1), ab, rhs)

    return np.maximum(f_new, 0.0)
