# Checkpointing in FEniCSx

## Jørgen Schartum Dokken

### Simula Research Laboratory

FEniCS 23'

<div>
<p style="text-align:center;"><img src="figures/fenics_banner.png" width="500">
</div>


# What is checkpointing?

> Checkpointing refers to the ability to store the state of a computation in a way that allows it be continued at a later time without changing the computation’s behavior [1]

<div>
<p style="text-align:center;"><img src="figures/checkpointing.drawio.png" width="600">
</div>

[1] Schulz, M. (2011). Checkpointing. In: Padua, D. (eds) Encyclopedia of Parallel Computing. Springer, Boston, MA. [10.1007/978-0-387-09766-4_62](https://doi.org/10.1007/978-0-387-09766-4_62)

# What is checkpointing?
- Write a mesh and function to file, that is readable by DOLFINx
- Implemented in Python with [ADIOS2](https://adios2.readthedocs.io/en/latest/)
- Available at: https://github.com/jorgensd/adios4dolfinx/

In [1]:
from dolfinx import mesh, fem
from mpi4py import MPI
import numpy as np
import adios4dolfinx as adx
from pathlib import Path

domain = mesh.create_unit_square(MPI.COMM_WORLD, 10, 10)
V = fem.FunctionSpace(domain, ("Lagrange", 5))
u = fem.Function(V)
u.interpolate(lambda x: x[0]**5 + 3*x[1]*x[0]**2)

engine = "BP4"
checkpoint_file = Path("function_checkpoint.bp")
adx.write_mesh(domain, checkpoint_file, engine=engine)
adx.write_function(u, checkpoint_file, engine=engine)

print(f"File {str(checkpoint_file)} exists: {checkpoint_file.exists()}")
domain_in = adx.read_mesh(MPI.COMM_WORLD, checkpoint_file, engine, mesh.GhostMode.shared_facet)
V_in = fem.FunctionSpace(domain_in, ("Lagrange", 5))
u_in = fem.Function(V_in)
adx.read_function(u_in, checkpoint_file, engine=engine)
print(f"Max difference: {max(abs(u_in.x.array - u.x.array))}")

File function_checkpoint.bp exists: True
Max difference: 0.0


# Shallow checkpointing
- Only meant to be used during simulation, on the same mesh

Illustrating this using [IPythonParallel](https://ipyparallel.readthedocs.io/en/latest/)

In [2]:
import ipyparallel as ipp
cluster = ipp.Cluster(engines="mpi", n=3)
rc = cluster.start_and_connect_sync()

Starting 3 engines with <class 'ipyparallel.cluster.launcher.MPIEngineSetLauncher'>


INFO:ipyparallel.cluster.cluster.1686238502-z8j4:Starting 3 engines with <class 'ipyparallel.cluster.launcher.MPIEngineSetLauncher'>


  0%|          | 0/3 [00:00<?, ?engine/s]

In [3]:
%%px
# Using the px magic runs the following commands on each engine
from mpi4py import MPI
from dolfinx import mesh, fem
import adios2
from pathlib import Path
import numpy as np
print(f"{MPI.COMM_WORLD.rank=}, {MPI.COMM_WORLD.size=}")

[stdout:0] MPI.COMM_WORLD.rank=0, MPI.COMM_WORLD.size=3


[stdout:2] MPI.COMM_WORLD.rank=2, MPI.COMM_WORLD.size=3


[stdout:1] MPI.COMM_WORLD.rank=1, MPI.COMM_WORLD.size=3


# How to write a shallow checkpoint

In [4]:
%%px
domain = mesh.create_unit_square(MPI.COMM_WORLD, 10, 10)
V = fem.FunctionSpace(domain, ("Lagrange", 5))
u = fem.Function(V)
u.interpolate(lambda x: x[0]**5 + 3*x[1]*x[0]**2)

def shallow_checkpoint_write(uh, file):
    dofmap = uh.function_space.dofmap
    adios = adios2.ADIOS(uh.function_space.mesh.comm)
    io = adios.DeclareIO("shallow_cp_writer")
    io.SetEngine("BP5")
    outfile = io.Open(str(file), adios2.Mode.Write)
    outfile.BeginStep()
    num_dofs_local = dofmap.index_map.size_local * dofmap.index_map_bs
    local_dofs = u.x.array[:num_dofs_local].copy()
    dofs = io.DefineVariable("dofs", local_dofs, count=[num_dofs_local])
    outfile.Put(dofs, local_dofs, adios2.Mode.Sync)
    outfile.EndStep()
    adios.RemoveIO("shallow_cp_writer")

shallow_checkpoint_file = Path("u.bp")
shallow_checkpoint_write(u, shallow_checkpoint_file)

In [5]:
import subprocess
shallow_checkpoint_file = Path("u.bp")
subprocess.run(["bpls", "-l", str(shallow_checkpoint_file.absolute())])

  double   dofs  [3]*{__} = 2.99914e-53 / 4


CompletedProcess(args=['bpls', '-l', '/root/shared/u.bp'], returncode=0)

# How to read a shallow checkpoint

In [6]:
%%px
def shallow_checkpoint_read(uh, file):
    dofmap = uh.function_space.dofmap
    adios = adios2.ADIOS(uh.function_space.mesh.comm)
    io = adios.DeclareIO("shallow_cp_reader")
    io.SetEngine("BP5")
    infile = io.Open(str(file), adios2.Mode.Read)
    infile.BeginStep()
    num_dofs_local = dofmap.index_map.size_local * dofmap.index_map_bs
    local_dofs = np.zeros(num_dofs_local, dtype=uh.x.array.dtype)
    in_variable = io.InquireVariable("dofs")
    in_variable.SetBlockSelection(uh.function_space.mesh.comm.rank)
    infile.Get(in_variable, local_dofs, adios2.Mode.Sync)
    infile.EndStep()
    adios.RemoveIO("shallow_cp_reader")
    v.x.array[:num_dofs_local] = local_dofs
    v.x.scatter_forward()

v = fem.Function(V)
shallow_checkpoint_read(v, shallow_checkpoint_file)
assert(np.allclose(u.x.array, v.x.array))

# What is the optimal format?

## FEM software
- Most finite element software use their own format for inputting/outputting meshes and functions
- No general consensus on what format to use

## Post-processing
- Most formats are based on [VTK](https://vtk.org/) in some shape or form
- Many different file formats support `vtk`
  - `.pvd`-files (`xml`-based)
  - `.xdmf`-files (`xml`+binary (`.h5`)
  - `.bp`-files (binary files in folder)