In [None]:
#default_exp pde

In [None]:
#exporti
import torch
import torch.autograd.functional as F

In [None]:
#hide
from nbdev.showdoc import show_doc

# FDM derivatives

In [None]:
#export
class FDMDerivatives():
    @staticmethod
    def du_dx_central(u, h):
        du = torch.zeros_like(u)
        du[:, 1:-1, :,:] = (u[:,  2:, :,:] - u[:, 0:-2, :,:]) / (2 * h[0])
        du[:,  0  , :,:] = (u[:,  1 , :,:] - u[:,  0  , :,:]) / h[0]
        du[:, -1  , :,:] = (u[:, -1 , :,:] - u[:, -2  , :,:]) / h[0]
        return du


    @staticmethod
    def du_dy_central(u, h):
        du = torch.zeros_like(u)
        du[:,:, 1:-1, :] = (u[:,:,  2:, :] - u[:,:, 0:-2, :]) / (2 * h[1])
        du[:,:,  0  , :] = (u[:,:,  1 , :] - u[:,:,  0  , :]) / h[1]
        du[:,:, -1  , :] = (u[:,:, -1 , :] - u[:,:, -2  , :]) / h[1]
        return du


    @staticmethod
    def du_dz_central(u, h):
        du = torch.zeros_like(u)
        du[:,:,:, 1:-1] = (u[:,:,:,  2:] - u[:,:,:, 0:-2]) / (2 * h[2])
        du[:,:,:,  0  ] = (u[:,:,:,  1 ] - u[:,:,:,  0  ]) / h[2]
        du[:,:,:, -1  ] = (u[:,:,:, -1 ] - u[:,:,:, -2  ]) / h[2]
        return du


    @staticmethod
    def du_dx_forward(u, h):
        du = torch.zeros_like(u)
        du[:, 0:-1,:,:] = (u[:,  1:,:,:] - u[:, 0:-1,:,:]) / h[0]
        du[:, -1  ,:,:] = (u[:, -1 ,:,:] - u[:, -2  ,:,:]) / h[0]
        return du


    @staticmethod
    def du_dy_forward(u, h):
        du = torch.zeros_like(u)
        du[:,:, 0:-1,:] = (u[:,:,  1:,:] - u[:,:, 0:-1,:]) / h[1]
        du[:,:, -1  ,:] = (u[:,:, -1 ,:] - u[:,:, -2  ,:]) / h[1]
        return du


    @staticmethod
    def du_dz_forward(u, h):
        du = torch.zeros_like(u)
        du[:,:,:, 0:-1] = (u[:,:,:,  1:] - u[:,:,:, 0:-1]) / h[2]
        du[:,:,:, -1  ] = (u[:,:,:, -1 ] - u[:,:,:, -2  ]) / h[2]
        return du


    @staticmethod
    def du_dx(u, h, use_forward_differences=True):
        assert len(u.shape) == 4
        assert u.shape[1] > 2
        if use_forward_differences:
            return FDMDerivatives.du_dx_forward(u, h)
        return FDMDerivatives.du_dx_central(u, h)


    @staticmethod
    def du_dy(u, h, use_forward_differences=True):
        assert len(u.shape) == 4
        assert u.shape[2] > 2
        if use_forward_differences:
            return FDMDerivatives.du_dy_forward(u, h)
        return FDMDerivatives.du_dy_central(u, h)


    @staticmethod
    def du_dz(u, h, use_forward_differences=True):
        assert len(u.shape) == 4
        assert u.shape[3] > 2
        if use_forward_differences:
            return FDMDerivatives.du_dz_forward(u, h)
        return FDMDerivatives.du_dz_central(u, h)

### Hardcoded analytical adjoints of numerical derivatives

In [None]:
#export
class FDMAdjointDerivatives():
    @staticmethod
    def du_dx_adj_for_a_sufficiently_large_number_of_voxels(ε, h):
        u = torch.zeros_like(ε)
        u[:,   0,:,:] = -(2 * ε[:,   0,:,:] + ε[:,   1,:,:]) / (2 * h[0])
        u[:,   1,:,:] =  (2 * ε[:,   0,:,:] - ε[:,   2,:,:]) / (2 * h[0])
        u[:,2:-2,:,:] =  (    ε[:,1:-3,:,:] - ε[:,3:-1,:,:]) / (2 * h[0])
        u[:,  -2,:,:] = -(2 * ε[:,  -1,:,:] - ε[:,  -3,:,:]) / (2 * h[0])
        u[:,  -1,:,:] =  (2 * ε[:,  -1,:,:] + ε[:,  -2,:,:]) / (2 * h[0])
        return u


    @staticmethod
    def du_dy_adj_for_a_sufficiently_large_number_of_voxels(ε, h):
        u = torch.zeros_like(ε)
        u[:,:,   0,:] = -(2 * ε[:,:,   0,:] + ε[:,:,   1,:]) / (2 * h[1])
        u[:,:,   1,:] =  (2 * ε[:,:,   0,:] - ε[:,:,   2,:]) / (2 * h[1])
        u[:,:,2:-2,:] =  (    ε[:,:,1:-3,:] - ε[:,:,3:-1,:]) / (2 * h[1])
        u[:,:,  -2,:] = -(2 * ε[:,:,  -1,:] - ε[:,:,  -3,:]) / (2 * h[1])
        u[:,:,  -1,:] =  (2 * ε[:,:,  -1,:] + ε[:,:,  -2,:]) / (2 * h[1])
        return u


    @staticmethod
    def du_dz_adj_for_a_sufficiently_large_number_of_voxels(ε, h):
        u = torch.zeros_like(ε)
        u[:,:,:,   0] = -(2 * ε[:,:,:,   0] + ε[:,:,:,   1]) / (2 * h[2])
        u[:,:,:,   1] =  (2 * ε[:,:,:,   0] - ε[:,:,:,   2]) / (2 * h[2])
        u[:,:,:,2:-2] =  (    ε[:,:,:,1:-3] - ε[:,:,:,3:-1]) / (2 * h[2])
        u[:,:,:,  -2] = -(2 * ε[:,:,:,  -1] - ε[:,:,:,  -3]) / (2 * h[2])
        u[:,:,:,  -1] =  (2 * ε[:,:,:,  -1] + ε[:,:,:,  -2]) / (2 * h[2])
        return u


    @staticmethod
    def du_dx_adj_for_a_sufficiently_small_number_of_voxels(ε, h):
        u = torch.zeros_like(ε)

        if u.shape[1] == 2:
            u[:, 0,:,:] = -(ε[:, 0,:,:] +  ε[:, 1,:,:]) / h[0]
            u[:, 1,:,:] = - u[:, 0,:,:]

        if u.shape[1] == 3:
            u[:,0,:,:] = -(2 * ε[:,0,:,:] + ε[:,1,:,:]) / (2 * h[0])
            u[:,1,:,:] =  (    ε[:,0,:,:] - ε[:,2,:,:]) /  h[0]
            u[:,2,:,:] =  (2 * ε[:,2,:,:] + ε[:,1,:,:]) / (2 * h[0])

        return u


    @staticmethod
    def du_dy_adj_for_a_sufficiently_small_number_of_voxels(ε, h):
        u = torch.zeros_like(ε)

        if u.shape[2] == 2:
            u[:,:, 0,:] = -(ε[:,:, 0,:] +  ε[:,:, 1,:]) / h[1]
            u[:,:, 1,:] = - u[:,:, 0,:]

        if u.shape[2] == 3:
            u[:,:,0,:] = -(2 * ε[:,:,0,:] + ε[:,:,1,:]) / (2 * h[1])
            u[:,:,1,:] =  (    ε[:,:,0,:] - ε[:,:,2,:]) /  h[1]
            u[:,:,2,:] =  (2 * ε[:,:,2,:] + ε[:,:,1,:]) / (2 * h[1])

        return u


    @staticmethod
    def du_dz_adj_for_a_sufficiently_small_number_of_voxels(ε, h):
        u = torch.zeros_like(ε)

        if u.shape[3] == 2:
            u[:,:,:, 0] = -(ε[:,:,:, 0] +  ε[:,:,:, 1]) / h[2]
            u[:,:,:, 1] = - u[:,:,:, 0]

        if u.shape[3] == 3:
            u[:,:,:,0] = -(2 * ε[:,:,:,0] + ε[:,:,:,1]) / (2 * h[2])
            u[:,:,:,1] =  (    ε[:,:,:,0] - ε[:,:,:,2]) /  h[2]
            u[:,:,:,2] =  (2 * ε[:,:,:,2] + ε[:,:,:,1]) / (2 * h[2])

        return u


    @staticmethod
    def du_dx_adj_forward(ε, h):
        u = torch.zeros_like(ε)
        u[:,   0   ,:,:] =  (                 - ε[:,      0,:,:]) / h[0]
        u[:,   1:-2,:,:] =  (ε[:,   0:-3,:,:] - ε[:,   1:-2,:,:]) / h[0]
        u[:,     -2,:,:] =  (ε[:,  -3,:,:] - ε[:,  -2,:,:] - ε[:,  -1,:,:]) / h[0]
        u[:,     -1,:,:] =  (ε[:,     -2,:,:] + ε[:,     -1,:,:]) / h[0]
        return u


    @staticmethod
    def du_dy_adj_forward(ε, h):
        u = torch.zeros_like(ε)
        u[:,:,   0   ,:] =  (                 - ε[:,:,      0,:]) / h[1]
        u[:,:,   1:-2,:] =  (ε[:,:,   0:-3,:] - ε[:,:,   1:-2,:]) / h[1]
        u[:,:,     -2,:] =  (ε[:,:,  -3,:] - ε[:,:,  -2,:] - ε[:,:,  -1,:]) / h[1]
        u[:,:,     -1,:] =  (ε[:,:,     -2,:] + ε[:,:,     -1,:]) / h[1]
        return u


    @staticmethod
    def du_dz_adj_forward(ε, h):
        u = torch.zeros_like(ε)
        u[:,:,:,   0   ] =  (                 - ε[:,:,:,      0]) / h[2]
        u[:,:,:,   1:-2] =  (ε[:,:,:,   0:-3] - ε[:,:,:,   1:-2]) / h[2]
        u[:,:,:,     -2] =  (ε[:,:,:,  -3] - ε[:,:,:,  -2] - ε[:,:,:,  -1]) / h[2]
        u[:,:,:,     -1] =  (ε[:,:,:,     -2] + ε[:,:,:,     -1]) / h[2]
        return u


    @staticmethod
    def du_dx_adj(ε, h, use_forward_differences=True):
        assert len(ε.shape) == 4
        if use_forward_differences:
            return FDMAdjointDerivatives.du_dx_adj_forward(ε, h)


        if ε.shape[1] > 3:
            return FDMAdjointDerivatives.du_dx_adj_for_a_sufficiently_large_number_of_voxels(ε, h)
        return FDMAdjointDerivatives.du_dx_adj_for_a_sufficiently_small_number_of_voxels(ε, h)


    @staticmethod
    def du_dy_adj(ε, h, use_forward_differences=True):
        assert len(ε.shape) == 4
        if use_forward_differences:
            return FDMAdjointDerivatives.du_dy_adj_forward(ε, h)


        if ε.shape[2] > 3:
            return FDMAdjointDerivatives.du_dy_adj_for_a_sufficiently_large_number_of_voxels(ε, h)
        return FDMAdjointDerivatives.du_dy_adj_for_a_sufficiently_small_number_of_voxels(ε, h)


    @staticmethod
    def du_dz_adj(ε, h, use_forward_differences=True):
        assert len(ε.shape) == 4
        if use_forward_differences:
            return FDMAdjointDerivatives.du_dz_adj_forward(ε, h)


        if ε.shape[3] > 3:
            return FDMAdjointDerivatives.du_dz_adj_for_a_sufficiently_large_number_of_voxels(ε, h)
        return FDMAdjointDerivatives.du_dz_adj_for_a_sufficiently_small_number_of_voxels(ε, h)

## Tests

In [None]:
import numpy as np
import hypothesis.strategies as st
import hypothesis.extra.numpy as npst
from hypothesis import given, settings

In [None]:
h = 3 * [1e-3]
atol = 1e-2

In [None]:
st_u_ε_shape = st.tuples(st.integers(min_value=6, max_value=6),
                         st.integers(min_value=3, max_value=32),
                         st.integers(min_value=3, max_value=32),
                         st.integers(min_value=3, max_value=32))
st_u_ε = npst.arrays(float, shape=st_u_ε_shape, elements=st.floats(-1e3, 1e3))

In [None]:
%%time

@given(u_ε=st_u_ε)
@settings(max_examples=10, deadline=None)
def test_that_adjoint_derivative_really_is_the_adjoint(u_ε):
    u = torch.tensor(u_ε[:3])
    ε = torch.tensor(u_ε[3:])
    forwDif = False

    assert torch.allclose(torch.dot(u.flatten(), FDMAdjointDerivatives.du_dx_adj(ε, h, forwDif).flatten()), torch.dot(ε.flatten(), FDMDerivatives.du_dx(u, h, forwDif).flatten()), atol=atol)
    assert torch.allclose(torch.dot(u.flatten(), FDMAdjointDerivatives.du_dy_adj(ε, h, forwDif).flatten()), torch.dot(ε.flatten(), FDMDerivatives.du_dy(u, h, forwDif).flatten()), atol=atol)
    assert torch.allclose(torch.dot(u.flatten(), FDMAdjointDerivatives.du_dz_adj(ε, h, forwDif).flatten()), torch.dot(ε.flatten(), FDMDerivatives.du_dz(u, h, forwDif).flatten()), atol=atol)

test_that_adjoint_derivative_really_is_the_adjoint()

Falsifying example: test_that_adjoint_derivative_really_is_the_adjoint(
    u_ε=array([[[[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]],
    
            [[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]],
    
            [[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]]],
    
    
           [[[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]],
    
            [[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]],
    
            [[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]]],
    
    
           [[[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]],
    
            [[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]],
    
            [[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]]],
    
    
           [[[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]],
    
            [[0., 0., 0.],
         

AttributeError: 'NoneType' object has no attribute 'flatten'

In [None]:
%%time

@given(u_ε=st_u_ε)
@settings(max_examples=10, deadline=None)
def test_that_adjoint_derivative_really_is_the_adjoint_for_forward_diffs(u_ε):
    u = torch.tensor(u_ε[:3])
    ε = torch.tensor(u_ε[3:])
    forwDif = True

    assert torch.allclose(torch.dot(u.flatten(), FDMAdjointDerivatives.du_dx_adj(ε, h, forwDif).flatten()), torch.dot(ε.flatten(), FDMDerivatives.du_dx(u, h, forwDif).flatten()), atol=atol)
    assert torch.allclose(torch.dot(u.flatten(), FDMAdjointDerivatives.du_dy_adj(ε, h, forwDif).flatten()), torch.dot(ε.flatten(), FDMDerivatives.du_dy(u, h, forwDif).flatten()), atol=atol)
    assert torch.allclose(torch.dot(u.flatten(), FDMAdjointDerivatives.du_dz_adj(ε, h, forwDif).flatten()), torch.dot(ε.flatten(), FDMDerivatives.du_dz(u, h, forwDif).flatten()), atol=atol)


test_that_adjoint_derivative_really_is_the_adjoint_for_forward_diffs()

CPU times: user 43.8 s, sys: 374 ms, total: 44.2 s
Wall time: 492 ms


In [None]:
st_x_shape = st.tuples(st.integers(min_value=1, max_value=1))
st_x = npst.arrays(float, shape=st_x_shape, elements=st.floats(-1e2, 1e2))

In [None]:
%%time

@given(x=st_x)
@settings(max_examples=10, deadline=None)
def test_that_fdm_derivative_and_torch_derivative_coincide_for_sin(x):
    x = torch.tensor(np.array([x]), requires_grad=True)
    y = torch.sin(x)
    y.backward()
    dy_dx_torch = x.grad

    h = 1e-4
    x_ = torch.tensor([x-h, x+h]).view(1, 2, 1, 1)
    x_.requires_grad_(True)
    y_ = torch.sin(x_)
    dy_dx_fdm = FDMDerivatives.du_dx(y_, 3 * [2*h])[0, 0]
    assert torch.allclose(dy_dx_torch, dy_dx_fdm, atol=atol)

    x_ = torch.tensor([x-h, x+h]).view(1, 1, 2, 1)
    x_.requires_grad_(True)
    y_ = torch.sin(x_)
    dy_dx_fdm = FDMDerivatives.du_dy(y_, 3 * [2*h])[0, 0]
    assert torch.allclose(dy_dx_torch, dy_dx_fdm, atol=atol)

    x_ = torch.tensor([x-h, x+h]).view(1, 1, 1, 2)
    x_.requires_grad_(True)
    y_ = torch.sin(x_)
    dy_dx_fdm = FDMDerivatives.du_dz(y_, 3 * [2*h])[0, 0]
    assert torch.allclose(dy_dx_torch, dy_dx_fdm, atol=atol)


test_that_fdm_derivative_and_torch_derivative_coincide_for_sin()

Falsifying example: test_that_fdm_derivative_and_torch_derivative_coincide_for_sin(
    x=array([0.]),
)


AssertionError: 

In [None]:
%%time

@given(x=st_x)
@settings(max_examples=10, deadline=None)
def test_that_fdm_derivative_and_torch_derivative_coincide_for_exp(x):
    x = torch.tensor(np.array([x]), requires_grad=True)
    y = torch.exp(x)
    y.backward()
    dy_dx_torch = x.grad

    h = 1e-4
    x_ = torch.tensor([x-h, x+h]).view(1, 2, 1, 1)
    x_.requires_grad_(True)
    y_ = torch.exp(x_)
    dy_dx_fdm = FDMDerivatives.du_dx(y_, 3 * [2*h])[0, 0]
    assert torch.allclose(dy_dx_torch, dy_dx_fdm, atol=atol)

    x_ = torch.tensor([x-h, x+h]).view(1, 1, 2, 1)
    x_.requires_grad_(True)
    y_ = torch.exp(x_)
    dy_dx_fdm = FDMDerivatives.du_dy(y_, 3 * [2*h])[0, 0]
    assert torch.allclose(dy_dx_torch, dy_dx_fdm, atol=atol)

    x_ = torch.tensor([x-h, x+h]).view(1, 1, 1, 2)
    x_.requires_grad_(True)
    y_ = torch.exp(x_)
    dy_dx_fdm = FDMDerivatives.du_dz(y_, 3 * [2*h])[0, 0]
    assert torch.allclose(dy_dx_torch, dy_dx_fdm, atol=atol)


test_that_fdm_derivative_and_torch_derivative_coincide_for_exp()

Falsifying example: test_that_fdm_derivative_and_torch_derivative_coincide_for_exp(
    x=array([0.]),
)


AssertionError: 