In [None]:
%load_ext autoreload

In [None]:
%autoreload
import numpy as np
import jax.numpy as jnp
from dctkit.mesh import util, simplex
from dctkit.math.opt import optctrl as oc
import dctkit.dec.cochain as C
import dctkit as dt
import pygmsh
import pyvista as pv
from pyvista import themes

In [None]:
dt.config()
#pv.set_jupyter_backend('trame')
pv.global_theme = themes.ParaViewTheme()

In [None]:
lc = 0.2
mesh, _ = util.generate_cube_mesh(lc)
#pv.plot(mesh)
S = util.build_complex_from_mesh(mesh)
num_nodes = S.num_nodes
print("number of nodes = ", num_nodes)
print("number of tets = ", S.S[3].shape[0])
S.get_hodge_star()

In [None]:
# boundary conditions
bottom_nodes = np.argwhere(S.node_coords[:,2]<1e-6).flatten()
top_nodes = np.argwhere(abs(S.node_coords[:,2]-1.)<1e-6).flatten()
values = np.zeros(len(bottom_nodes)+len(top_nodes), dtype=dt.float_dtype)
boundary_values = (np.hstack((bottom_nodes,top_nodes)), values)

In [None]:
from functools import partial
from dctkit.physics import poisson as p

def disspot(u, u_prev, deltat):
    u_coch = C.CochainP0(S, u)
    u_prev_coch = C.CochainP0(S, u_prev)
    u_diff = C.sub(u_coch, u_prev_coch)
    return (1/2)*C.inner(u_diff, u_diff)/deltat

energy = partial(p.energy_poisson, S=S)

def obj(u, u_prev, f, k, boundary_values, gamma, deltat):
    en = energy(x=u, f=f, k=k, boundary_values=boundary_values, gamma=gamma)
    return en + disspot(u, u_prev, deltat)

k = 1.
f_vec = np.ones(num_nodes, dtype=dt.float_dtype)
gamma = 1000.
deltat = 0.1

u_0 = np.zeros(num_nodes, dt.float_dtype)
u_prev = u_0

In [None]:
sols = []
prb = oc.OptimizationProblem(dim=num_nodes, state_dim=num_nodes, objfun=obj)
for i in range(10):
    print("t = ", (i+1)*deltat)
    args = {'u_prev': u_prev, 'f': f_vec, 'k': k, 'boundary_values': boundary_values,
        'gamma': gamma, 'deltat': deltat}
    prb.set_obj_args(args)
    u = prb.solve(u_prev, ftol_abs=1e-8, ftol_rel=1e-8)
    u_prev = u.__array__()
    sols.append(u)
prb.last_opt_result

In [None]:
p = pv.Plotter()
p.add_mesh(mesh, scalars=sols[-1])
p.show()

In [None]:
import meshio
filename = "timedata.xdmf"
points = mesh.points
cells = {"tetra": mesh.cells_dict["tetra"]}
with meshio.xdmf.TimeSeriesWriter(filename) as writer:
    writer.write_points_cells(points, cells)
    for i in range(10):
        writer.write_data(i, point_data={"u": sols[i]})