In [None]:
#default_exp pde

In [None]:
#exporti
import torch
import numpy as np
from scipy.sparse import csc_matrix, hstack
import time

In [None]:
#hide
from nbdev.showdoc import show_doc

# FDM assembly

In [None]:
#export
class FDMAssembly():
    """
    This class contains methods that are used for the assembly of the FDM stiffness matrix.
    """

    @staticmethod
    def apply_dirichlet_zero_rows_to_operator(operator, Ω_dirichlet):
        """
        Returns a version of `operator` that fulfills the dirichlet conditions in the output.

        Returns
        -------
        torch.Tensor
        """
        def operator_with_dirichlet_rows_zero(x):
            assert len(Ω_dirichlet.shape) == len(x.shape) == 4
            y = operator(x)
            y[Ω_dirichlet] = 0
            return y

        return operator_with_dirichlet_rows_zero


    @staticmethod
    def apply_dirichlet_zero_columns_to_operator(operator, Ω_dirichlet):
        """
        Returns a version of `operator` that fulfills the dirichlet conditions in the input.

        Returns
        -------
        torch.Tensor
        """
        def operator_with_dirichlet_columns_zero(x):
            assert len(Ω_dirichlet.shape) == len(x.shape) == 4
            x = x.clone()
            x[Ω_dirichlet] = 0
            y = operator(x)
            return y

        return operator_with_dirichlet_columns_zero


    @staticmethod
    def _get_graph(operator, shape, channels_in, filter_shape):
        operator_graph = []

        for i in range(filter_shape[0]):
            for j in range(filter_shape[1]):
                for k in range(filter_shape[2]):
                    for c in range(channels_in):
                        x = torch.zeros(channels_in, *shape)
                        x[c, i::filter_shape[0], j::filter_shape[1], k::filter_shape[2]] = 1
                        operator_graph.append((x.numpy(), operator(x).numpy()))

        return operator_graph


    @staticmethod
    def _get_nbh_coordinates(pos_in_nbhs, channels_in, channels_out):
        assert pos_in_nbhs.shape[1:] == (4,)
        channels_prod = channels_in * channels_out
        nbh_coordinates = -2 * np.ones([pos_in_nbhs.shape[0], channels_prod*channels_out, 4])
        centroids = np.zeros([pos_in_nbhs.shape[0], channels_prod*channels_out, 4])

        if pos_in_nbhs.shape[0] > 0:
            centroids[:] = pos_in_nbhs[0]

        for c in range(channels_out):
            nbh_coordinates[:, channels_prod*c:channels_prod*(c+1), 0] = c

        for i in range(-1, 2):
            for j in range(-1, 2):
                for k in range(-1, 2):
                    if [i,j,k].count(0) >= 2:
                        rhs = pos_in_nbhs[:, 1:] + np.array([i, j, k])
                        t = channels_out * (i + 1) + channels_in * (j + 1) + (k + 1)
                        nbh_coordinates[:, t::channels_prod, 1:] = rhs.reshape(pos_in_nbhs.shape[0], 1, 3)
                        centroids[:, t::channels_prod, 1:] = pos_in_nbhs[:, 1:].reshape(pos_in_nbhs.shape[0], 1, 3)

        assert nbh_coordinates.shape[1:] == (channels_prod*channels_out, 4)
        return nbh_coordinates.reshape(-1, 4), centroids.reshape(-1, 4)


    @staticmethod
    def _remove_out_of_bounds_rows(nbh_coordinates, centroids, shape):
        mask0 = np.all(nbh_coordinates >=0, axis=1, keepdims=True).flatten()
        mask1 = nbh_coordinates[:,1] < shape[0]
        mask2 = nbh_coordinates[:,2] < shape[1]
        mask3 = nbh_coordinates[:,3] < shape[2]

        mask = mask0 & mask1 & mask2 & mask3
        return nbh_coordinates[mask].astype(int), centroids[mask].astype(int)


    @staticmethod
    def _get_1d_coordinates(positions, shape):
        assert positions.shape[1:] == (4,)
        coords_1d = positions[:,3] + shape[2]*positions[:,2] + shape[1] * shape[2] * positions[:,1] +  shape[0] * shape[1] * shape[2] * positions[:,0]
        return coords_1d.astype(int)


    @staticmethod
    def assemble_operator(operator, shape, channels_in=3, channels_out=9, filter_shape=3, Ω_dirichlet=None, column_wise=True):
        """
        Returns a sparse assembly of `operator`.

        Returns
        -------
        scipy.sparse.csc_matrix
        """
        if type(filter_shape) is int:
            filter_shape = [filter_shape, filter_shape, filter_shape]

        if Ω_dirichlet is not None:
            if column_wise:
                operator = FDMAssembly.apply_dirichlet_zero_columns_to_operator(operator, Ω_dirichlet)
            else:
                operator = FDMAssembly.apply_dirichlet_zero_rows_to_operator(operator, Ω_dirichlet)

        op_graph = FDMAssembly._get_graph(operator, shape, channels_in, filter_shape)
        col_indices = []
        row_indices = []
        values = []

        for x, y in op_graph:
            pos_in_nbhs = np.where(x)
            pos_in_nbhs = np.stack(pos_in_nbhs).transpose()

            nbh_3d, centroids = FDMAssembly._get_nbh_coordinates(pos_in_nbhs, channels_in, channels_out)
            nbh_3d, centroids = FDMAssembly._remove_out_of_bounds_rows(nbh_3d, centroids, shape)

            col_idx = FDMAssembly._get_1d_coordinates(centroids, shape)
            row_idx = FDMAssembly._get_1d_coordinates(nbh_3d, shape)

            col_indices.extend(col_idx)
            row_indices.extend(row_idx)

            vals = np.take(y.flatten(), row_idx, axis=0)
            values.extend(vals)

        return csc_matrix((values, (row_indices, col_indices)), shape=(channels_out*np.prod(shape), channels_in*np.prod(shape)))

In [None]:
#hide
from dl4to.datasets import BasicDataset
from dl4to.pde import FDMDerivatives

In [None]:
%%time
#hide
def J_mock(u, problem, use_forward_differences):
    h = problem.h
    u_x = FDMDerivatives.du_dx(u, h, use_forward_differences)
    u_y = FDMDerivatives.du_dy(u, h, use_forward_differences)
    u_z = FDMDerivatives.du_dz(u, h, use_forward_differences)
    return torch.cat([u_x, u_y, u_z], dim=0)


def test_that_the_assembled_J_on_the_ledge_problem_agrees_with_the_matrx_free_version_including_Dirichlet_BCs(resolution, channels_in=3, number=20):
    problem = BasicDataset(resolution=resolution).ledge()
    J = lambda u: J_mock(u, problem=problem, use_forward_differences=False)

    shape = problem.shape
    Ω_dirichlet = problem.Ω_dirichlet

    J_sparse = FDMAssembly.assemble_operator(operator=J, shape=shape, Ω_dirichlet=Ω_dirichlet)
    J_sparse_tensor = torch.tensor(J_sparse.todense())

    operator = FDMAssembly.apply_dirichlet_zero_columns_to_operator(J, Ω_dirichlet)

    for _ in range(number):
        x = torch.randn(channels_in, *shape)
        assert torch.allclose(J_sparse_tensor.mv(x.flatten()), operator(x).flatten(), atol=1e-4)


test_that_the_assembled_J_on_the_ledge_problem_agrees_with_the_matrx_free_version_including_Dirichlet_BCs(resolution=20)

CPU times: user 4.03 s, sys: 28.3 ms, total: 4.06 s
Wall time: 383 ms
