In [1]:
# Import some useful modules.
import jax.numpy as np
import numpy as onp
import os
import pypardiso
import scipy

In [2]:
# Import JAX-FEM specific modules.
from jax_fem.problem import Problem
from jax_fem.solver import solver
from jax_fem.utils import save_sol
from jax_fem.generate_mesh import box_mesh_gmsh, get_meshio_cell_type, Mesh
from jax_fem import logger

       __       ___      ___   ___                _______  _______ .___  ___. 
      |  |     /   \     \  \ /  /               |   ____||   ____||   \/   | 
      |  |    /  ^  \     \  V  /      ______    |  |__   |  |__   |  \  /  | 
.--.  |  |   /  /_\  \     >   <      |______|   |   __|  |   __|  |  |\/|  | 
|  `--'  |  /  _____  \   /  .  \                |  |     |  |____ |  |  |  | 
 \______/  /__/     \__\ /__/ \__\               |__|     |_______||__|  |__| 
                                                                              



In [3]:
# Log events of the DEBUG level and severer
import logging
logger.setLevel(logging.DEBUG)

# Linear elasticity

- PARallel DIrect SOlver (PARDISO) is a software library written in Fortran and C for the dedicated purpose of solving large sparse linear systems of equations efficiently, such as $Ax=b$. `pypardiso` is the Python compiled version.
- Portable, Extensible Toolkit for Scientific Computation (PETSc) is a library that is able to solve linear, non-linear, and time-dependent problems. Apparently similar to PARDISO, PETSc has wider functionality beyond the sparse linear system and it is optimized for high-performance computing.
- In `pardiso_solver`, first the PETSc sparse matrix is converted to a scipy sparse matrix. Then, a PARDISO solver is invoked to solve the algebraic equation `Ax = b`.

In [4]:
def pardiso_solver(A, b, x0, solver_options):
    """
    Solves Ax=b with x0 being the initial guess.

    A: PETSc sparse matrix
    b: JAX array
    x0: JAX array (forward problem) or None (adjoint problem)
    solver_options: anything the user defines, at least satisfying solver_options['custom_solver'] = pardiso_solver
    """
    logger.debug(f"Pardiso Solver - Solving linear system")
    print('solver_options-------->', solver_options)

    # If you need to convert PETSc to scipy
    indptr, indices, data = A.getValuesCSR()
    A_sp_scipy = scipy.sparse.csr_array((data, indices, indptr), shape=A.getSize())
    x = pypardiso.spsolve(A_sp_scipy, onp.array(b))
    return x

The coefficients are defined as $\mu = \frac{E}{2(1 + \nu)}$ and $\lambda = \frac{E\nu}{(1 + \nu) * (1 - 2\nu)}$.

In [5]:
# Material properties.
E = 70e3
nu = 0.3
mu = E / (2. * (1. + nu))
lmbda = E * nu / ((1 + nu) * (1 - 2 * nu))

- `get_tensor_map` defines a function of the displacment gradient $\nabla u$.

In [6]:
# Weak forms.
class LinearElasticity(Problem):
    # The function 'get_tensor_map' overrides base class method. Generally, JAX-FEM 
    # solves -div(f(u_grad)) = b. Here, we have f(u_grad) = sigma.
    def get_tensor_map(self):
        def stress(u_grad):
            epsilon = 0.5 * (u_grad + u_grad.T)
            sigma = lmbda * np.trace(epsilon) * np.eye(self.dim) + 2 * mu * epsilon
            return sigma
        return stress

    # Define the Neumann boundary condition, traction t = [0, 0, 100]
    def get_surface_maps(self):
        def surface_map(u, x):
            return np.array([0., 0., 100.])
        return [surface_map]

In [7]:
# Specify mesh-related information (second-order tetrahedron element).
ele_type = 'TET10'
cell_type = get_meshio_cell_type(ele_type)
# data_dir = os.path.join(os.path.dirname(__file__), '03_data')
data_dir = os.path.join(os.getcwd(), '03_data')
Lx, Ly, Lz = 10., 2., 2.  # unit is meter
Nx, Ny, Nz = 25, 5, 5  # the number of cells (nodes = cells + 1)
meshio_mesh = box_mesh_gmsh(Nx=Nx,
                       Ny=Ny,
                       Nz=Nz,
                       Lx=Lx,
                       Ly=Ly,
                       Lz=Lz,
                       data_dir=data_dir,
                       ele_type=ele_type)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])

Info    : Meshing 1D...
Info    : [  0%] Meshing curve 1 (Extruded)
Info    : [ 10%] Meshing curve 2 (Extruded)
Info    : [ 20%] Meshing curve 3 (Extruded)
Info    : [ 30%] Meshing curve 4 (Extruded)
Info    : [ 40%] Meshing curve 7 (Extruded)
Info    : [ 50%] Meshing curve 8 (Extruded)
Info    : [ 50%] Meshing curve 9 (Extruded)
Info    : [ 60%] Meshing curve 10 (Extruded)
Info    : [ 70%] Meshing curve 12 (Extruded)
Info    : [ 80%] Meshing curve 13 (Extruded)
Info    : [ 90%] Meshing curve 17 (Extruded)
Info    : [100%] Meshing curve 21 (Extruded)
Info    : Done meshing 1D (Wall 0.000549775s, CPU 0s)
Info    : Meshing 2D...
Info    : [  0%] Meshing surface 5 (Extruded)
Info    : [ 20%] Meshing surface 14 (Extruded)
Info    : [ 40%] Meshing surface 18 (Extruded)
Info    : [ 50%] Meshing surface 22 (Extruded)
Info    : [ 70%] Meshing surface 26 (Extruded)
Info    : [ 90%] Meshing surface 27 (Extruded)
Info    : Done meshing 2D (Wall 0.00388237s, CPU 0s)
Info    : Meshing 3D...
Info   

- `points` is the 2-d matrix that represents the coordinate of each cell node with the shape of `[n, dims]`, where `n` and `dims` are the number and dimension of nodes respectively. This is saying that each row is a particular node and its x-, y-, and z-coordinates are given by the column in sequence.
- `[0, 1, 2]` represents the number of component of the displament varaible $u(x, y, z)$.

In [8]:
# Define boundary locations.
def left(point):
    return np.isclose(point[0], 0., atol=1e-5)

def right(point):
    return np.isclose(point[0], Lx, atol=1e-5)

# Define Dirichlet boundary values.
# This means on the 'left' side, we apply the function 'zero_dirichlet_val' 
# to all components of the displacement variable u.
def zero_dirichlet_val(point):
    return 0.

dirichlet_bc_info = [[left] * 3, [0, 1, 2], [zero_dirichlet_val] * 3]

# Define Neumann boundary locations.
# This means on the 'right' side, we will perform the surface integral to get 
# the tractions with the function 'get_surface_maps' defined in the class 'LinearElasticity'.
location_fns = [right]

In [9]:
# Create an instance of the problem.
problem = LinearElasticity(mesh,
                           vec=3,
                           dim=3,
                           ele_type=ele_type,
                           dirichlet_bc_info=dirichlet_bc_info,
                           location_fns=location_fns)

[01-08 21:36:07][DEBUG] jax_fem: Computing shape function values, gradients, etc.
[01-08 21:36:07][DEBUG] jax_fem: ele_type = TET10, quad_points.shape = (num_quads, dim) = (4, 3)
[01-08 21:36:07][DEBUG] jax_fem: face_quad_points.shape = (num_faces, num_face_quads, dim) = (4, 3, 3)
[01-08 21:36:07][DEBUG] jax_fem: Done pre-computations, took 0.5220680236816406 [s]
[01-08 21:36:07][INFO] jax_fem: Solving a problem with 3750 cells, 6171x3 = 18513 dofs.


In [10]:
# Solve the defined problem.
sol_list = solver(problem, solver_options={'custom_solver': pardiso_solver})
# sol_list = solver(problem, solver_options={'umfpack_solver': {}})

[01-08 21:36:10][DEBUG] jax_fem: Calling the row elimination solver for imposing Dirichlet B.C.
[01-08 21:36:10][DEBUG] jax_fem: Start timing
[01-08 21:36:10][DEBUG] jax_fem: Computing cell Jacobian and cell residual...
[01-08 21:36:10][DEBUG] jax_fem: Function split_and_compute_cell took 0.4206 seconds
[01-08 21:36:11][DEBUG] jax_fem: Creating sparse matrix with scipy...
[01-08 21:36:11][DEBUG] jax_fem: Before, res l_2 = 44.62186808181723
[01-08 21:36:11][DEBUG] jax_fem: Solving linear system...
[01-08 21:36:11][DEBUG] jax_fem: JAX Solver - Solving linear system
[01-08 21:36:26][DEBUG] jax_fem: JAX Solver - Finshed solving, res = 8.043486185162176e-09
[01-08 21:36:26][DEBUG] jax_fem: Computing cell Jacobian and cell residual...
[01-08 21:36:26][DEBUG] jax_fem: Function split_and_compute_cell took 0.0919 seconds
[01-08 21:36:26][DEBUG] jax_fem: Creating sparse matrix with scipy...
[01-08 21:36:26][DEBUG] jax_fem: res l_2 = 8.023659491709278e-09
[01-08 21:36:26][INFO] jax_fem: Solve too

In [11]:
# Postprocess for stress evaluations
# (num_cells, num_quads, vec, dim)
u_grad = problem.fes[0].sol_to_grad(sol_list[0])
epsilon = 0.5 * (u_grad + u_grad.transpose(0,1,3,2))
# (num_cells, num_quads, 1, 1) * (num_cells, num_quads, vec, dim)
# -> (num_cells, num_quads, vec, dim)
sigma = lmbda * np.trace(epsilon, axis1=2, axis2=3)[:,:,None,None] * np.eye(problem.dim) + 2*mu*epsilon
# (num_cells, num_quads)
cells_JxW = problem.JxW[:,0,:]
# (num_cells, num_quads, vec, dim) * (num_cells, num_quads, 1, 1) ->
# (num_cells, vec, dim) / (num_cells, 1, 1)
#  --> (num_cells, vec, dim)
sigma_average = np.sum(sigma * cells_JxW[:,:,None,None], axis=1) / np.sum(cells_JxW, axis=1)[:,None,None]

# Von Mises stress
# (num_cells, dim, dim)
s_dev = (sigma_average - 1/problem.dim * np.trace(sigma_average, axis1=1, axis2=2)[:,None,None]
                                       * np.eye(problem.dim)[None,:,:])
# (num_cells,)
vm_stress = np.sqrt(3./2.*np.sum(s_dev*s_dev, axis=(1,2)))

In [12]:
# Store the solution to local file.
vtk_path = os.path.join(data_dir, 'vtk/u.vtu')
save_sol(problem.fes[0], sol_list[0], vtk_path, cell_infos=[('vm_stress', vm_stress)])