# Install Dependencies

In [None]:
import dolfinx
import polyhedral_net_splines as pns


--2025-10-16 20:34:17--  https://fem-on-colab.github.io/releases/fenicsx-install-release-real.sh
Resolving fem-on-colab.github.io (fem-on-colab.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to fem-on-colab.github.io (fem-on-colab.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4339 (4.2K) [application/x-sh]
Saving to: ‘/tmp/fenicsx-install.sh’


2025-10-16 20:34:17 (14.9 MB/s) - ‘/tmp/fenicsx-install.sh’ saved [4339/4339]

+ INSTALL_PREFIX=/usr/local
++ echo /usr/local
++ awk -F/ '{print NF-1}'
+ INSTALL_PREFIX_DEPTH=2
+ PROJECT_NAME=fem-on-colab
+ SHARE_PREFIX=/usr/local/share/fem-on-colab
+ FENICSX_INSTALLED=/usr/local/share/fem-on-colab/fenicsx.installed
+ [[ ! -f /usr/local/share/fem-on-colab/fenicsx.installed ]]
+ PYBIND11_INSTALL_SCRIPT_PATH=https://github.com/fem-on-colab/fem-on-colab.github.io/raw/a0823bbc/releases/pybind11-install.sh
+ [[ https://github.com/fem-on-colab/fem-on-colab.github.i

In [2]:
# @title PnS FEM
import dolfinx
import basix
import ufl
from mpi4py import MPI
from dolfinx.fem.petsc import LinearProblem
import numpy
import polyhedral_net_splines as pns
import numpy as np

"""# PNS FEM"""

import typing
from petsc4py import PETSc
import dolfinx.fem.petsc
from dolfinx.fem.petsc import assemble_matrix_mat, assemble_vector, create_matrix, _create_form, create_vector
import itertools
import functools
from mpi4py import MPI
import ufl
import numpy as np
import polyhedral_net_splines as pns
import basix
from tqdm import tqdm

import time
from functools import wraps
from contextlib import contextmanager

# Store total times
total_times = {}

# Thread-local (or global) exclusion time store
_excluded_time = {}

@contextmanager
def exclude_time():
    thread_id = 'default'  # In real multi-threaded code, use threading.get_ident()
    start = time.perf_counter()
    yield
    end = time.perf_counter()
    _excluded_time[thread_id] = _excluded_time.get(thread_id, 0) + (end - start)

def track_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        thread_id = 'default'
        _excluded_time[thread_id] = 0  # Reset before each call
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        elapsed = end - start - _excluded_time[thread_id]
        total_times[func.__name__] = total_times.get(func.__name__, 0) + elapsed
        return result
    return wrapper

cubicLagrangeDofOrder = [0, 6, 7, 2, 4, 12, 14, 10, 5, 13, 15, 11, 1, 8, 9, 3]
cubicLagrangeDofOrderInv = [0, 12, 3, 15, 4, 8, 1, 2, 13, 14, 7, 11, 5, 9, 6, 10]

def assemble_prolongation_matrix(pnsObject) -> PETSc.Mat:
    """Construct Prolongation matrix of size (numBBcoeffs, numPnsCP)

    Args:
        pnsObject: PNS object

    Returns:
        Prolongation matrix
    """
    comm = pnsObject['comm']
    numBB = pnsObject['numPatches'] * 16
    numPns = pnsObject['numVerts']
    numElements = pnsObject['numElements']

    K = PETSc.Mat().createAIJ(size=(numBB*numElements, numPns*numElements), comm=comm)
    K.setUp()
    currCp = 0
    for pb in pnsObject['augmentedLagrangePatchBuilders']:
        neighborhood = pb.neighbor_verts
        mask = pb.mask
        for maskIdx in range(len(mask)):
            rowIdx = (maskIdx//16)*16 + cubicLagrangeDofOrderInv[maskIdx % 16]
            cpRow = mask[rowIdx]
            u, inv = np.unique(neighborhood, return_inverse=True)
            coeffs = dict(zip(u, np.bincount(inv, weights=cpRow)))
            for neighbor in coeffs:
                for e in range(numElements):
                    writeRow = e*16 + currCp
                    writeCol = e + numElements * neighbor
                    K[writeRow, writeCol] = coeffs[neighbor]
            if maskIdx % 16 == 15:
                currCp += (numElements-1)*16
            currCp += 1
    K.assemble()
    return K


def fix_zero_rows(mat: PETSc.Mat):
    """For any all-zero row in PETSc Mat, set the diagonal to 1."""
    nrows, _ = mat.getSize()
    zero_rows = []
    for i in range(nrows):
        # Get row entries
        cols, vals = mat.getRow(i)
        if cols.size == 0 or all(abs(v) < 1e-14 for v in vals):
            zero_rows.append(i)
    mat.zeroRowsColumns(zero_rows, diag = 1)
    mat.assemble()
    return mat

@track_time
def assemble_matrix_pns(pnsObject, a, bc = None, lagrange_bcs=[], result = None, A = None, AK = None, KTAK = None):
    """Assemble bilinear form into a matrix for PNS.
    Args:
        pnsObject: PNS object
        a: Bilinear UFL form or a sequence of sequence of bilinear
            forms, the left hand side of the variational problem.
        bc: Dirichlet boundary condition with PNS DOF.
        lagrange_bcs: Lagrange basis Dirichlet boundary conditions.
        result: Optional Matrix to assemble the bilinear form into.
        A: Bilinear form matrix in Lagrange basis
        AK: Bilinear form matrix in Lagrange basis
        KTAK: Temprory as setting to result directly without copy not working.
    Returns:
        Bilinear form matrix for PNS.
        Bilinear form matrix in Lagrange basis(no BC)
    """
    comm = pnsObject['comm']

    with exclude_time():
        if A is None:
            A = create_matrix(a)
        # A.assemble()
        A.zeroEntries()
        A = dolfinx.fem.petsc.assemble_matrix(A, a, bcs=lagrange_bcs)
        A.assemble()
    K = pnsObject['prolongation_matrix']
    AK = A.matMult(K, result=AK)
    KTAK = K.transposeMatMult(AK, result = KTAK) # for some reason cannot set directly to result when result is used with non-linear solver.
    if bc is not None:
        bc.setMatrix(KTAK)
    KTAK.assemble()
    KTAK = fix_zero_rows(KTAK)
    if result is not None:
        result.assemble()
        KTAK.copy(result)
    return KTAK, A

@track_time
def assemble_vector_pns(pnsObject, L, A = None, bc = None, lagrange_bcs=[], bilinear_form=None, result = None, x0=None, alpha=1.0, b = None):
    """Assemble linear form into a vector for PNS.
    Args:
        pnsObject: PNS object
        L: Linear UFL form or a sequence of linear forms, the right
            hand side of the variational problem.
        A: Bilinear form matrix in Lagrange basis
        bc: Dirichlet boundary condition with PNS DOF.
        lagrange_bcs: Lagrange basis Dirichlet boundary conditions.
        bilinear_form: Bilinear form in Lagrange basis. Required if applying lagrange bcs.
        result: Optional Vector to assemble the linear form into.
        x0: For BC
        alpha: For BC
        b: Optional vector to store the temopary vector for Lagrange basis.
    Returns:
        Linear form vector for PNS.
    """
    comm = pnsObject['comm']
    if result is None:
        result = PETSc.Vec().createSeq(pnsObject['prolongation_matrix'].getSize()[1], comm=comm)
    K = pnsObject['prolongation_matrix']
    with exclude_time():
        if b is None:
            b = create_vector(L)
        with b.localForm() as bl:
            bl.set(0)
        dolfinx.fem.petsc.assemble_vector(b, L)
        if len(lagrange_bcs) > 0:
            dolfinx.fem.petsc.apply_lifting(b, [bilinear_form], [lagrange_bcs], x0=[x0] if x0 is not None else [], alpha=alpha)
            dolfinx.fem.petsc.set_bc(b, lagrange_bcs, x0=x0, alpha=alpha)

    lifted_b = b # if no pns BC

    # TODO: account for x0 and alpha
    if bc is not None:
        AK = A.matMult(K)
        x_g = bc.getXg()
        lifted_b = PETSc.Vec().createSeq(K.getSize()[0], comm=comm)
        AK.mult(x_g, lifted_b)
        lifted_b.aypx(-1, b)
    K.multTranspose(lifted_b, result)
    if bc is not None:
        bc.setVector(result)
    result.assemble()
    return result, b

def pnsAssign(source_pns: PETSc.Vec, target_lagrange: dolfinx.fem.Function, pnsObject):
    """Assign the value from source vector in Lagrange basis to target vector in PNS basis.
    Args:
        source_pns: Source vector in PNS basis.
        target_lagrange: Target function in Lagrange basis.
        pnsObject: PNS object
    """
    K = pnsObject['prolongation_matrix']
    K.mult(source_pns, target_lagrange.x.petsc_vec)

def pnsLagrangeAssign(source_lagrange: dolfinx.fem.Function, target_pns: PETSc.Vec, pnsObject):
    """Assign the value from source vector in PNS basis to target vector in Lagrange basis.
    Args:
        source_lagrange: Source function in PNS basis.
        target_pns: Target vector in Lagrange basis.
        pnsObject: PNS object
    """
    K = pnsObject['prolongation_matrix']
    K.multTranspose(source_lagrange.x.petsc_vec, target_pns)

class PnsLinearProblem:
    def __init__(
        self,
        a: ufl.Form,
        L: ufl.Form,
        pnsObject,
        bc: typing.Optional[dolfinx.fem.DirichletBC] = None,
        lagrange_bcs = [],
        comm = MPI.COMM_WORLD,
        petsc_options: typing.Optional[dict] = None,
        form_compiler_options: typing.Optional[dict] = None,
        jit_options: typing.Optional[dict] = None,
    ):
        """Initialize solver for a linear variational problem.

        Args:
            a: Bilinear UFL form or a sequence of sequence of bilinear
                forms, the left hand side of the variational problem.
            L: Linear UFL form or a sequence of linear forms, the right
                hand side of the variational problem.
            pnsObject: PNS object
            bc(PnsDirichletBC): Dirichlet boundary condition with PNS DOF.
            lagrange_bcs: Lagrange basis Dirichlet boundary conditions.
            petsc_options: Options that are passed to the linear
                algebra backend PETSc. For available choices for the
                'petsc_options' kwarg, see the `PETSc documentation
                <https://petsc4py.readthedocs.io/en/stable/manual/ksp/>`_.
            form_compiler_options: Options used in FFCx compilation of
                all forms. Run ``ffcx --help`` at the commandline to see
                all available options.
            jit_options: Options used in CFFI JIT compilation of C
                code generated by FFCx. See `python/dolfinx/jit.py` for
                all available options. Takes priority over all other
                option values.

        Example::

            problem = LinearProblem(a, L, bc, petsc_options={
                "ksp_type": "preonly",
                "pc_type": "lu",
                "pc_factor_mat_solver_type": "mumps"
            })
        """
        # Maybe u needs to be in pns for.
        # Create K. Maybe k can be represended as a form then create_matrix function can be used
        self.pnsObject = pnsObject
        self._a = dolfinx.fem.form(a)
        self._L = dolfinx.fem.form(L)

        self.bc = bc
        self.lagrange_bcs = lagrange_bcs

        comm = pnsObject['comm']

        self._solver = PETSc.KSP().create(comm)
        prefix = f"dolfinx_solve_{id(self)}"
        self._solver.setOptionsPrefix(prefix)

        opts = PETSc.Options()
        opts.prefixPush(prefix)
        if petsc_options:
            for key, val in petsc_options.items():
                opts[key] = val
        opts.prefixPop()
        self._solver.setFromOptions()

    def solve(self) -> PETSc.Vec:
        A_reduced, A = assemble_matrix_pns(self.pnsObject, self._a, self.bc, self.lagrange_bcs)
        b_reduced, _ = assemble_vector_pns(self.pnsObject, self._L, A, self.bc, self.lagrange_bcs, self._a)
        self._solver.setOperators(A_reduced)
        x_reduced = b_reduced.duplicate()
        self._solver.solve(b_reduced, x_reduced)

        return x_reduced

    def __del__(self):
        return

class PnsNewtonSolver(dolfinx.cpp.nls.petsc.NewtonSolver):
    def __init__(self, comm: MPI.Intracomm, problem):
        """A Newton solver for non-linear problems."""
        super().__init__(comm)
        self.problem = problem
        # Create matrix and vector to be used for assembly
        # of the non-linear problem
        pnsObject = problem.pnsObject
        numPnsDof = pnsObject['prolongation_matrix'].getSize()[1]
        self._A = PETSc.Mat().createAIJ(size=(numPnsDof, numPnsDof), comm=comm)
        self.setJ(problem.J, self._A)
        self._b = PETSc.Vec().createSeq(numPnsDof, comm=comm)
        self.setF(problem.F, self._b)
        self._x = PETSc.Vec().createSeq(numPnsDof, comm=comm)
        pnsObject['prolongation_matrix'].multTranspose(problem.u.x.petsc_vec, self._x)
        self.set_form(problem.form)

    def __del__(self):
        self._A.destroy()
        self._b.destroy()

    def solve(self, u: dolfinx.fem.Function):
        """Solve non-linear problem into function u. Returns the number
        of iterations and if the solver converged."""
        n, converged = super().solve(self._x)
        self.problem.pnsObject['prolongation_matrix'].mult(self._x, u.x.petsc_vec)
        u.x.scatter_forward()
        return n, converged

    @property
    def A(self) -> PETSc.Mat:  # type: ignore
        """Jacobian matrix"""
        return self._A

    @property
    def b(self) -> PETSc.Vec:  # type: ignore
        """Residual vector"""
        return self._b

    def setP(self, P: dolfinx.fem.Function, Pmat: PETSc.Mat):  # type: ignore
        """
        Set the function for computing the preconditioner matrix

        Args:
            P: Function to compute the preconditioner matrix
            Pmat: Matrix to assemble the preconditioner into

        """
        super().setP(P, Pmat)




class PnsNonlinearProblem:
    """Nonlinear problem class for solving the non-linear problems.

    Solves problems of the form :math:`F(u, v) = 0 \\ \\forall v \\in V` using
    PETSc as the linear algebra backend.
    """

    def __init__(
        self,
        F: ufl.form.Form,
        u: dolfinx.fem.Function,
        pnsObject,
        bcs: list[dolfinx.fem.DirichletBC] = [],
        J: ufl.form.Form = None,
        form_compiler_options: typing.Optional[dict] = None,
        jit_options: typing.Optional[dict] = None,
    ):
        """Initialize solver for solving a non-linear problem using Newton's method`.

        Args:
            F: The PDE residual F(u, v)
            u: The unknown
            bcs: List of Dirichlet boundary conditions
            J: UFL representation of the Jacobian (Optional)
            form_compiler_options: Options used in FFCx
                compilation of this form. Run ``ffcx --help`` at the
                command line to see all available options.
            jit_options: Options used in CFFI JIT compilation of C
                code generated by FFCx. See ``python/dolfinx/jit.py``
                for all available options. Takes priority over all
                other option values.

        Example::

            problem = LinearProblem(F, u, [bc0, bc1])
        """
        self.u = u
        self._L = _create_form(
            F, form_compiler_options=form_compiler_options, jit_options=jit_options
        )

        # Create the Jacobian matrix, dF/du
        if J is None:
            V = u.function_space
            du = ufl.TrialFunction(V)
            J = ufl.derivative(F, u, du)

        self._a = _create_form(
            J, form_compiler_options=form_compiler_options, jit_options=jit_options
        )
        self.bcs = bcs
        self.pnsObject = pnsObject
        self.bTemp = create_vector(self._L)
        self.Atemp = create_matrix(self._a)
        self.AKtemp = PETSc.Mat()
        self.KTAKtemp = PETSc.Mat()

    @property
    def L(self) -> ufl.form.Form:
        """Compiled linear form (the residual form)"""
        return self._L

    @property
    def a(self) -> ufl.form.Form:
        """Compiled bilinear form (the Jacobian form)"""
        return self._a

    def form(self, x: PETSc.Vec) -> None:
        """This function is called before the residual or Jacobian is
        computed. This is usually used to update ghost values.

        Args:
           x: The vector containing the latest solution
        """
        # x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
        self.pnsObject['prolongation_matrix'].mult(x, self.u.x.petsc_vec)

    def F(self, x: PETSc.Vec, b: PETSc.Vec) -> None:
        """Assemble the residual F into the vector b.

        Args:
            x: The vector containing the latest solution
            b: Vector to assemble the residual into
        """
        # # Reset the residual vector
        # with b.localForm() as b_local:
        #     b_local.set(0.0)
        # assemble_vector(b, self._L)

        # # Apply boundary condition
        # apply_lifting(b, [self._a], bcs=[self.bcs], x0=[x], alpha=-1.0)
        # b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        # set_bc(b, self.bcs, x, -1.0)

        assemble_vector_pns(self.pnsObject, self._L, lagrange_bcs=self.bcs, bilinear_form=self._a, result=b, x0=x, alpha=-1.0, b=self.bTemp)
        # print(b.array)
        # print("----------------\n\n\n-------------------")
        # print("Residual norm:", b.norm())

    def J(self, x: PETSc.Vec, A: PETSc.Mat) -> None:
        """Assemble the Jacobian matrix.

        Args:
            x: The vector containing the latest solution
        """
        # A.zeroEntries()
        # assemble_matrix(A, self._a, self.bcs)
        # A.assemble()
        assemble_matrix_pns(self.pnsObject, self._a, lagrange_bcs=self.bcs, result=A, A=self.Atemp, AK=self.AKtemp, KTAK=self.KTAKtemp)

class PnsDirichletBC:
    def __init__(self, pnsObject, value: np.ndarray, dof: np.ndarray):
        self.value = value
        self.dof = dof
        self.pnsObject = pnsObject

    def setVector(self, x: PETSc.Vec):
        x.setValues(self.dof, self.value)
        x.assemble()

    def setMatrix(self, mat: PETSc.Mat):
        mat.zeroRowsColumns(self.dof, diag = 1)
        mat.assemble()
        # mat.zeroRows(self.dof, diag=1)

    def getXg(self):
        x_g = PETSc.Vec().createSeq(self.pnsObject['prolongation_matrix'].getSize()[1], comm=self.pnsObject['comm'])
        x_g.setValues(self.dof, self.value)
        x_g.assemble()
        return x_g

def loadPns(file_name: str, numElements = 1, set_boundary_gradient=False, comm = MPI.COMM_WORLD):
    control_mesh = pns.Pns_control_mesh.from_file(file_name)
    if set_boundary_gradient:
        augmented_control_mesh = pns.set_boundary_gradient(control_mesh)
    else:
        augmented_control_mesh = control_mesh
    pnsLagrangePatchBuilders = pns.get_patch_builders(control_mesh)
    pnsPatchBuilders = pns.get_patch_builders(control_mesh)
    pnsAugmentedLagrangePatchBuilders = pns.get_patch_builders(augmented_control_mesh)
    pnsAugmentedPatchBuilders = pns.get_patch_builders(augmented_control_mesh)
    if len(pnsLagrangePatchBuilders) == 0:
        raise Exception("No patches found")
    flatPatches = []
    numPatches = 0
    for pb in pnsLagrangePatchBuilders:
        pb.toLagrange()
    for pb in pnsAugmentedLagrangePatchBuilders:
        pb.toLagrange()
    for pb in pnsPatchBuilders:
        for patch in pb.build_patches(control_mesh):
            flatPatches.append(patch)
            numPatches += 1
        pb.degRaise()
    for pb in pnsAugmentedPatchBuilders:
        pb.degRaise()
    numVerts = len(control_mesh.get_vertices())
    pnsObject = {'lagrangePatchBuilders': pnsLagrangePatchBuilders, 'patchBuilders': pnsPatchBuilders, 'augmentedLagrangePatchBuilders': pnsAugmentedLagrangePatchBuilders, 'augmentedPatchBuilders': pnsAugmentedPatchBuilders, 'numPatches': numPatches, 'numVerts': numVerts, 'control_mesh': control_mesh, "flat_patches": flatPatches, "numElements": numElements, "comm": comm}
    K = assemble_prolongation_matrix(pnsObject)
    pnsObject["prolongation_matrix"] = K
    return pnsObject

def createPnsDolfinxMesh(pnsObject) -> dolfinx.mesh.Mesh:
    x = []
    cells = []
    idx = 0
    temp = 0
    for pb in pnsObject['lagrangePatchBuilders']:
        lagrangePatches = pb.build_patches(pnsObject['control_mesh'])
        for patch in lagrangePatches:
            temp += 1
            lagrangeCoeffs = patch.bb_coefs
            currX = [0]*16
            cpIdx = 0
            for row in lagrangeCoeffs:
                for col in row:
                    currX[cubicLagrangeDofOrder[cpIdx]] = [col[0], col[1], col[2]]
                    cpIdx += 1
            x.extend(currX)
            cells.append(list(range(idx, idx+16)))
            idx += 16
    domain = ufl.Mesh(
        basix.ufl.element(
            "Lagrange",
            "quadrilateral",
            3,
            lagrange_variant=basix.LagrangeVariant.equispaced,
            shape=(3,),
        )
    )

    mesh = dolfinx.mesh.create_mesh(pnsObject["comm"], cells, x, domain)
    return mesh

def pnsFunctionSpace(mesh, pnsObject) -> dolfinx.fem.FunctionSpace:
    el = basix.ufl.element(
            "Lagrange",
            "quadrilateral",
            3,
            lagrange_variant=basix.LagrangeVariant.equispaced,
        )
    if pnsObject['numElements'] > 1:
        el = basix.ufl.mixed_element([el] * pnsObject['numElements'])
    V = dolfinx.fem.functionspace(mesh, el)
    return V



# testProlongationMatrix(pnsObject)

# """## Visualize Pns"""


import numpy as np
import json
from math import comb
from functools import cache

class PnsVisualizer:
    def __init__(self, pnsObject, dt=0):
        self.pnsObject = pnsObject
        self.dt = dt
        self.funcList = []

    def addFunc(self, valDofs):
        self.funcList.append(valDofs.copy())

    def visualizePns(self,
                     expectedResult = None,
                     valRange = None,
                     showContour=True,
                     colorMap='blackbody',
                     res_u=20, res_v=20,
                     width=800, height=600, saveToFile = True):
        """
        Render time‐dependent PNS patches with a GUI slider.
        colorMap: 'rainbow','cooltowarm','blackbody','grayscale'
        """
        us = np.linspace(0, 1, res_u)
        vs = np.linspace(0, 1, res_v)

        @cache
        def bernstein(i, t):
            return comb(3, i) * (t**i) * ((1 - t)**(3 - i))

        # 1) Build static mesh positions & faces
        allPatchCps = np.array([
            patch.bb_coefs
            for pb in self.pnsObject['patchBuilders']
            for patch in pb.build_patches(self.pnsObject['control_mesh'])
        ])  # shape (P,4,4,3)
        P = allPatchCps.shape[0]

        Bu = np.array([[bernstein(i, u) for i in range(4)] for u in us])  # (res_u,4)
        Bv = np.array([[bernstein(j, v) for j in range(4)] for v in vs])  # (res_v,4)

        pts = np.empty((P, res_u, res_v, 3))
        for k in range(3):
            pts[..., k] = np.einsum('ui,pij,vj->puv', Bu, allPatchCps[..., k], Bv)

        vertices = pts.reshape(-1, 3).tolist()

        idx = np.arange(res_u * res_v).reshape(res_u, res_v)
        I0 = idx[:-1, :-1].ravel()
        I1 = idx[1:,  :-1].ravel()
        I2 = idx[1:,   1:].ravel()
        I3 = idx[:-1,  1:].ravel()
        F0 = np.vstack([
            np.stack([I0, I1, I2], axis=1),
            np.stack([I0, I2, I3], axis=1)
        ])  # (2*(res_u-1)*(res_v-1),3)

        faces = (
            F0[None, :, :] +
            (res_u * res_v) * np.arange(P)[:, None, None]
        ).reshape(-1, 3).tolist()

        # 2) Precompute scalar arrays for each time‐step
        vals_all = []
        positions = np.array(vertices)
        for valDofs in self.funcList:
            # build val‐function control‐point arrays
            valPatchCps = np.array([
                patch.bb_coefs
                for pb in self.pnsObject['augmentedPatchBuilders']
                for patch in pb.build_patches(
                    pns.Pns_control_mesh.from_data(
                        [(v,0,0) for v in valDofs],
                        []
                    )
                )
            ])  # shape (P,4,4,3)

            valCtrl = valPatchCps[..., 0]  # extract the scalar from x‐coord
            spts = np.einsum('ui,pij,vj->puv', Bu, valCtrl, Bv)  # (P,res_u,res_v)
            s_flat = spts.reshape(-1)

            if expectedResult is not None:
                s_flat = np.abs(s_flat - np.array([
                    expectedResult(tuple(pt)) for pt in positions
                ]))

            vals_all.append(s_flat.tolist())

        # 3) Global min/max
        if valRange is None:
            global_min = min(min(arr) for arr in vals_all)
            global_max = max(max(arr) for arr in vals_all)
        else:
            global_min, global_max = valRange

        print(f"Range min: {global_min}")
        print(f"Range max: {global_max}")

        # 4) JSONify
        verts_json    = json.dumps(vertices)
        faces_json    = json.dumps(faces)
        vals_json_all = json.dumps(vals_all)
        min_json      = json.dumps(global_min)
        max_json      = json.dumps(global_max)

        # 5) Build HTML/JS with RawShaderMaterial
        html = f"""
        <div id="bezier-view" style="width:{width}px; height:{height}px;"></div>
        <script type="importmap">
        {{
        "imports": {{
            "three":         "https://cdn.jsdelivr.net/npm/three@0.167.1/build/three.module.js",
            "three/addons/": "https://cdn.jsdelivr.net/npm/three@0.167.1/examples/jsm/"
        }}
        }}
        </script>
        <script type="module">
        import * as THREE         from 'three';
        import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js';
        import {{ Lut }}           from 'three/addons/math/Lut.js';
        import {{ GUI }}           from 'three/addons/libs/lil-gui.module.min.js';

        // SCENE SETUP
        const container = document.getElementById('bezier-view');
        const scene     = new THREE.Scene();
        scene.background = new THREE.Color(0xffffff);
        const camera    = new THREE.PerspectiveCamera(45, {width}/{height}, 0.1, 1000);
        camera.position.set(0,10,20);
        camera.lookAt(0,0,0);

        const renderer = new THREE.WebGLRenderer({{ antialias: true }});
        renderer.setSize({width}, {height});
        renderer.autoClear = false;
        container.appendChild(renderer.domElement);

        // STATIC GEOMETRY
        const verts = {verts_json};
        const faces = {faces_json};
        const geom  = new THREE.BufferGeometry();
        geom.setAttribute('position',
        new THREE.BufferAttribute(new Float32Array(verts.flat()), 3)
        );
        geom.setIndex(
        new THREE.BufferAttribute(new Uint32Array(faces.flat()), 1)
        );
        geom.computeVertexNormals();

        // SCALAR FIELDS & LUT
        const vals_all = {vals_json_all};
        const lut      = new Lut('{colorMap}', 512);
        lut.setMin({min_json});
        lut.setMax({max_json});

        // LUT LEGEND
        const uiScene   = new THREE.Scene();
        const orthoCam  = new THREE.OrthographicCamera(-1,1,1,-1,1,2);
        orthoCam.position.set(0.9, 0, 1);
        const sprite    = new THREE.Sprite(new THREE.SpriteMaterial({{
        map: new THREE.CanvasTexture(lut.createCanvas())
        }}));
        sprite.material.map.colorSpace = THREE.SRGBColorSpace;
        sprite.scale.x = 0.125;
        uiScene.add(sprite);

        // BUFFERS
        const N = vals_all[0].length;
        const colorArr  = new Float32Array(N*3);
        const scalarArr = new Float32Array(vals_all[0]);
        geom.setAttribute('color',
        new THREE.BufferAttribute(colorArr, 3)
        );
        geom.setAttribute('scalar',
        new THREE.BufferAttribute(scalarArr, 1)
        );

        // SHADERS
        const vertexShader = `
        attribute vec3 position;
        attribute vec3 color;
        attribute float scalar;
        uniform mat4 modelViewMatrix;
        uniform mat4 projectionMatrix;
        varying vec3 vColor;
        varying float vScalar;
        void main() {{
        vColor    = color;
        vScalar   = scalar;
        gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
        }}
        `;
        const fragmentShader = `
        precision highp float;
        varying vec3 vColor;
        varying float vScalar;
        uniform float minVal;
        uniform float maxVal;
        uniform float nIso;
        void main() {{
        float t       = (vScalar - minVal) / (maxVal - minVal);
        float isoDist = abs(fract(t * nIso) - 0.5);
        float line    = 1.0 - step(0.02, isoDist);
        vec3 outC     = mix(vColor, vec3(1.0), line);
        gl_FragColor  = vec4(outC, 1.0);
        }}
        `;

        // RAW SHADER MATERIAL
        const rawMat = new THREE.RawShaderMaterial({{
        vertexShader:   vertexShader,
        fragmentShader: fragmentShader,
        uniforms: {{
            minVal: {{ value: {min_json} }},
            maxVal: {{ value: {max_json} }},
            nIso:   {{ value: {10 if showContour else 0} }}
        }},
        vertexColors: true,
        side: THREE.DoubleSide
        }});

        // MESH & CONTROLS
        const mesh = new THREE.Mesh(geom, rawMat);
        scene.add(mesh);
        new OrbitControls(camera, renderer.domElement).update();

        // UPDATE COLORS & SCALARS
        const onColorChange = idx => {{
        const arr = vals_all[Math.floor(idx)];
        for (let i = 0; i < N; i++) {{
            const c = lut.getColor(arr[i]);
            colorArr[3*i]   = c.r;
            colorArr[3*i+1] = c.g;
            colorArr[3*i+2] = c.b;
            scalarArr[i]    = arr[i];
        }}
        geom.attributes.color.needsUpdate  = true;
        geom.attributes.scalar.needsUpdate = true;
        }};

        // GUI SLIDER
        var isPlaying = false;
        const params = {{ t: 0, "Play/Pause": () => {{isPlaying = !isPlaying}} }};
        if (vals_all.length > 1) {{
        const gui = new GUI();

        gui.add(params, 't', 0, vals_all.length-1, 1)
            .name('time')
            .onChange(onColorChange).listen();
        gui.add(params, "Play/Pause")
        }}

        onColorChange(0);

        // GRID HELPER
        const gridHelper = new THREE.GridHelper(20, 10);
        gridHelper.material.opacity     = 0.2;
        gridHelper.material.transparent = true;
        gridHelper.position.y           = -1;
        scene.add(gridHelper);

        // RENDER LOOP
        let lastTime = Date.now()
        function animate() {{
        const now = Date.now()
        const dt = (now - lastTime)/300;
        if (isPlaying) {{
        params.t = (params.t + dt) % vals_all.length;
        onColorChange(params.t)
        }}
        lastTime = now;
        requestAnimationFrame(animate);
        renderer.clear();
        renderer.render(scene, camera);
        renderer.clearDepth();
        renderer.render(uiScene, orthoCam);
        }}
        animate();
        </script>
        """
        if saveToFile:
            with open("result.html", "w") as f:
                f.write(html)


# PnS FEA

In [None]:
import dolfinx
import basix
import ufl
from mpi4py import MPI
from dolfinx.fem.petsc import LinearProblem
import numpy
import polyhedral_net_splines as pns
import numpy as np

"""# PNS FEM"""

import typing
from petsc4py import PETSc
import dolfinx.fem.petsc
from dolfinx.fem.petsc import assemble_matrix_mat, assemble_vector, create_matrix, _create_form, create_vector
import itertools
import functools
from mpi4py import MPI
import ufl
import numpy as np
import polyhedral_net_splines as pns
import basix
from tqdm import tqdm

import time
from functools import wraps
from contextlib import contextmanager

# Store total times
total_times = {}

# Thread-local (or global) exclusion time store
_excluded_time = {}

@contextmanager
def exclude_time():
    thread_id = 'default'  # In real multi-threaded code, use threading.get_ident()
    start = time.perf_counter()
    yield
    end = time.perf_counter()
    _excluded_time[thread_id] = _excluded_time.get(thread_id, 0) + (end - start)

def track_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        thread_id = 'default'
        _excluded_time[thread_id] = 0  # Reset before each call
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        elapsed = end - start - _excluded_time[thread_id]
        total_times[func.__name__] = total_times.get(func.__name__, 0) + elapsed
        return result
    return wrapper

cubicLagrangeDofOrder = [0, 6, 7, 2, 4, 12, 14, 10, 5, 13, 15, 11, 1, 8, 9, 3]
cubicLagrangeDofOrderInv = [0, 12, 3, 15, 4, 8, 1, 2, 13, 14, 7, 11, 5, 9, 6, 10]

def assemble_prolongation_matrix(pnsObject) -> PETSc.Mat:
    """Construct Prolongation matrix of size (numBBcoeffs, numPnsCP)
    Args:
        pnsObject: PNS object

    Returns:
        Prolongation matrix
    """
    comm = pnsObject['comm']
    numBB = pnsObject['numPatches'] * 16
    numPns = pnsObject['numVerts']
    numElements = pnsObject['numElements']

    K = PETSc.Mat().createAIJ(size=(numBB*numElements, numPns*numElements), comm=comm)
    K.setUp()
    currCp = 0
    for pb in pnsObject['augmentedLagrangePatchBuilders']:
        neighborhood = pb.neighbor_verts
        mask = pb.mask
        for maskIdx in range(len(mask)):
            rowIdx = (maskIdx//16)*16 + cubicLagrangeDofOrderInv[maskIdx % 16]
            cpRow = mask[rowIdx]
            u, inv = np.unique(neighborhood, return_inverse=True)
            coeffs = dict(zip(u, np.bincount(inv, weights=cpRow)))
            for neighbor in coeffs:
                for e in range(numElements):
                    writeRow = e*16 + currCp
                    writeCol = e + numElements * neighbor
                    K[writeRow, writeCol] = coeffs[neighbor]
            if maskIdx % 16 == 15:
                currCp += (numElements-1)*16
            currCp += 1
    K.assemble()
    return K


def fix_zero_rows(mat: PETSc.Mat):
    """For any all-zero row in PETSc Mat, set the diagonal to 1."""
    nrows, _ = mat.getSize()
    zero_rows = []
    for i in range(nrows):
        # Get row entries
        cols, vals = mat.getRow(i)
        if cols.size == 0 or all(abs(v) < 1e-14 for v in vals):
            zero_rows.append(i)
    mat.zeroRowsColumns(zero_rows, diag = 1)
    mat.assemble()
    return mat

@track_time
def assemble_matrix_pns(pnsObject, a, bc = None, lagrange_bcs=[], result = None, A = None, AK = None, KTAK = None):
    """Assemble bilinear form into a matrix for PNS.
    Args:
        pnsObject: PNS object
        a: Bilinear UFL form or a sequence of sequence of bilinear
            forms, the left hand side of the variational problem.
        bc: Dirichlet boundary condition with PNS DOF.
        lagrange_bcs: Lagrange basis Dirichlet boundary conditions.
        result: Optional Matrix to assemble the bilinear form into.
        A: Bilinear form matrix in Lagrange basis
        AK: Bilinear form matrix in Lagrange basis
        KTAK: Temprory as setting to result directly without copy not working.
    Returns:
        Bilinear form matrix for PNS.
        Bilinear form matrix in Lagrange basis(no BC)
    """
    comm = pnsObject['comm']

    with exclude_time():
        if A is None:
            A = create_matrix(a)
        # A.assemble()
        A.zeroEntries()
        A = dolfinx.fem.petsc.assemble_matrix(A, a, bcs=lagrange_bcs)
        A.assemble()
    K = pnsObject['prolongation_matrix']
    AK = A.matMult(K, result=AK)
    KTAK = K.transposeMatMult(AK, result = KTAK) # for some reason cannot set directly to result when result is used with non-linear solver.
    if bc is not None:
        bc.setMatrix(KTAK)
    KTAK.assemble()
    KTAK = fix_zero_rows(KTAK)
    if result is not None:
        result.assemble()
        KTAK.copy(result)
    return KTAK, A

@track_time
def assemble_vector_pns(pnsObject, L, A = None, bc = None, lagrange_bcs=[], bilinear_form=None, result = None, x0=None, alpha=1.0, b = None):
    """Assemble linear form into a vector for PNS.
    Args:
        pnsObject: PNS object
        L: Linear UFL form or a sequence of linear forms, the right
            hand side of the variational problem.
        A: Bilinear form matrix in Lagrange basis
        bc: Dirichlet boundary condition with PNS DOF.
        lagrange_bcs: Lagrange basis Dirichlet boundary conditions.
        bilinear_form: Bilinear form in Lagrange basis. Required if applying lagrange bcs.
        result: Optional Vector to assemble the linear form into.
        x0: For BC
        alpha: For BC
        b: Optional vector to store the temopary vector for Lagrange basis.
    Returns:
        Linear form vector for PNS.
    """
    comm = pnsObject['comm']
    if result is None:
        result = PETSc.Vec().createSeq(pnsObject['prolongation_matrix'].getSize()[1], comm=comm)
    K = pnsObject['prolongation_matrix']
    with exclude_time():
        if b is None:
            b = create_vector(L)
        with b.localForm() as bl:
            bl.set(0)
        dolfinx.fem.petsc.assemble_vector(b, L)
        if len(lagrange_bcs) > 0:
            dolfinx.fem.petsc.apply_lifting(b, [bilinear_form], [lagrange_bcs], x0=[x0] if x0 is not None else [], alpha=alpha)
            dolfinx.fem.petsc.set_bc(b, lagrange_bcs, x0=x0, alpha=alpha)

    lifted_b = b # if no pns BC

    # TODO: account for x0 and alpha
    if bc is not None:
        AK = A.matMult(K)
        x_g = bc.getXg()
        lifted_b = PETSc.Vec().createSeq(K.getSize()[0], comm=comm)
        AK.mult(x_g, lifted_b)
        lifted_b.aypx(-1, b)
    K.multTranspose(lifted_b, result)
    if bc is not None:
        bc.setVector(result)
    result.assemble()
    return result, b

def pnsAssign(source_pns: PETSc.Vec, target_lagrange: dolfinx.fem.Function, pnsObject):
    """Assign the value from source vector in Lagrange basis to target vector in PNS basis.
    Args:
        source_pns: Source vector in PNS basis.
        target_lagrange: Target function in Lagrange basis.
        pnsObject: PNS object
    """
    K = pnsObject['prolongation_matrix']
    K.mult(source_pns, target_lagrange.x.petsc_vec)

def pnsLagrangeAssign(source_lagrange: dolfinx.fem.Function, target_pns: PETSc.Vec, pnsObject):
    """Assign the value from source vector in PNS basis to target vector in Lagrange basis.
    Args:
        source_lagrange: Source function in PNS basis.
        target_pns: Target vector in Lagrange basis.
        pnsObject: PNS object
    """
    K = pnsObject['prolongation_matrix']
    K.multTranspose(source_lagrange.x.petsc_vec, target_pns)

class PnsLinearProblem:
    def __init__(
        self,
        a: ufl.Form,
        L: ufl.Form,
        pnsObject,
        bc: typing.Optional[dolfinx.fem.DirichletBC] = None,
        lagrange_bcs = [],
        comm = MPI.COMM_WORLD,
        petsc_options: typing.Optional[dict] = None,
        form_compiler_options: typing.Optional[dict] = None,
        jit_options: typing.Optional[dict] = None,
    ):
        """Initialize solver for a linear variational problem.

        Args:
            a: Bilinear UFL form or a sequence of sequence of bilinear
                forms, the left hand side of the variational problem.
            L: Linear UFL form or a sequence of linear forms, the right
                hand side of the variational problem.
            pnsObject: PNS object
            bc(PnsDirichletBC): Dirichlet boundary condition with PNS DOF.
            lagrange_bcs: Lagrange basis Dirichlet boundary conditions.
            petsc_options: Options that are passed to the linear
                algebra backend PETSc. For available choices for the
                'petsc_options' kwarg, see the `PETSc documentation
                <https://petsc4py.readthedocs.io/en/stable/manual/ksp/>`_.
            form_compiler_options: Options used in FFCx compilation of
                all forms. Run ``ffcx --help`` at the commandline to see
                all available options.
            jit_options: Options used in CFFI JIT compilation of C
                code generated by FFCx. See `python/dolfinx/jit.py` for
                all available options. Takes priority over all other
                option values.

        Example::

            problem = LinearProblem(a, L, bc, petsc_options={
                "ksp_type": "preonly",
                "pc_type": "lu",
                "pc_factor_mat_solver_type": "mumps"
            })
        """
        # Maybe u needs to be in pns for.
        # Create K. Maybe k can be represended as a form then create_matrix function can be used
        self.pnsObject = pnsObject
        self._a = dolfinx.fem.form(a)
        self._L = dolfinx.fem.form(L)

        self.bc = bc
        self.lagrange_bcs = lagrange_bcs

        comm = pnsObject['comm']

        self._solver = PETSc.KSP().create(comm)
        prefix = f"dolfinx_solve_{id(self)}"
        self._solver.setOptionsPrefix(prefix)

        opts = PETSc.Options()
        opts.prefixPush(prefix)
        if petsc_options:
            for key, val in petsc_options.items():
                opts[key] = val
        opts.prefixPop()
        self._solver.setFromOptions()

    def solve(self) -> PETSc.Vec:
        A_reduced, A = assemble_matrix_pns(self.pnsObject, self._a, self.bc, self.lagrange_bcs)
        b_reduced, _ = assemble_vector_pns(self.pnsObject, self._L, A, self.bc, self.lagrange_bcs, self._a)
        self._solver.setOperators(A_reduced)
        x_reduced = b_reduced.duplicate()
        self._solver.solve(b_reduced, x_reduced)

        return x_reduced

    def __del__(self):
        return

class PnsNewtonSolver(dolfinx.cpp.nls.petsc.NewtonSolver):
    def __init__(self, comm: MPI.Intracomm, problem):
        """A Newton solver for non-linear problems."""
        super().__init__(comm)
        self.problem = problem
        # Create matrix and vector to be used for assembly
        # of the non-linear problem
        pnsObject = problem.pnsObject
        numPnsDof = pnsObject['prolongation_matrix'].getSize()[1]
        self._A = PETSc.Mat().createAIJ(size=(numPnsDof, numPnsDof), comm=comm)
        self.setJ(problem.J, self._A)
        self._b = PETSc.Vec().createSeq(numPnsDof, comm=comm)
        self.setF(problem.F, self._b)
        self._x = PETSc.Vec().createSeq(numPnsDof, comm=comm)
        pnsObject['prolongation_matrix'].multTranspose(problem.u.x.petsc_vec, self._x)
        self.set_form(problem.form)

    def __del__(self):
        self._A.destroy()
        self._b.destroy()

    def solve(self, u: dolfinx.fem.Function):
        """Solve non-linear problem into function u. Returns the number
        of iterations and if the solver converged."""
        n, converged = super().solve(self._x)
        self.problem.pnsObject['prolongation_matrix'].mult(self._x, u.x.petsc_vec)
        u.x.scatter_forward()
        return n, converged

    @property
    def A(self) -> PETSc.Mat:  # type: ignore
        """Jacobian matrix"""
        return self._A

    @property
    def b(self) -> PETSc.Vec:  # type: ignore
        """Residual vector"""
        return self._b

    def setP(self, P: dolfinx.fem.Function, Pmat: PETSc.Mat):  # type: ignore
        """
        Set the function for computing the preconditioner matrix

        Args:
            P: Function to compute the preconditioner matrix
            Pmat: Matrix to assemble the preconditioner into

        """
        super().setP(P, Pmat)




class PnsNonlinearProblem:
    """Nonlinear problem class for solving the non-linear problems.

    Solves problems of the form :math:`F(u, v) = 0 \\ \\forall v \\in V` using
    PETSc as the linear algebra backend.
    """

    def __init__(
        self,
        F: ufl.form.Form,
        u: dolfinx.fem.Function,
        pnsObject,
        bcs: list[dolfinx.fem.DirichletBC] = [],
        J: ufl.form.Form = None,
        form_compiler_options: typing.Optional[dict] = None,
        jit_options: typing.Optional[dict] = None,
    ):
        """Initialize solver for solving a non-linear problem using Newton's method`.

        Args:
            F: The PDE residual F(u, v)
            u: The unknown
            bcs: List of Dirichlet boundary conditions
            J: UFL representation of the Jacobian (Optional)
            form_compiler_options: Options used in FFCx
                compilation of this form. Run ``ffcx --help`` at the
                command line to see all available options.
            jit_options: Options used in CFFI JIT compilation of C
                code generated by FFCx. See ``python/dolfinx/jit.py``
                for all available options. Takes priority over all
                other option values.

        Example::

            problem = LinearProblem(F, u, [bc0, bc1])
        """
        self.u = u
        self._L = _create_form(
            F, form_compiler_options=form_compiler_options, jit_options=jit_options
        )

        # Create the Jacobian matrix, dF/du
        if J is None:
            V = u.function_space
            du = ufl.TrialFunction(V)
            J = ufl.derivative(F, u, du)

        self._a = _create_form(
            J, form_compiler_options=form_compiler_options, jit_options=jit_options
        )
        self.bcs = bcs
        self.pnsObject = pnsObject
        self.bTemp = create_vector(self._L)
        self.Atemp = create_matrix(self._a)
        self.AKtemp = PETSc.Mat()
        self.KTAKtemp = PETSc.Mat()

    @property
    def L(self) -> ufl.form.Form:
        """Compiled linear form (the residual form)"""
        return self._L

    @property
    def a(self) -> ufl.form.Form:
        """Compiled bilinear form (the Jacobian form)"""
        return self._a

    def form(self, x: PETSc.Vec) -> None:
        """This function is called before the residual or Jacobian is
        computed. This is usually used to update ghost values.

        Args:
           x: The vector containing the latest solution
        """
        # x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
        self.pnsObject['prolongation_matrix'].mult(x, self.u.x.petsc_vec)

    def F(self, x: PETSc.Vec, b: PETSc.Vec) -> None:
        """Assemble the residual F into the vector b.

        Args:
            x: The vector containing the latest solution
            b: Vector to assemble the residual into
        """
        # # Reset the residual vector
        # with b.localForm() as b_local:
        #     b_local.set(0.0)
        # assemble_vector(b, self._L)

        # # Apply boundary condition
        # apply_lifting(b, [self._a], bcs=[self.bcs], x0=[x], alpha=-1.0)
        # b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        # set_bc(b, self.bcs, x, -1.0)

        assemble_vector_pns(self.pnsObject, self._L, lagrange_bcs=self.bcs, bilinear_form=self._a, result=b, x0=x, alpha=-1.0, b=self.bTemp)
        # print(b.array)
        # print("----------------\n\n\n-------------------")
        # print("Residual norm:", b.norm())

    def J(self, x: PETSc.Vec, A: PETSc.Mat) -> None:
        """Assemble the Jacobian matrix.

        Args:
            x: The vector containing the latest solution
        """
        # A.zeroEntries()
        # assemble_matrix(A, self._a, self.bcs)
        # A.assemble()
        assemble_matrix_pns(self.pnsObject, self._a, lagrange_bcs=self.bcs, result=A, A=self.Atemp, AK=self.AKtemp, KTAK=self.KTAKtemp)

class PnsDirichletBC:
    def __init__(self, pnsObject, value: np.ndarray, dof: np.ndarray):
        self.value = value
        self.dof = dof
        self.pnsObject = pnsObject

    def setVector(self, x: PETSc.Vec):
        x.setValues(self.dof, self.value)
        x.assemble()

    def setMatrix(self, mat: PETSc.Mat):
        mat.zeroRowsColumns(self.dof, diag = 1)
        mat.assemble()
        # mat.zeroRows(self.dof, diag=1)

    def getXg(self):
        x_g = PETSc.Vec().createSeq(self.pnsObject['prolongation_matrix'].getSize()[1], comm=self.pnsObject['comm'])
        x_g.setValues(self.dof, self.value)
        x_g.assemble()
        return x_g

def loadPns(file_name: str, numElements = 1, set_boundary_gradient=False, comm = MPI.COMM_WORLD):
    control_mesh = pns.Pns_control_mesh.from_file(file_name)
    if set_boundary_gradient:
        augmented_control_mesh = pns.set_boundary_gradient(control_mesh)
    else:
        augmented_control_mesh = control_mesh
    pnsLagrangePatchBuilders = pns.get_patch_builders(control_mesh)
    pnsPatchBuilders = pns.get_patch_builders(control_mesh)
    pnsAugmentedLagrangePatchBuilders = pns.get_patch_builders(augmented_control_mesh)
    pnsAugmentedPatchBuilders = pns.get_patch_builders(augmented_control_mesh)
    if len(pnsLagrangePatchBuilders) == 0:
        raise Exception("No patches found")
    flatPatches = []
    numPatches = 0
    for pb in pnsLagrangePatchBuilders:
        pb.toLagrange()
    for pb in pnsAugmentedLagrangePatchBuilders:
        pb.toLagrange()
    for pb in pnsPatchBuilders:
        for patch in pb.build_patches(control_mesh):
            flatPatches.append(patch)
            numPatches += 1
        pb.degRaise()
    for pb in pnsAugmentedPatchBuilders:
        pb.degRaise()
    numVerts = len(control_mesh.get_vertices())
    pnsObject = {'lagrangePatchBuilders': pnsLagrangePatchBuilders, 'patchBuilders': pnsPatchBuilders, 'augmentedLagrangePatchBuilders': pnsAugmentedLagrangePatchBuilders, 'augmentedPatchBuilders': pnsAugmentedPatchBuilders, 'numPatches': numPatches, 'numVerts': numVerts, 'control_mesh': control_mesh, "flat_patches": flatPatches, "numElements": numElements, "comm": comm}
    K = assemble_prolongation_matrix(pnsObject)
    pnsObject["prolongation_matrix"] = K
    return pnsObject

def createPnsDolfinxMesh(pnsObject) -> dolfinx.mesh.Mesh:
    x = []
    cells = []
    idx = 0
    temp = 0
    for pb in pnsObject['lagrangePatchBuilders']:
        lagrangePatches = pb.build_patches(pnsObject['control_mesh'])
        for patch in lagrangePatches:
            temp += 1
            lagrangeCoeffs = patch.bb_coefs
            currX = [0]*16
            cpIdx = 0
            for row in lagrangeCoeffs:
                for col in row:
                    currX[cubicLagrangeDofOrder[cpIdx]] = [col[0], col[1], col[2]]
                    cpIdx += 1
            x.extend(currX)
            cells.append(list(range(idx, idx+16)))
            idx += 16
    domain = ufl.Mesh(
        basix.ufl.element(
            "Lagrange",
            "quadrilateral",
            3,
            lagrange_variant=basix.LagrangeVariant.equispaced,
            shape=(3,),
        )
    )

    mesh = dolfinx.mesh.create_mesh(pnsObject["comm"], cells, x, domain)
    return mesh

def pnsFunctionSpace(mesh, pnsObject) -> dolfinx.fem.FunctionSpace:
    el = basix.ufl.element(
            "Lagrange",
            "quadrilateral",
            3,
            lagrange_variant=basix.LagrangeVariant.equispaced,
        )
    if pnsObject['numElements'] > 1:
        el = basix.ufl.mixed_element([el] * pnsObject['numElements'])
    V = dolfinx.fem.functionspace(mesh, el)
    return V



# testProlongationMatrix(pnsObject)

# """## Visualize Pns"""


import numpy as np
import json
from math import comb
from functools import cache

class PnsVisualizer:
    def __init__(self, pnsObject, dt=0):
        self.pnsObject = pnsObject
        self.dt = dt
        self.funcList = []

    def addFunc(self, valDofs):
        self.funcList.append(valDofs.copy())

    def visualizePns(self, filename, val_range=None):
        """
        Generates a file that can be visualized by https://cise.ufl.edu/~p.gupta/pns-fea-visualizer/
        """

        allPatchCps = np.array([
            patch.bb_coefs
            for pb in self.pnsObject['patchBuilders']
            for patch in pb.build_patches(self.pnsObject['control_mesh'])
        ])  # shape (P,4,4,3)

        num_patches = allPatchCps.shape[0]
        num_timesteps = len(self.funcList)
        # build val‐function control‐point arrays
        all_valPatchCps = np.array([
            patch.bb_coefs
            for valDofs in self.funcList
            for pb in self.pnsObject['augmentedPatchBuilders']
            for patch in pb.build_patches(
                pns.Pns_control_mesh.from_data(
                    [(v,0,0) for v in valDofs],
                    []
                )
            )
        ])  # shape (num_timesteps*P,4,4,3)
        all_valPatchCps = all_valPatchCps[:, :, :, 0]  # take x‐coord as scalar value
        all_valPatchCps = all_valPatchCps.reshape(num_timesteps, num_patches, 4, 4)  # shape (num_timesteps,P,4,4)
        if val_range is None:
            min_val = np.min(all_valPatchCps)
            max_val = np.max(all_valPatchCps)
        else:
            min_val, max_val = val_range
        out = [str(num_patches), str(num_timesteps), str(min_val), str(max_val)]

        for p in range(num_patches):
            out.append("3 3")  # u‐degree, v‐degree
            out.append(" ".join(map(str, allPatchCps[p].flatten().tolist())))
            out.append(" ".join(map(str, all_valPatchCps[:, p].flatten().tolist())))

        out_str = "\n".join(out)

        with open(filename, "w") as f:
            f.write(out_str)


# Cahn Hilliard

In [None]:
"""
Useful Docker Commands

Build Image
docker build -t pns_fenicsx .

#Run .ipynb File
jupyter nbconvert --to python cahn_hilliard.ipynb
python3 cahn_hilliard.py

#Copy File from inside container
docker cp cahn_hilliard:/app/cahn-hilliard_single.bvx .

#Running Container
docker run -it -v ${PWD}:/app --name cahn_hilliard pns_fenicsx
"""

In [None]:
from dolfinx import default_real_type
from dolfinx.fem import Function, locate_dofs_geometrical, dirichletbc, locate_dofs_topological
from dolfinx.mesh import meshtags, locate_entities, locate_entities_boundary, exterior_facet_indices, entities_to_geometry
from ufl import dx, grad, inner, Measure, SpatialCoordinate
import meshio

In [None]:
#@title Parameters
lmbda = 2.0e-01  # surface parameter
M = 1
dt = 1.0e-5 
filename = "OBJ/plane_2x2.obj" # Filename is used to determine if we have a boundary or not. File with boundary should start with "plane"
numSteps = 500
err = 1e-3
# lmbda = 1.0  # surface parameter
# M = 1
# dt = 5.0e-4  # time step
# filename = "tee_subd3.obj"

In [7]:
# @title Load PnS and create function space
pnsObject = loadPns(filename, comm=MPI.COMM_WORLD, set_boundary_gradient=True)
mesh = createPnsDolfinxMesh(pnsObject)
V = pnsFunctionSpace(mesh, pnsObject)

In [None]:
# @title Test and trial function
v = ufl.TestFunction(V)

u = Function(V)  # current solution
un = Function(V)  # solution from previous converged step

# Initialize intitial condition
u.x.array[:] = 0.0
rng = np.random.default_rng(900)
u.interpolate(lambda x: 0.63 + 0.02 * (0.5 - rng.random(x.shape[1])))

In [None]:
# @title Create boundary integral (naively detects boundary until we build Brep)
coords = mesh.geometry.x     # shape (num_vertices, gdim)
bmin   = coords.min(axis=0) # array([min_x, min_y, min_z])
bmax   = coords.max(axis=0) # array([max_x, max_y, max_z])

In [None]:
#Boundary Conditions
def circle(x):
  return np.logical_and(
    np.logical_and(np.less_equal(np.sqrt(x[0]**2 + x[1]**2), 0.1), np.greater_equal(np.sqrt(x[0]**2 + x[1]**2), 0.05)),
    ~boundary(x)
  )
                        
def boundary(x):
  return np.isclose(x[0], 0, atol=err) | np.isclose(x[0], 1, atol=err) | np.isclose(x[2], 0, atol=err) | np.isclose(x[2], 1, atol=err)


In [None]:
# Detecting boudary DOFs in a naive way. Eventually, the PnS library would be able to tell us
# which vertices are on the boundary.

boundaries = [
    (0, lambda x: boundary(x)),
]

#Creates Facets for Boundary Conditions, Used for Ds measure
facet_indices, tag_values = [], []
fdim = mesh.topology.dim - 1
mesh.topology.create_connectivity(fdim, mesh.topology.dim)

for marker, locator in boundaries:
    #Marker Zero is the Boundary Around the Mesh
    if marker == 0:
        #Not Working Right Now
        facets = exterior_facet_indices(mesh.topology)
        facet_indices.append(facets)
        tag_values.append(np.full_like(facets, marker))
    else:
        facets = locate_entities(mesh, fdim, locator)
        facet_indices.append(facets)
        tag_values.append(np.full_like(facets, marker))
facet_indices = np.hstack(facet_indices).astype(np.int32)
facet_markers = np.hstack(tag_values).astype(np.int32)
sorted_facets = np.argsort(facet_indices)

#ftags = dolfinx.mesh.meshtags(
    #mesh, fdim, facet_indices[sorted_facets], facet_markers[sorted_facets]
#)
ftags = dolfinx.mesh.meshtags(
    mesh, fdim, facet_indices, facet_markers
)


ds = ufl.Measure("ds", domain=mesh, subdomain_data=ftags, metadata={"quadrature_degree": 10})(1)
n = ufl.FacetNormal(mesh)
# ds does not makes sense if we don't have an axis aligned plane boundary
if "plane" not in filename:
    print("DS = 0")
    ds = 0

In [None]:
# @ Boundary Condition Class
# @title Boundry Conditions Class

class BoundaryCondition:
  """
  Type: Dirichlet, Robin or Neumann
  marker: integer of the number of the marker
  Values: Dirichlet: Python Float or Python Function
          Neumann: Function(X)
          Robin: Tuple of (Python Float, Function(SpatialCoordinate))
          
  """
  def __init__(self, type_: str, marker: int, value):
    self._type = type_
    if self._type == "Dirichlet":
      facets = ftags.find(marker)
      #Problem here
      dofs = locate_dofs_topological(V, fdim, facets)
      self._bc = dirichletbc(value, dofs, V)
    elif self._type == "Neumann":
      self._bc = ufl.inner(value, v) * ds(marker)
    elif self._type == "Robin" and type(value) == tuple and len(value) == 2 and (type(value[0]) == float or type(value[0]) == int):
      self._bc = value[0] * inner(u - value[1], v) * ds(marker)
    else:
      print("Unknown Boundary Condition :(")

  @property
  def bc(self):
    return self._bc
  @property
  def type(self):
    return self._type

In [None]:
# @title Formulation
c = ufl.variable((u))
f = 100 * c**2 * (1 - c) ** 2
fp = ufl.diff(f, c)
F = (
    (u - un)/dt * v * dx
    + M*ufl.inner(ufl.grad(fp), ufl.grad(v))* dx
    - lmbda*M*ufl.inner(ufl.div(ufl.grad(c)), ufl.dot(ufl.grad(v), n))* ds
    + lmbda*M*ufl.inner(ufl.div(ufl.grad(c)), ufl.div(ufl.grad(v)))* dx
)

In [None]:
def u_x(x):
    return 1 + x[0]**2 + x[1]**2 + x[2]**2
x = SpatialCoordinate(mesh)
boundary_conditions = [
    BoundaryCondition("Neumann", 0, u_x(x))
]

bcs = []
for condition in boundary_conditions:
    if condition.type == "Dirichlet":
        bcs.append(condition.bc)
    else:
        F += condition.bc



In [None]:
# @title Initilize solver
problem = PnsNonlinearProblem(F, u, pnsObject, bcs=bcs)
solver = PnsNewtonSolver(MPI.COMM_WORLD, problem)

In [None]:
# @title Set solver paramters
solver.max_it = 100
solver.error_on_nonconvergence = False
solver.convergence_criterion = "incremental"
solver.rtol = np.sqrt(np.finfo(default_real_type).eps) * 1e-2
ksp = solver.krylov_solver
opts = PETSc.Options()  # type: ignore
option_prefix = ksp.getOptionsPrefix()
opts[f"{option_prefix}ksp_type"] = "preonly"
opts[f"{option_prefix}pc_type"] = "lu"

sys = PETSc.Sys()  # type: ignore
# For factorisation prefer superlu_dist, then MUMPS, then default
if sys.hasExternalPackage("superlu_dist"):
    opts[f"{option_prefix}pc_factor_mat_solver_type"] = "superlu_dist"
    print("superlu_dist")
elif sys.hasExternalPackage("mumps"):
    opts[f"{option_prefix}pc_factor_mat_solver_type"] = "mumps"
    print("mumps")
ksp.setFromOptions()

superlu_dist


In [None]:
# @title Time step and solve
t = 0.0
T = numSteps * dt
un.x.array[:] = u.x.array

## For the html visualizer
visualizer = PnsVisualizer(pnsObject)
i = 0
while t < T:
    t += dt

    r = solver.solve(u)
    error = np.linalg.norm(solver.b.array)
    print(f"Step {int(t / dt)}: num iterations: {r[0]} Error: {error:.2e}")
    
    un.x.array[:] = u.x.array

    if i % 20 == 0: # Incase we don't want to store every step in the html visualizer.
        visualizer.addFunc(solver._x.array)
    i += 1

In [None]:
# @title Save visualization
# Visualize using https://cise.ufl.edu/~p.gupta/pns-fea-visualizer/
visualizer.visualizePns(filename="cahn-hilliard_single.bvx", val_range=(0,1))