In [None]:
#default_exp pde

In [None]:
#export
import torch
import numpy as np
import time
import importlib.util
import warnings
from scipy.sparse.linalg import factorized, use_solver, spsolve
from scipy.sparse import csc_matrix
from typing import Callable

use_solver(assumeSortedIndices=True)

In [None]:
#hide
from nbdev.showdoc import show_doc

# Linear solvers

In [None]:
#export
class AutogradLinearSolver(torch.autograd.Function):
    @staticmethod
    def forward(ctx, θ, A_op, b, solver, A_mat, factorize=True):
        """
        In the forward pass we receive a tensor containing the input and return
        a tensor containing the output. `ctx` is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the `ctx.save_for_backward` method.

        Returns
        torch.Tensor
        """
        np_b = b.cpu().numpy()

        if factorize:
            solver = factorized(A_mat)
            x = solver(np_b)
        else:
            x = solver(A_mat, np_b)

        x = torch.from_numpy(x.astype(np_b.dtype))
        ctx.save_for_backward(θ, x, b)
        ctx.intermediate = (A_mat, solver, A_op, factorize)
        return x


    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.

        Returns
        ----------
        (torch.Tensor, None, None, None, None, None)
        """
        torch.set_grad_enabled(True)
        θ, x, b = ctx.saved_tensors
        A_mat, solver, A_op, factorize = ctx.intermediate
        θ = θ.clone().detach()
        θ.requires_grad_(True)

        with torch.no_grad():
            flat_np_grad_output = grad_output.flatten().cpu().numpy()

            if factorize:
                y = solver(flat_np_grad_output)
            else:
                y = solver(A_mat, flat_np_grad_output)

            y = torch.from_numpy(y).clone().requires_grad_(False)
            x = x.clone().requires_grad_(False)

        expr = torch.sum(y * (b - A_op(x, θ).flatten()))
        grad_input = torch.autograd.grad(expr, θ)
        return grad_input[0], None, None, None, None, None

In [None]:
#export
class LinearSolver():
    """
    A parent class for linear solvers that are used to solve the linear system that solves the PDE.
    We compute the gradients via `torch.autograd` and with the adjoint method in the backwards pass.
    """
    def __init__(self, 
                 factorize:bool=True # Whether the system matrix should be factorized.
                ):
        self.autograd_linear_solver = AutogradLinearSolver.apply
        self.factorize = factorize


    def _solver(self):
        raise NotImplementedError("Solver must be overridden.")


    def __call__(self, 
                 θ:torch.Tensor, # The density for which the PDE is solved.
                 A_op:Callable[[torch.Tensor, torch.Tensor], torch.Tensor], # A function that takes `u` and `θ` as input and outputs the right hand side of the PDE. In other words, this is an operator representing the system matrix.
                 b:torch.Tensor, # A flattened version of the right side of the PDE.
                 A_mat:csc_matrix # The system matrix in sparse format.
                ):
        """
        Solves the PDE for the density `θ`. Returns the solution as a `torch.Tensor` object.
        """
        x = self.autograd_linear_solver(θ, A_op, b, self._solver(), A_mat, self.factorize)
        return x

In [None]:
show_doc(LinearSolver.__call__)

<h4 id="LinearSolver.__call__" class="doc_header"><code>LinearSolver.__call__</code><a href="__main__.py#L18" class="source_link" style="float:right">[source]</a></h4>

> <code>LinearSolver.__call__</code>(**`θ`**:`Tensor`, **`A_op`**:`Callable`\[`Tensor`, `Tensor`, `Tensor`\], **`b`**:`Tensor`, **`A_mat`**:`csc_matrix`)

Solves the PDE for the density `θ`. Returns the solution as a `torch.Tensor` object.

||Type|Default|Details|
|---|---|---|---|
|**`θ`**|`Tensor`||The density for which the PDE is solved.|
|**`A_op`**|`typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`||A function that takes `u` and `theta` as input and outputs the right hand side of the PDE. In other words, this is an operator representing the system matrix.|
|**`b`**|`Tensor`||A flattened version of the right side of the PDE.|
|**`A_mat`**|`csc_matrix`||The system matrix in sparse format.|


In [None]:
#export
class SparseLinearSolver(LinearSolver):
    """
    A sparse linear solver implementation based on the `scipy.sparse.linalg.solve()` method that is used to solve the PDE of linear elasticity.
    """
    def __init__(self, 
                 use_umfpack:bool=True, # Whether to use umfpack. If false, then the LU solver from `scipy.sparse` is used, which is usually slower.
                 factorize:bool=False # Whether the system matrix should be factorized.
                ):
        if use_umfpack or factorize:
            if importlib.util.find_spec('scikits') is None:
                warnings.warn("The package scikits.umfpack is not installed.Therefore, the LU solver from scipy.sparse is used, which is usually slower.")

        self.use_umfpack = use_umfpack
        super().__init__(factorize)


    def _solver(self):
        return lambda A, b: spsolve(A, b, use_umfpack=self.use_umfpack)

In [None]:
#hide
import numpy as np
import torch
import matplotlib.pyplot as plt

In [None]:
#hide
def get_operator_and_b():
    n = 4
    θ = torch.rand(1,n,n,n, requires_grad=True)
    θ_triple = torch.cat([θ, θ, θ])

    def A_op(x, θ):
        x = x.view(3,n,n,n)
        θ_triple = torch.cat([θ, θ, θ])
        return θ_triple*x

    A_mat = np.diag(θ_triple.flatten().detach().numpy()).astype(np.float64)

    b = torch.ones(3,n,n,n)

    return A_op, A_mat, b, θ, θ_triple, n

In [None]:
%%time
#hide

def test_that_we_can_solve_a_system():
    A_op, A_mat, b, θ, θ_triple, n = get_operator_and_b()

    solver = SparseLinearSolver()
    x = solver(θ=θ, A_op=A_op, b=b.flatten(), A_mat=A_mat)
    x = x.view(3, n, n, n)
    assert torch.allclose(A_op(x, θ), b)
    assert torch.allclose(x, 1/θ_triple)


test_that_we_can_solve_a_system()

CPU times: user 3.87 ms, sys: 8.62 ms, total: 12.5 ms
Wall time: 30 ms


  warn('spsolve requires A be CSC or CSR matrix format',


In [None]:
%%time
#hide
#slow
def test_that_we_can_differentiate_solution(verbose):
    A_op, A_mat, b, θ, θ_triple, n = get_operator_and_b()
    solver = SparseLinearSolver()

    optimizer = torch.optim.Adam([θ],lr=1e-1)
    losses = []

    for _ in range(20):
        θ_triple = torch.cat([θ, θ, θ])
        A_mat = np.diag(θ_triple.flatten().detach().numpy()).astype(np.float64)
        x = solver(θ=θ, A_op=A_op, b=b.flatten(), A_mat=A_mat)
        x = x.view(3, n, n, n)
        loss = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if verbose:
        plt.plot(losses)
        plt.show()
        print(losses[-1])

    assert losses[-1] < .2


test_that_we_can_differentiate_solution(verbose=False)

CPU times: user 64.5 ms, sys: 5 ms, total: 69.5 ms
Wall time: 67.7 ms
