In [None]:
from matplotlib import pyplot as plt
import numpy as np

from firedrake import *

In [None]:
mask = np.load("data/filled/1.npy")

mask.shape

In [None]:
plt.matshow(mask)

In [None]:
def create_grid(n):
    # assumes periodic mesh
    step = 1/n
    x = np.arange(0.,  1., step)
    return np.stack(
        np.meshgrid(x, x),
        axis=-1
)

grid = create_grid(mask.shape[0]).reshape(-1, 2)
grid[:, 1] = np.flip(grid[:, 1])

In [None]:
n = mask.shape[0]

mesh = PeriodicUnitSquareMesh(nx=n, ny=n)

In [None]:
WX = FunctionSpace(mesh, "WXRobH3NC", degree=7)

V = FunctionSpace(mesh, "CG", 1)

In [None]:
m = V.ufl_domain()
W = VectorFunctionSpace(m, V.ufl_element())

X = interpolate(m.coordinates, W)

mesh_coords = X.dat.data

In [None]:
def hash_np(coords, digits=6):
    return tuple(np.round(coords, digits))

def create_hashmap(mesh_coords, digits=6):
    hashmap = {hash_np(mesh_coords_, digits): index for index, mesh_coords_ in enumerate(mesh_coords)}
    hashmap["hash_func"] = lambda x: hash_np(x, digits)
    return hashmap

def get_chi(data, mask, grid, hashmap):
    flatten_mask = mask.flatten()
    vals = np.zeros(data.shape[0], dtype=int)

    indices, = np.where(flatten_mask == 1)
    
    hash_func = hashmap["hash_func"]
    
    for index in indices:
        mesh_coords_ = grid[index]
        hashed_coords = hash_func(mesh_coords_)
        vals[hashmap[hashed_coords]] = 1
    
    return vals

In [None]:
hashmap = create_hashmap(mesh_coords)

In [None]:
data =  get_chi(mesh_coords, mask, grid, hashmap)

In [None]:
chi = Function(V)

chi.dat.data[:] = data

tricontourf(chi)

In [None]:
v, u = TrialFunction(WX), TestFunction(WX)

In [None]:
a = inner(v, u)*dx - (
    -inner(Dx(Dx(Dx(v, 0), 0), 0), Dx(Dx(Dx(u, 0), 0), 0))
    - 3 * inner(Dx(Dx(Dx(v, 1), 0), 0), Dx(Dx(Dx(u, 1), 0), 0))
    - 3 * inner(Dx(Dx(Dx(v, 0), 1), 1), Dx(Dx(Dx(u, 0), 1), 1))
    - inner(Dx(Dx(Dx(v, 1), 1), 1), Dx(Dx(Dx(u, 1), 1), 1))
)*dx

L = inner(chi, u)*dx

In [None]:
res = Function(WX, name="field")

solve(a == L, res)


In [None]:
res.dat.data.sum()

In [None]:
l = tricontourf(res)

plt.colorbar(l)
