In [94]:
%matplotlib notebook
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [95]:
import os
os.environ['PKG_CONFIG_PATH'] = '/ocean/projects/asc170022p/mtragoza/mambaforge/envs/lung-project/lib/pkgconfig'

import numpy as np
import xarray as xr
import pygalmesh
import fenics as fe
import torch

import sys
sys.path.append('..')
import project

In [150]:
%autoreload

shape = (31, 31, 31)
resolution = (1, 1, 1)

disp_x = np.linspace(-1, 1, shape[0])
disp_y = np.linspace(-1, 1, shape[1])
disp_z = np.linspace(-1, 1, shape[2])

disp = xr.DataArray(
    data=np.stack(np.meshgrid(disp_x, disp_y, disp_z, indexing='ij'), axis=-1),
    dims=['x', 'y', 'z', 'component'],
    coords={
        'x': np.arange(shape[0]) * resolution[0], 
        'y': np.arange(shape[1]) * resolution[1],
        'z': np.arange(shape[2]) * resolution[2],
        'component': ['x', 'y', 'z']
    }
)
disp.name = 'displacement'

disp2 = np.sin(disp * 2 * np.pi / 2)
disp3 = disp.sel(component='y')

project.visual.XArrayViewer(disp, x='x', y='y')
project.visual.XArrayViewer(disp2, x='x', y='y')
project.visual.XArrayViewer(disp3, x='x', y='y')

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14d420bcc610>

In [152]:
mask = ((disp**2).sum('component') <= 1).astype(np.uint8)
mask.name = 'mask'

project.visual.XArrayViewer(mask)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14d42021a7a0>

In [156]:
mesh = pygalmesh.generate_from_array(mask.values, voxel_size=[1,1,1], max_cell_circumradius=2.0, odt=True)

vertices = mesh.points
facets = mesh.cells[0].data
tetras = mesh.cells[1].data

angles = project.meshing.compute_angles_to_interior(vertices, facets, tetras)

vertices.shape, facets.shape, tetras.shape

((1512, 3), (2364, 3), (7196, 4))

In [157]:
# view surface mesh

fig, ax = project.meshing.plot_mesh(
    vertices,
    facets[angles > 0],
    facecolors='white',
    edgecolors='black',
    shade=True,
    alpha=1.0,
    linewidth=0.4,
    figsize=(8,8)
)
#ax.grid(False)
#ax.xaxis.set_pane_color((1,1,1,0))
#ax.yaxis.set_pane_color((1,1,1,0))
#ax.zaxis.set_pane_color((1,1,1,0))
#ax.set_axis_off()

<IPython.core.display.Javascript object>

In [161]:
import meshio
from mpi4py import MPI

def convert_to_fe_mesh(points, cells):
    '''
    Convert from meshio to fenics mesh.
    '''
    mesh_file = 'temp.xdmf'
    meshio.write_points_cells(mesh_file, points, [(cells.type, cells.data)])
    fe_mesh = fe.Mesh()
    with fe.XDMFFile(MPI.COMM_WORLD, mesh_file) as f:
        f.read(fe_mesh)
    return fe_mesh

fe_mesh = convert_to_fe_mesh(mesh.points, mesh.cells[1])

In [162]:
V = fe.FunctionSpace(fe_mesh, 'P', 1)
f = fe.Function(V)
f.vector().get_local().shape

(1512,)

In [163]:
%%time
import torch
import torch.nn.functional as F

# functions for converting between image-like arrays 
#   and vectors of coefficients for a linear FEM basis

def image_to_dofs(image, resolution, V):
    '''
    Args:
        image: (n_x, n_y, n_z, n_c) torch.Tensor
        V: fenics.FunctionSpace
            defined on (mesh_size, 3) coordinates
    Returns:
        dofs: (batch_size, mesh_size, n_channels) torch.Tensor
    '''    
    if V.num_sub_spaces() == 0:
        image = image.unsqueeze(-1)

    n_x, n_y, n_z, n_c = image.shape
    
    coords = V.tabulate_dof_coordinates()
    if V.num_sub_spaces() > 0:
        coords = coords[::V.num_sub_spaces(),:]
    
    mesh_size, n_dims = coords.shape

    coords = torch.as_tensor(coords, dtype=image.dtype, device=image.device)

    shape = torch.as_tensor([n_x, n_y, n_z], dtype=image.dtype, device=image.device)
    resolution = torch.as_tensor(resolution, dtype=image.dtype, device=image.device)
    extent = (shape - 1) * resolution

    dofs = F.grid_sample(
        input=image[None,...].permute(0,4,3,2,1), # xyzc -> bczyx
        grid=(coords[None,None,None,...] / extent) * 2 - 1,
        align_corners=True
    )
    if V.num_sub_spaces() > 0:
        return dofs.view(n_c, mesh_size).permute(1,0)
    else:
        return dofs.view(mesh_size)


u_tensor = torch.as_tensor(disp3.values)
u_func_space = fe.FunctionSpace(fe_mesh, 'P', 1)
u_func_dofs = image_to_dofs(u_tensor, resolution, u_func_space)
u_func_dofs.shape

CPU times: user 7.92 ms, sys: 123 µs, total: 8.04 ms
Wall time: 8.38 ms


torch.Size([1512])

In [164]:
import torch_fenics

u_func = torch_fenics.numpy_fenics.numpy_to_fenics(
    u_func_dofs.detach().cpu().numpy(), fe.Function(u_func_space)
)
u_func.set_allow_extrapolation(True)

In [165]:
u_func([0,0,0])

-0.9999999999999996

In [166]:
u_func([(shape[0] - 1) * resolution[0],0,0])

-1.0000000000000004

In [167]:
u_func([0,(shape[1] - 1) * resolution[1],0])

1.0

In [168]:
u_func([0,0,(shape[2] - 1) * resolution[2]])

-1.0

In [169]:
u_func_space.num_sub_spaces()

0

In [170]:
%%time
import torch_fenics

def dofs_to_image(dofs, V, image_shape, resolution):
    '''
    Args:
        dofs: (mesh_size, n_c) torch.Tensor
        V: fenics.FunctionSpace
            defined on (mesh_size, 3) coordinates
        image_shape: (int, int, int) tuple
    Returns:
        image: (n_x, n_y, n_z, n_c) torch.Tensor
    '''
    if V.num_sub_spaces() > 0:
        mesh_size, n_c = dofs.shape
    else:
        mesh_size, = dofs.shape
        n_c = 1

    n_x, n_y, n_z = image_shape

    x = np.arange(n_x) * resolution[0]
    y = np.arange(n_y) * resolution[1]
    z = np.arange(n_z) * resolution[2]

    grid = np.stack(np.meshgrid(x, y, z, indexing='ij'), axis=-1)
    print(image_shape, grid.shape)

    func = torch_fenics.numpy_fenics.numpy_to_fenics(
        dofs.detach().cpu().numpy(), fe.Function(V)
    )
    func.set_allow_extrapolation(True)

    image = np.zeros((n_x, n_y, n_z, n_c))

    for i in range(n_x):
        for j in range(n_y):
            for k in range(n_z):
                func.eval(image[i,j,k], grid[i,j,k])

    if V.num_sub_spaces() == 0:
        return image.squeeze(-1)
    else:
        return image

u_interp = dofs_to_image(u_func_dofs, u_func_space, shape, resolution)
u_interp.shape

(31, 31, 31) (31, 31, 31, 3)
CPU times: user 503 ms, sys: 42.3 ms, total: 545 ms
Wall time: 490 ms


(31, 31, 31)

In [171]:
project.visual.XArrayViewer(disp3 * 0 + u_interp)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14d4201addb0>