In [1]:
# Import some generally useful packages.
import jax
import jax.numpy as np
import os

In [2]:
# Import JAX-FEM specific modules.
from jax_fem.problem import Problem
from jax_fem.solver import solver
from jax_fem.utils import save_sol
from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh

       __       ___      ___   ___                _______  _______ .___  ___. 
      |  |     /   \     \  \ /  /               |   ____||   ____||   \/   | 
      |  |    /  ^  \     \  V  /      ______    |  |__   |  |__   |  \  /  | 
.--.  |  |   /  /_\  \     >   <      |______|   |   __|  |   __|  |  |\/|  | 
|  `--'  |  /  _____  \   /  .  \                |  |     |  |____ |  |  |  | 
 \______/  /__/     \__\ /__/ \__\               |__|     |_______||__|  |__| 
                                                                              



# Poisson's Equation

In [3]:
# Define constitutive relationship. 
class Poisson(Problem):
    # The function 'get_tensor_map' overrides base class method. Generally, JAX-FEM 
    # solves -div.f(u_grad) = b. Here, we define f to be the indentity function. 
    # We will see how f is deined as more complicated to solve non-linear problems 
    # in later examples.
    def get_tensor_map(self):
        return lambda x: x
    
    # Define the source term b
    def get_mass_map(self):
        def mass_map(u, x):
            val = -np.array([10*np.exp(-(np.power(x[0] - 0.5, 2) + np.power(x[1] - 0.5, 2)) / 0.02)])
            return val
        return mass_map

    def get_surface_maps(self):
        def surface_map(u, x):
            return -np.array([np.sin(5.*x[0])])

        return [surface_map, surface_map]

In [4]:
# Specify mesh-related information. 
# We make use of the external package 'meshio' and create a mesh named 'meshio_mesh', 
# then converting it into a JAX-FEM compatible one.
ele_type = 'QUAD4'
cell_type = get_meshio_cell_type(ele_type)
Lx, Ly = 1., 1.
meshio_mesh = rectangle_mesh(Nx=32, Ny=32, domain_x=Lx, domain_y=Ly)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])

In [5]:
# Define boundary locations.
def left(point):
    return np.isclose(point[0], 0., atol=1e-5)

def right(point):
    return np.isclose(point[0], Lx, atol=1e-5)

def bottom(point):
    return np.isclose(point[1], 0., atol=1e-5)

def top(point):
    return np.isclose(point[1], Ly, atol=1e-5)

In [6]:
# Define Dirichlet boundary values. 
# This means on the 'left' side, we apply the function 'dirichlet_val_left' 
# to the 0 component of the solution variable; on the 'right' side, we apply 
# 'dirichlet_val_right' to the 0 component.
def dirichlet_val_left(point):
    return 0.

def dirichlet_val_right(point):
    return 0.

location_fns = [left, right]
value_fns = [dirichlet_val_left, dirichlet_val_right]
vecs = [0, 0]
dirichlet_bc_info = [location_fns, vecs, value_fns]

In [7]:
# Define Neumann boundary locations.
# This means on the 'bottom' and 'top' side, we will perform the surface integral 
# with the function 'get_surface_maps' defined in the class 'Poisson'.
location_fns = [bottom, top]

In [8]:
# Create an instance of the Class 'Poisson'. 
# Here, vec is the number of components for the solution.
problem = Poisson(mesh=mesh, vec=1, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns)

[01-08 21:34:48][DEBUG] jax_fem: Computing shape function values, gradients, etc.
[01-08 21:34:48][DEBUG] jax_fem: ele_type = QUAD4, quad_points.shape = (num_quads, dim) = (4, 2)
[01-08 21:34:48][DEBUG] jax_fem: face_quad_points.shape = (num_faces, num_face_quads, dim) = (4, 2, 2)
[01-08 21:34:48][DEBUG] jax_fem: Done pre-computations, took 0.39591455459594727 [s]
[01-08 21:34:48][INFO] jax_fem: Solving a problem with 1024 cells, 1089x1 = 1089 dofs.


In [9]:
# Solve the problem. 
# solver_options can be changed for other linear solver options
sol = solver(problem)
# sol = solver(problem, solver_options={'umfpack_solver': {}})
# sol = solver(problem, solver_options={'petsc_solver': {'ksp_type': 'bcgsl', 'pc_type': 'ilu'}})

[01-08 21:34:53][DEBUG] jax_fem: Calling the row elimination solver for imposing Dirichlet B.C.
[01-08 21:34:53][DEBUG] jax_fem: Start timing
[01-08 21:34:53][DEBUG] jax_fem: Computing cell Jacobian and cell residual...
[01-08 21:34:54][DEBUG] jax_fem: Function split_and_compute_cell took 0.1621 seconds
[01-08 21:34:54][DEBUG] jax_fem: Creating sparse matrix with scipy...
[01-08 21:34:54][DEBUG] jax_fem: Before, res l_2 = 0.18688758627660917
[01-08 21:34:54][DEBUG] jax_fem: Solving linear system...
[01-08 21:34:54][DEBUG] jax_fem: JAX Solver - Solving linear system
[01-08 21:34:55][DEBUG] jax_fem: JAX Solver - Finshed solving, res = 8.126895132795366e-11
[01-08 21:34:55][DEBUG] jax_fem: Computing cell Jacobian and cell residual...
[01-08 21:34:55][DEBUG] jax_fem: Function split_and_compute_cell took 0.0044 seconds
[01-08 21:34:55][DEBUG] jax_fem: Creating sparse matrix with scipy...
[01-08 21:34:55][DEBUG] jax_fem: res l_2 = 8.126885865709939e-11
[01-08 21:34:55][INFO] jax_fem: Solve t

In [10]:
# Save the solution to a local folder that can be visualized with ParaWiew.
data_dir = os.path.join(os.getcwd(), '02_data')
vtk_path = os.path.join(data_dir, f'vtk/u.vtu')
save_sol(problem.fes[0], sol[0], vtk_path)