In [163]:
import dctkit as dt
from dctkit import config
from dctkit.mesh import util, simplex
from dctkit.dec import cochain as C
from dctkit.dec import vector as V
import jax.numpy as jnp
from jax import Array
import numpy as np
import scipy
from scipy.sparse.linalg import cg, spilu, LinearOperator
from typing import Callable
from time import perf_counter

In [275]:
config()

In [276]:
# generate mesh and simplicial complex
mesh, _ = util.generate_square_mesh(0.05)
S = util.build_complex_from_mesh(mesh, is_well_centered=True)
S.get_hodge_star()
S.get_flat_PDP_weights()
n = S.num_nodes
bnd_nodes = util.get_nodes_for_physical_group(mesh, 1, "boundary")
idx_last_bnd_node = int(bnd_nodes[-1])
node_coord = S.node_coords
print("Number of nodes = ", n)

Number of nodes =  513


In [315]:
def assemble_Laplace_operator(S: simplex.SimplicialComplex, k: C.CochainD1V = None) -> Array:
    n = S.num_nodes
    num_edges = S.S[1].shape[0]
    coeffs = jnp.eye(n)
    # build a vector-valued 0-cochain such that each 0-simplex has a one-hot vector
    I = C.CochainP0V(S, coeffs)
    #!!
    if k is None:
        k = C.CochainD1V(S, jnp.ones((num_edges, 1), dtype=dt.float_dtype))
       
    dI = C.coboundary(I)
    # extract the matrix of the conformal Laplacian operator (d * k d)
    A = C.coboundary(C.cochain_mul(k, C.star(dI))).coeffs
    #A = C.coboundary(C.star(dI)).coeffs
    return A

def get_source_vector(S: simplex.SimplicialComplex) -> Array:
    # source term of the equation Au + f = 0, with u = x² + y²
    n = S.num_nodes
    f_vec = -4.*jnp.ones(n, dtype=dt.float_dtype)
    f = C.Cochain(0, True, S, f_vec)
    b = -C.star(f).coeffs
    return b

In [316]:
A = assemble_Laplace_operator(S)
b = get_source_vector(S)

In [317]:
# define true solution for the discrete problem Au + f = 0
u_true = np.array(node_coord[:, 0]**2 + node_coord[:, 1]**2, dtype=dt.float_dtype)
ub = u_true[bnd_nodes]

Ab = A[idx_last_bnd_node+1:,:idx_last_bnd_node+1]

# apply Dirichlet BC and reduce the system to the unknown values of u
A = jnp.delete(A, np.array(bnd_nodes), axis=0)
A = jnp.delete(A, np.array(bnd_nodes), axis=1)
b = jnp.delete(b, np.array(bnd_nodes))
b -= Ab @ ub

In [318]:
# convert jax arrays to numpy (for use with scipy)
A = A.__array__()
b = b.__array__()

In [319]:
# sanity check
x = scipy.linalg.solve(A,b)
print(x)
print(u_true[idx_last_bnd_node+1:])

[1.14089746 0.23025553 0.2775     1.20717082 0.4575     1.37089746
 0.1075     1.0201433  1.02089746 0.10741051 1.38167064 0.4575
 1.59589746 0.68335815 0.03345048 0.68568906 0.94589746 1.59415417
 0.94358365 0.0325     0.18449278 1.09081908 0.33345139 1.24589746
 1.76946122 0.84920293 0.00741125 0.85383085 1.09589746 1.03679492
 0.99429492 0.93769238 0.98269238 0.88608984 0.93358984 0.8394873
 0.8894873  0.79788476 0.75038476 0.71128222 0.76128222 0.67717968
 0.72967968 0.64807714 0.59807714 0.5714746  0.6239746  0.54987206
 0.49987206 0.48076952 0.43326952 0.53326952 0.41666698 0.37166698
 0.39076952 0.33166698 0.35756443 0.35326952 0.29666698 0.32076952
 0.26666698 0.24506443 0.37987206 0.34987206 0.4114746  0.3839746
 0.44807714 0.21756443 0.19846189 0.81628222 0.78717968 0.87628222
 0.84967968 0.94128222 0.91717968 1.01128222 0.98967968 0.89807714
 0.40506443 0.34846189 0.39846189 0.34435935 0.39685935 0.34525681
 0.29525681 0.40025681 0.45435935 0.46025681 0.51685935 0.52525681
 

In [320]:
print(np.linalg.cond(A))

139.18332114017858


In [321]:
# Jacobi preconditioner
P = np.diag(np.diag(A))
print(P)

[[-3.46410162  0.          0.         ...  0.          0.
   0.        ]
 [ 0.         -3.39561079  0.         ...  0.          0.
   0.        ]
 [ 0.          0.         -3.46529762 ...  0.          0.
   0.        ]
 ...
 [ 0.          0.          0.         ... -3.66513687  0.
   0.        ]
 [ 0.          0.          0.         ...  0.         -3.66716746
   0.        ]
 [ 0.          0.          0.         ...  0.          0.
  -3.67996542]]


In [322]:
# build 0-cochains of the node coordinates
x0 = C.CochainP0(S, jnp.array(S.node_coords[:,0].ravel(), dtype=dt.float_dtype))
x1 = C.CochainP0(S, jnp.array(S.node_coords[:,1].ravel(), dtype=dt.float_dtype))
# convert them into dual vector-valued 1-cochains (maybe reshape could be done inside the assemble function and define these as scalar-valued dual 1-coch)
x01 = C.star(V.flat_PDP(x0))
x11 = C.star(V.flat_PDP(x1))
x01.coeffs = x01.coeffs.reshape(-1,1)
x11.coeffs = x11.coeffs.reshape(-1,1)
print(x01.coeffs.shape)

(1456, 1)


In [331]:
def fitness(individual: Callable):
    iter = 0
    # callback function as a counter of iterations
    def callback(kk):
        nonlocal iter
        iter += 1

    # build the preconditioner and start timing (NOTE: should average)
    tic = perf_counter()
    if individual is not None:
        print("Using preconditioner...")
        k = individual(x01, x11)
        P = assemble_Laplace_operator(S, k)
        P = jnp.delete(P, np.array(bnd_nodes), axis=0)
        P = jnp.delete(P, np.array(bnd_nodes), axis=1)
        P = P.__array__()
    else:
        P = None

    # solve the system
    x, _ = cg(A, b, callback=callback, M=P, tol=0., atol=1e-7, maxiter=1000)
    toc = perf_counter()
    
    error = np.linalg.norm(x-u_true[idx_last_bnd_node+1:])**2
    print(error)
    return iter, toc-tic

In [332]:
print(fitness(None))

8.514935926232876e-15
(63, 0.0037283490019035526)


In [328]:
def func(x, y):
    return C.add(C.square(x), C.square(y))

r = func(x01, x11)
print(x01.coeffs**2 + x11.coeffs**2)
print(r.coeffs)

#print(fitness(func))

[[8.41368395e-08]
 [8.41368395e-08]
 [2.40473581e-07]
 ...
 [9.66793767e-05]
 [2.35715775e-04]
 [2.12927490e-04]]
[[8.41368395e-08]
 [8.41368395e-08]
 [2.40473581e-07]
 ...
 [9.66793767e-05]
 [2.35715775e-04]
 [2.12927490e-04]]
