# Poisson Equation Solver using Firedrake FEM-CG

## Summary

This notebook solves the 2D Poisson equation:
$$-\nabla^2 u = f(x,y)$$

on the unit square $[0,1]\times[0,1]$ with:
- Homogeneous Dirichlet BCs at $x=0$ and $x=1$: $u(0,y) = u(1,y) = 0$
- Homogeneous Neumann BCs at $y=0$ and $y=1$ (natural BCs, automatically satisfied)
- Source term: $f(x,y) = 2\pi^2 \sin(\pi x) \cos(\pi y)$
- Exact solution: $u(x,y) = \sin(\pi x) \cos(\pi y)$

## Methods

1. **Method 1**: Manual weak form construction
2. **Method 2**: Variational principle via derivative

Both methods are implemented in an object-oriented design.

In [1]:
from firedrake import *
import os
from abc import ABC, abstractmethod



## Base Class: PoissonSolver

In [2]:
class PoissonSolver(ABC):
    """
    Abstract base class for solving the 2D Poisson equation -∇²u = f
    on the unit square with homogeneous Dirichlet BCs at x=0 and x=1.
    """
    
    def __init__(self, nx=128, ny=128, quadrilateral=True, degree=1):
        """
        Initialize the Poisson solver.
        
        Parameters:
        -----------
        nx : int
            Number of elements in x-direction
        ny : int
            Number of elements in y-direction
        quadrilateral : bool
            Use quadrilateral elements if True, triangular if False
        degree : int
            Polynomial degree for the function space
        """
        self.nx = nx
        self.ny = ny
        self.degree = degree
        
        # Create mesh
        self.mesh = UnitSquareMesh(nx, ny, quadrilateral=quadrilateral)
        
        # Define function space (Continuous Galerkin)
        self.V = FunctionSpace(self.mesh, 'CG', degree)
        
        # Get spatial coordinates
        self.x, self.y = SpatialCoordinate(self.mesh)
        
        # Define test function
        self.v = TestFunction(self.V)
        
        # Initialize solution
        self.solution = None
        
        # Define source term and exact solution
        self._setup_problem()
        
        # Setup boundary conditions
        self._setup_boundary_conditions()
    
    def _setup_problem(self):
        """
        Define the source term f(x,y) = 2π² sin(πx) cos(πy)
        and exact solution u(x,y) = sin(πx) cos(πy).
        """
        # Source term
        self.f = Function(self.V).interpolate(
            2*pi**2*sin(pi*self.x)*cos(pi*self.y)
        )
        
        # Exact solution (for error computation)
        self.u_exact = Function(self.V).interpolate(
            sin(pi*self.x)*cos(pi*self.y)
        )
    
    def _setup_boundary_conditions(self):
        """
        Setup Dirichlet boundary conditions at x=0 and x=1.
        """
        # u = 0 at x=0 (boundary ID 1)
        bc_x0 = DirichletBC(self.V, Constant(0), 1)
        # u = 0 at x=1 (boundary ID 2)
        bc_x1 = DirichletBC(self.V, Constant(0), 2)
        
        self.bcs = [bc_x0, bc_x1]
    
    @abstractmethod
    def solve(self):
        """
        Solve the Poisson equation. Must be implemented by subclasses.
        """
        pass
    
    def compute_l2_error(self):
        """
        Compute L2 error: ||u_numerical - u_exact||_L2
        
        Returns:
        --------
        float
            L2 norm of the error
        """
        if self.solution is None:
            raise ValueError("Solution not computed yet. Call solve() first.")
        
        error = sqrt(assemble(dot(
            self.solution - self.u_exact,
            self.solution - self.u_exact
        ) * dx))
        
        return error
    
    def get_solution(self):
        """
        Get the computed solution.
        
        Returns:
        --------
        Function
            The solution function
        """
        return self.solution
    
    def get_mesh_resolution(self):
        """
        Get the mesh resolution Δx.
        
        Returns:
        --------
        float
            Mesh spacing in x-direction
        """
        return 1.0 / self.nx

## Method 1: Direct Weak Form Solver

In [3]:
class DirectWeakFormSolver(PoissonSolver):
    """
    Solve Poisson equation using direct weak form construction.
    
    Weak formulation of -∇²u = f:
        ∫ ∇u·∇v dx = ∫ f·v dx  for all test functions v
    
    This is obtained by multiplying the PDE by a test function v,
    integrating over the domain, and applying integration by parts.
    """
    
    def __init__(self, nx=128, ny=128, quadrilateral=True, degree=1,
                 solver_params=None):
        """
        Initialize the direct weak form solver.
        
        Parameters:
        -----------
        nx, ny, quadrilateral, degree : see PoissonSolver
        solver_params : dict
            PETSc solver parameters
        """
        super().__init__(nx, ny, quadrilateral, degree)
        
        # Default solver parameters
        if solver_params is None:
            self.solver_params = {'ksp_type': 'cg', 'pc_type': 'none'}
        else:
            self.solver_params = solver_params
        
        # Define trial function
        self.u = TrialFunction(self.V)
        
        # Setup bilinear and linear forms
        self._setup_forms()
    
    def _setup_forms(self):
        """
        Setup bilinear form a(u,v) and linear form L(v).
        """
        # Bilinear form: a(u,v) = ∫ ∇u·∇v dx
        self.a = inner(grad(self.u), grad(self.v)) * dx
        
        # Linear form: L(v) = ∫ f·v dx
        self.L = self.f * self.v * dx
    
    def solve(self):
        """
        Solve the linear system a(u,v) = L(v).
        """
        # Create function to hold solution
        self.solution = Function(self.V, name='Direct_WeakForm')
        
        # Solve linear system
        firedrake.solve(
            self.a == self.L,
            self.solution,
            bcs=self.bcs,
            solver_parameters=self.solver_params
        )
        
        return self.solution

## Method 2: Variational Principle Solver

In [4]:
class VariationalPrincipleSolver(PoissonSolver):
    """
    Solve Poisson equation using variational principle.
    
    Define the Ritz-Galerkin functional (energy functional):
        J[u] = ∫ (½|∇u|² - fu) dx
    
    The solution minimizes this functional. Taking the first variation
    (derivative) with respect to u gives the weak form automatically.
    """
    
    def __init__(self, nx=128, ny=128, quadrilateral=True, degree=1,
                 solver_params=None):
        """
        Initialize the variational principle solver.
        
        Parameters:
        -----------
        nx, ny, quadrilateral, degree : see PoissonSolver
        solver_params : dict
            PETSc solver parameters
        """
        super().__init__(nx, ny, quadrilateral, degree)
        
        # Default solver parameters
        if solver_params is None:
            self.solver_params = {}
        else:
            self.solver_params = solver_params
        
        # Create function for solution (needed for functional)
        self.solution = Function(self.V, name='Variational_Principle')
        
        # Setup variational functional
        self._setup_functional()
    
    def _setup_functional(self):
        """
        Setup the Ritz-Galerkin functional J[u] and its first variation.
        """
        # Ritz-Galerkin functional: J[u] = ∫ (½ ∇u·∇u - u·f) dx
        self.J = (
            0.5 * inner(grad(self.solution), grad(self.solution)) -
            self.solution * self.f
        ) * dx
        
        # First variation (Gateaux derivative): F(u,v) = dJ/du
        self.F = derivative(self.J, self.solution, du=self.v)
    
    def solve(self):
        """
        Solve the nonlinear system F(u,v) = 0.
        """
        # Solve nonlinear system (linear for Poisson)
        firedrake.solve(
            self.F == 0,
            self.solution,
            bcs=self.bcs,
            solver_parameters=self.solver_params
        )
        
        return self.solution

## Analysis and Visualization Class

In [None]:
class PoissonAnalyzer:
    """
    Class for analyzing and comparing multiple Poisson solver solutions.
    """
    
    def __init__(self, solvers):
        """
        Initialize analyzer with a list of solvers.
        
        Parameters:
        -----------
        solvers : list of PoissonSolver
            List of solver instances to analyze
        """
        self.solvers = solvers
        self.results = {}
    
    def solve_all(self):
        """
        Solve all problems and store results.
        """
        for i, solver in enumerate(self.solvers):
            solver.solve()
            error = solver.compute_l2_error()
            
            self.results[i] = {
                'solver': solver,
                'solution': solver.get_solution(),
                'l2_error': error,
                'mesh_resolution': solver.get_mesh_resolution()
            }
    
    def compare_solutions(self):
        """
        Compare solutions between different solvers.
        
        Returns:
        --------
        dict
            Dictionary containing L2 norms of differences
        """
        if len(self.solvers) < 2:
            return {}
        
        differences = {}
        
        for i in range(len(self.solvers)):
            for j in range(i+1, len(self.solvers)):
                sol_i: Function = self.results[i]['solution']
                sol_j: Function = self.results[j]['solution']
                
                diff = sqrt(assemble(dot(sol_i - sol_j, sol_i - sol_j) * dx))
                differences[f'solver_{i}_vs_{j}'] = diff
        
        return differences
    
    def print_summary(self):
        """
        Print summary of all results.
        """
        print("="*70)
        print("POISSON EQUATION SOLVER RESULTS")
        print("="*70)
        
        for i, result in self.results.items():
            solver_name = type(result['solver']).__name__
            print(f"\nSolver {i+1}: {solver_name}")
            print(f"  Mesh resolution: Δx = {result['mesh_resolution']:.6f}")
            print(f"  L2 error: {result['l2_error']:.6e}")
        
        # Print comparisons
        differences = self.compare_solutions()
        if differences:
            print("\n" + "-"*70)
            print("Comparison between methods:")
            for key, diff in differences.items():
                print(f"  L2 norm ({key}): {diff:.6e}")
        
        print("="*70)
    
    def save_output(self, filename='output.pvd', output_dir=None):
        """
        Save all solutions to ParaView file.
        
        Parameters:
        -----------
        filename : str
            Output filename
        output_dir : str, optional
            Output directory (defaults to current directory)
        """
        if output_dir is not None:
            filepath = os.path.join(output_dir, filename)
        else:
            filepath = filename
        
        outfile = VTKFile(filepath)
        
        # Write all solutions
        solutions = [result['solution'] for result in self.results.values()]
        outfile.write(*solutions)
        
        print(f"\nOutput saved to: {filepath}")
        print(f"View with: paraview {filepath}")

## Run the Solvers

In [6]:
# Set mesh resolution
nx = ny = 128

# Create solver instances
solver1 = DirectWeakFormSolver(nx=nx, ny=ny, quadrilateral=True, degree=1)
solver2 = VariationalPrincipleSolver(nx=nx, ny=ny, quadrilateral=True, degree=1)

# Create analyzer
analyzer = PoissonAnalyzer([solver1, solver2])

# Solve all problems
analyzer.solve_all()

# Print summary
analyzer.print_summary()

# Save output
output_dir = os.path.dirname(os.path.abspath('__file__')) if '__file__' in dir() else os.getcwd()
analyzer.save_output('output.pvd', output_dir=output_dir)

POISSON EQUATION SOLVER RESULTS

Solver 1: DirectWeakFormSolver
  Mesh resolution: Δx = 0.007812
  L2 error: 2.509643e-05

Solver 2: VariationalPrincipleSolver
  Mesh resolution: Δx = 0.007812
  L2 error: 2.509643e-05


ValueError: Multiple domains found, making the choice of integration domain ambiguous.

## Optional: Convergence Study

In [None]:
def convergence_study(mesh_sizes=[16, 32, 64, 128], solver_class=DirectWeakFormSolver):
    """
    Perform a convergence study for different mesh resolutions.
    
    Parameters:
    -----------
    mesh_sizes : list of int
        List of mesh resolutions to test
    solver_class : class
        Solver class to use (DirectWeakFormSolver or VariationalPrincipleSolver)
    
    Returns:
    --------
    dict
        Dictionary with mesh sizes and corresponding errors
    """
    results = {'mesh_sizes': [], 'resolutions': [], 'errors': []}
    
    for n in mesh_sizes:
        solver = solver_class(nx=n, ny=n)
        solver.solve()
        error = solver.compute_l2_error()
        
        results['mesh_sizes'].append(n)
        results['resolutions'].append(1.0/n)
        results['errors'].append(error)
        
        print(f"n = {n:3d}, Δx = {1.0/n:.6f}, L2 error = {error:.6e}")
    
    return results

# Example usage:
# print("\nConvergence study for Direct Weak Form Solver:")
# convergence_results = convergence_study([16, 32, 64, 128], DirectWeakFormSolver)