### Tutorial on Space-time FEM with FEniCSx
Dominik Kern ORCID [0000-0002-1958-2982](https://orcid.org/0000-0002-1958-2982) 

This notebook is a supplement to the tutorial TODO zenodo.doi

**solving the non-dimensional heat equation in a 1D bar
using finite elements in space and time-stepping by  the mid-point rule (Crank-Nicolson)** 

In [None]:
import numpy as np
import ufl
from dolfinx import geometry
from dolfinx.fem import (Constant, Function, functionspace, dirichletbc, form)
from dolfinx.fem.petsc import LinearProblem
from dolfinx.fem import locate_dofs_geometrical
from dolfinx.mesh import create_unit_interval
from mpi4py import MPI
import pyvista as pv

#### parameters

In [None]:
T = 1.0             # Total simulation time
nt = 8              # Number of time steps
dt_val = T / nt     # Time step size
comm = MPI.COMM_WORLD
nx = 4              # Number of elements in the spatial mesh
order = 1           # Polynomial order of spatial elements (time-stepping is fixed)

#### discretization

In [None]:
domain = create_unit_interval(comm, nx)
V = functionspace(domain, ("Lagrange", order))

def initial_condition(x):
    return np.sin(np.pi * x[0])
u_n = Function(V)
u_n.name = "u_n"
u_n.interpolate(initial_condition)

boundary_dofs_L = locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0))
boundary_dofs_R = locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 1))
bc_L = dirichletbc(np.float64(0.0), boundary_dofs_L, V)
bc_R = dirichletbc(np.float64(0.0), boundary_dofs_R, V)
bcs = [bc_L, bc_R]

u = ufl.TrialFunction(V)
Du = ufl.TestFunction(V)

dt = Constant(domain, np.float64(dt_val))

a = u * Du * ufl.dx + 0.5 * dt  * ufl.dot(ufl.grad(u), ufl.grad(Du)) * ufl.dx
L = u_n * Du * ufl.dx - 0.5 * dt * ufl.dot(ufl.grad(u_n), ufl.grad(Du)) * ufl.dx
problem_a = form(a)
problem_L = form(L)

##### solution

In [None]:
uh = Function(V)
uh.name = "u_timestep"

problem = LinearProblem(problem_a, problem_L, bcs=bcs, u=uh, petsc_options={"ksp_type": "preonly", "pc_type": "lu"})

# sort u values by x-coordinate to map them to plot-grid correctly
x_coords = V.tabulate_dof_coordinates()[:, 0]
sort_order = np.argsort(x_coords)
u_sol = np.zeros((nt+1, nx*order+1))
u_sol[0, :] = u_n.x.array[sort_order]

t = 0.0
for n in range(nt):
    t += dt_val
    problem.solve()
    u_n.x.array[:] = uh.x.array
    x_coords = V.tabulate_dof_coordinates()[:, 0]
    sort_order = np.argsort(x_coords)
    u_values = uh.x.array
    u_sol[n+1, :] =  u_values[sort_order]

#### post-processing
collecting all results on a space-time grid for plotting

In [None]:
xt = np.meshgrid(np.linspace(0, 1, nx*order+1), np.linspace(0, T, nt+1), indexing='ij')
X, T = xt  

u_grid = u_sol.T  

points = np.zeros((X.size, 3))
points[:, 0] = X.ravel(order="F")  # x
points[:, 1] = T.ravel(order="F")  # t
points[:, 2] = u_grid.ravel(order="F")  # u as height

grid = pv.StructuredGrid()
grid.points = points
grid.dimensions = [X.shape[0], X.shape[1], 1]
grid["u"] = u_grid.ravel(order="F")

plotter = pv.Plotter()
plotter.add_mesh(grid, scalars="u", cmap="viridis", show_edges=True, scalar_bar_args={'vertical':True})
plotter.show_grid(xlabel="x", ylabel="t", zlabel="u")
plotter.show(title="u(x, t) surface plot")