# Inversion benchmark - centroid distance

Define a set of points around the boundary of a lithology then use their distance to the centroid as an inversion variable. It only has one degree of freedom (_expansion_ and _contraction_) but with enough points this should not matter.

> **Problem statement**: We need a way to relate the spatial arrangement of nodes to lithology that is _differentiable_.

In [None]:
import numpy as np
from time import clock
from conduction import ConductionND
from conduction.inversion import InvObservation, InvPrior
from conduction import InversionND
import matplotlib.pyplot as plt
%matplotlib inline

from petsc4py import PETSc
from mpi4py import MPI
comm = MPI.COMM_WORLD

from scipy import ndimage

In [None]:
minX, maxX = 0.0, 1000.0
minY, maxY = -1000.0, 0.0
nx, ny = 10, 10
n = nx*ny

mesh = ConductionND((minX, minY), (maxX, maxY), (nx,ny))

# BCs
mesh.boundary_condition('maxY', 298.0, flux=False)
mesh.boundary_condition('minY', 1e3, flux=True)


# Global lithology

lithology = np.zeros((ny,nx), dtype='int32')
lithology[3:7,:] = 1
lithology[7:,:]  = 2

lithology_ratios = np.empty_like(lithology, dtype=np.float)

ratio0 = 1.0/(lithology == 0).sum()
ratio1 = 1.0/(lithology == 1).sum()
ratio2 = 1.0/(lithology == 2).sum()

lithology_ratios.fill(ratio0)
lithology_ratios[3:7,:] = ratio1
lithology_ratios[7:, :] = ratio2


# Need to slice this bad boy up: Local lithology

(minI, maxI), (minJ, maxJ) = mesh.dm.getGhostRanges()
lithology = lithology[minJ:maxJ, minI:maxI]

In [None]:
inv = InversionND(lithology.flatten(), mesh)

k = np.array([3.5, 2.0, 3.2])
H = np.array([0.5e-6, 1e-6, 2e-6])
a = np.array([0.3, 0.3, 0.3])
q0 = 35e-3
sigma_q0 = 5e-3


# Inversion variables
x = np.hstack([k, H, a, [q0], lithology_ratios.flatten()])
dx = x*0.01
dx[:10] = 0.0

## Find boundary points

We only have to do this once to find the point coordinates.

Lithology transitions are sharp, and the boundary points must be located at "imaginary" nodes between the mesh.

1. Use the sobel filter to identify lithology transitions (thickness of two cells)
2. Iterate through lithologies, reducing sobel points to within the lithology and outside.
3. Find the boundary (centroid) between points within lithology and neighbours that are within sobel filter.

In [None]:
sobel_bands = []
for i in xrange(inv.mesh.dim):
    sobel_bands.append( ndimage.sobel(lithology, axis=i) )
    
sobel_bands = np.linalg.norm(sobel_bands, axis=0)

bands = np.nonzero(sobel_bands)
node_bands = np.nonzero(sobel_bands.ravel())[0]

In [None]:
neighbours = mesh.find_neighbours()

node_bands_mask = np.zeros(mesh.nn, dtype=bool)
node_bands_mask[node_bands] = True

bpoints = [[] for _ in range(len(inv.lithology_index))]

for l, lith in enumerate(inv.lithology_index):
    lith_mask = lith == inv.lithology
    lith_band  = np.nonzero(np.logical_and(lith_mask, node_bands_mask))[0]
    other_band = np.nonzero(np.logical_and(~lith_mask, node_bands_mask))[0]
    for node in lith_band:
        # iterate through points along the boundary
        neighbour_nodes = neighbours[node]
        neighbour_set = set(neighbour_nodes).intersection(other_band)
        for nnode in neighbour_set:
            # iterate through neighbouring points along the boundary
            bpt = 0.5*(mesh.coords[node] + mesh.coords[nnode])
            bpoints[l].append( bpt )

In [None]:
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.imshow(lithology, extent=mesh.extent)
for l, lith in enumerate(inv.lithology_index):
    bpts = np.vstack(bpoints[l])
    ax1.scatter(bpts[:,0], bpts[:,1], alpha=0.5, label="lithology {}".format(lith))
ax1.legend(bbox_to_anchor=(1.5,1))

## Map lithologies within hull of boundary points

The boundary points are going to deform so we need map the volume contained by the boundary points to lithologies. We need to ensure nodes at the BCs remain fixed.

1. Connect up the boundary points (and BC nodes)
2. Mask all nodes contained within the hull of boundary points using a flood-fill algorithm
3. Repeat for every lithology

In [None]:
def vfill(self):
    """
    Flood-fill algorithm for the vertical axes
    """
    def query_nearest(l):
        layer_mask.fill(0)

        zq = spl[l].ev(xq, yq)
        d, idx = tree.query(np.column_stack([xq, yq, zq]))
        layer_mask.flat[idx] = True

        return np.where(layer_mask)

    tree = self.ndinterp.tree
    layer_voxel = np.zeros_like(lithology)
    layer_mask = np.zeros_like(lithology, dtype=bool)
    
    nl = len(self.lithology_index)

    for l in xrange(nl):
        i0, j0 = query_nearest(l)

        for i in xrange(i0.size):
            layer_voxel[:i0[i], j0[i]] = l

        print("mapped layer {}".format(l))

In [None]:
def forward_model(x, nbpts, self):
    nl = len(self.lithology_index)
    H = x[:nl] # just H for demonstration purposes
    q0 = x[nl]
    bpts = np.split(x[nl+1:], nbpts) # nbpts are integers specifying how many points belong to each lithology in x
    
    lithology = self.lithology.reshape(self.mesh.n)
    qc = []
    
    for l, lith in enumerate(self.lithology_index):
        mask = lithology == lith
        nrow = mask.sum(axis=0)
        dz = nrow*self.grid_delta[-1]
#         z0 = bpts[l]
#         qc = H[l]*(z0 - z1)
        qc.append( H[l]*dz )
    
    qs = np.vstack(qc).sum(axis=0)
    return np.sum((qs - qobs)**2/sigma_qobs**2)

def tangent_linear(x, dx, nbpts, self):
    nl = len(self.lithology_index)
    H = x[:nl]
    q0 = x[n]
    dH = dx[:nl]
    dq0 = dx[n]
    
    bpts = np.split(x[nl+1:], nbpts)
    dbpts = np.split(dx[nl+1:], nbpts)
    
    lithology = self.lithology.reshape(self.mesh.n)
    qc = []
    
    for l, lith in enumerate(self.lithology_index):
        nrow = (lithology == lith).sum(axis=0)
        dz = nrow*self.grid_delta[-1]
        
        qc.append( H[l]*dz )
    
    qs = np.vstack(qc).sum(axis=0)
    cost = np.sum((qs - qobs)**2/sigma_qobs**2)

In [None]:
qobs = np.ones(nx)
sigma_qobs = qobs / 10.
