# Notebook implementing fast light field deconvolution, and motion-aware deconvolution

In [2]:
from __future__ import print_function

import os, sys, time, warnings
import math
import numpy as np
import cProfile, pstats
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import scipy.ndimage, scipy.optimize, scipy.io
from scipy.optimize import Bounds
from joblib import Parallel, delayed
from skimage.transform import PiecewiseAffineTransform, warp

from hmatrix import HMatrix, LoadMatrix
import jutils as util
import lfdeconv, projector, lfimage

# I don't know if these are necessary, but it has been suggested that low-level threading
# does not interact well with the joblib Parallel feature.
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_DYNAMIC'] = 'FALSE'

# My code expects these to start as None, and I want to do this right at the top
# to minimize the rish of accidentally resetting them to None after I've spent hours
# accumulating valuable data in them!
shiftHistoryJoint = None
shiftHistoryRaw = None
shiftHistoryNaive = None

In [3]:
# Select the PSF matrix we will use in this notebook
if False:
    # This is the definitive, full PSF matrix
    matPath = 'PSFmatrix/PSFmatrix_M22.2NA0.5MLPitch125fml3125from-110to110zspacing4Nnum19lambda520n1.33.mat'
elif True:
    # This is a smaller one (so that things don't take forever to run)
    warnings.warn('WARNING: Switched to faster matrix for testing')
    matPath = 'PSFmatrix/PSFmatrix_M40NA0.95MLPitch150fml3000from-26to0zspacing2Nnum15lambda520n1.0.mat'
elif True:
    # This closer-spaced one is useful for focusing on native-focal-plane artefacts in flow analysis
    warnings.warn('WARNING: Switched to faster and closer-spaced matrix for testing')
    matPath = 'PSFmatrix/PSFmatrix_M40NA0.95MLPitch150fml3000from-13to0zspacing0.5Nnum15lambda520n1.0.mat'   

# Load the matrix information
_HPathFormat, _HtPathFormat, _HReducedShape, _HtReducedShape = LoadMatrix(matPath)

# Load an input image
inputImage = lfimage.LoadLightFieldTiff('Data/02_Rectified/exampleData/20131219WORM2_small_full_neg_X1_N15_cropped_uncompressed.tif')

# Solve for flow field (single-plane toy example)

In [None]:
# Generate two identical images of the same synthetic object,
# which for now consists of a cloud of random gaussian spots
from scipy.ndimage.filters import gaussian_filter
if False:
    numSpots = 100
    imageSize = 240
    sigma = 8
    controlPointSpacing = 30    
elif False:
    numSpots = 400
    imageSize = 120
    sigma = 2
    controlPointSpacing = 30
else:
    numSpots = 1000
    imageSize = 180
    sigma = 2
    controlPointSpacing = 30
syntheticImageExtendSize = 30

syntheticObjectExt = np.zeros((1, imageSize+syntheticImageExtendSize, imageSize))
syntheticObjectExt[0, (np.random.random(numSpots)*syntheticObjectExt.shape[1]).astype('int'), \
                      (np.random.random(numSpots)*syntheticObjectExt.shape[2]).astype('int')] = 1
syntheticObjectExt = gaussian_filter(syntheticObjectExt, sigma=(0,sigma,sigma))
plt.imshow(syntheticObjectExt[0])

In [None]:
# Set up the PSF that we will use

# First check we're using the expected PSF - the plane choices used here are intended to work with this PSF.
assert(matPath == 'PSFmatrix/PSFmatrix_M40NA0.95MLPitch150fml3000from-13to0zspacing0.5Nnum15lambda520n1.0.mat')

zPlaneToModel = _H.shape[0]-1   # Modelling native focal plane
zPlaneToModel = 7   # Modelling some way from the native focal plane, which should perform fairly well
zPlaneToModel = _H.shape[0]-3   # Modelling close to native focal plane. This has artefacts - prev one is fairly artefact-free
zPlaneToModel = _H.shape[0]-2


pivHMatrix = HMatrix(_HPathFormat, _HtPathFormat, _HReducedShape, numZ=1, zStart=zPlaneToModel)

In [None]:
if False:
    shiftType = 'piv'
    source = 'synthetic'
    actualImageExtendSize = syntheticImageExtendSize
    # Allowing an x search range is fairer, but it makes little difference for vertical flow
    xMotionPermitted = False
    xSearchRange = 0
    ySearchRange = 10
else:
    shiftType = 'piv-zeroedge'
    source = 'piv'
    actualImageExtendSize = 0
    xMotionPermitted = True
    xSearchRange = 8
    ySearchRange = 8


def forwardProjectACC_PIV(hMatrix, obj, shiftDescription):
    # Compute the AB images obtained from the single object we are provided with
    # (with the B image being of the object shifted by shiftYX).
    # We give each image half the intensity in order to conserve energy.
    dualObject = np.tile(obj[:,np.newaxis,:,:] / 2.0, (1,2,1,1))
    dualObject[:,1,:,:] = ShiftObject(dualObject[:,1,:,:], shiftDescription)
    return forwardProjectACC(hMatrix, dualObject, logPrint=False, progress=None)

def dualBackwardProjectACC_PIV(hMatrix, dualProjection, shiftDescription):
    # Compute the reverse transform given the AB images (B image shifted by shiftYX).
    # First we do the reverse transformation on both images
    dualObject = backwardProjectACC(hMatrix, dualProjection, logPrint=False, progress=None)
    # Now we reverse the shift on the B object
    dualObject[:,1,:,:] = ShiftObject(dualObject[:,1,:,:], -shiftDescription)
    # Now, ideally the objects would match, but of course in practice there will be discrepancies,
    # especially if we are not using the correct shiftDescription.
    # To make the operation match the transpose of the forward operation,
    # we add the two objects and divide by 2 here
    return dualObject

def fusedBackwardProjectACC_PIV(hMatrix, dualProjection, shiftDescription):
    dualObject = dualBackwardProjectACC_PIV(hMatrix, dualProjection, shiftDescription)
    result = np.sum(dualObject, axis=1) / 2.0     # Merge the two backprojection
    return result

def deconvRL_PIV_OLD(hMatrix, imageAB, maxIter, Xguess, shiftDescription):
    # I believed this to be the RL algorithm in the way I have written it in the past.
    # However, this gives different results to Prevedel's implementation
    # (mine seems to converge more slowly).
    # TODO: I should look into this and see if I've just made a mistake or if they are actually different.
    
    # Xguess is our single combined guess of the object
    Xguess = Xguess.copy()    # Because we will be updating it, and caller may not always be expecting that
    for i in tqdm(range(maxIter), desc='RL deconv'):
        t0 = time.time()
        relativeBlurDual = imageAB / forwardProjectACC_PIV(hMatrix, Xguess, shiftDescription)
        Xguess *= fusedBackwardProjectACC_PIV(hMatrix, relativeBlurDual, shiftDescription)
        Xguess[np.where(np.isnan(Xguess))] = 0
        t1 = time.time() - t0
    return Xguess

def deconvRL_PIV(hMatrix, imageAB, maxIter, shiftDescription):
    # Note:
    #  Htf is the *initial* backprojection of the camera image
    #  Xguess is the initial guess for the object
    Htf = fusedBackwardProjectACC_PIV(hMatrix, imageAB, shiftDescription)
    Xguess = Htf.copy()
    for i in tqdm(range(maxIter), desc='RL deconv'):
        t0 = time.time()
        HXguess = forwardProjectACC_PIV(hMatrix, Xguess, shiftDescription)
        HXguessBack = fusedBackwardProjectACC_PIV(hMatrix, HXguess, shiftDescription)
        errorBack = Htf / HXguessBack
        Xguess = Xguess * errorBack
        Xguess[np.where(np.isnan(Xguess))] = 0
        t1 = time.time() - t0
    return Xguess

def RollNoninteger(obj, amount, axis=0):
    intAmount = math.floor(amount)
    frac = amount - intAmount
    result1 = np.roll(obj, intAmount, axis=axis)
    result2 = np.roll(obj, intAmount+1, axis=axis)
    return result1 * (1-frac) + result2 * frac


# Some replacement functions to use for testing (effective PSF is a delta function, 1:1 mapping from image to object)
def forwardProjectTrivial(hMatrix, obj, shiftDescription):
    # Compute the AB images obtained from the single object we are provided with
    # (with the B image being of the object shifted by shiftYX).
    # We give each image half the intensity in order to conserve energy.
    dualObject = np.tile(obj[:,np.newaxis,:,:] / 2.0, (1,2,1,1))
    dualObject[:,1,:,:] = ShiftObject(dualObject[:,1,:,:], shiftDescription)
    return dualObject[0]

def dualBackwardProjectTrivial(hMatrix, dualProjection, shiftDescription):
    dualObject = dualProjection[np.newaxis].copy()
    dualObject[:,1,:,:] = ShiftObject(dualObject[:,1,:,:], -shiftDescription)
    return dualObject

def fusedBackwardProjectTrivial(hMatrix, dualProjection, shiftDescription):
    dualObject = dualBackwardProjectTrivial(hMatrix, dualProjection, shiftDescription)
    result = np.sum(dualObject, axis=1) / 2.0     # Merge the two backprojection
    return result

def deconvRLTrivial(hMatrix, imageAB, maxIter, shiftDescription):
    # Note:
    #  Htf is the *initial* backprojection of the camera image
    #  Xguess is the initial guess for the object
    return fusedBackwardProjectTrivial(hMatrix, imageAB, shiftDescription)

In [None]:
if (shiftType == 'uniform') or (shiftType == 'uniformSK'):
    if shiftType == 'uniform':
        def ShiftObject(obj, shiftYX):
            # Transform a 3D object according to the flow information provided in shiftDescription
            # For now I just consider a uniform translation in xy
            # 
            # TODO: We need to worry about conserving energy during the shift. 
            # For now I will do a circular shift in order to avoid having to worry about this!
            result = RollNoninteger(obj, shiftYX[0,0], axis=len(obj.shape)-2)
            return RollNoninteger(result, shiftYX[0,1], axis=len(obj.shape)-1)
    else:
        # A lot of code duplication here, but it's just an experiment for now
        def ShiftObject(obj, shiftYX):
            # Generate control points in the corners of the image
            src_cols = np.arange(0, obj.shape[-1]+1, obj.shape[-1])
            src_rows = np.arange(0, obj.shape[-2]+1, obj.shape[-2])
            src_rows, src_cols = np.meshgrid(src_rows, src_cols)
            src = np.dstack([src_cols.flat, src_rows.flat])[0]
            dst = src + shiftYX[0]
            tform = PiecewiseAffineTransform()
            tform.estimate(src, dst)
            # Annoyingly, skimage insists that a float input is scaled between 0 and 1, so I must rescale here
            maxVal = np.max(np.abs(obj))
            if len(obj.shape) == 3:
                result = np.zeros(obj.shape)
                for cc in range(obj.shape[0]):
                    result[cc] = warp(obj[cc]/maxVal, tform, mode='edge') * maxVal
                return result
            else:
                return warp(obj/maxVal, tform, mode='edge') * maxVal
    
    def ExampleShiftDescriptionForObject(obj):
        return np.array([[-10, 20]])
    
    def VelocityShapeForObject(obj):
        return (2,)

    def IWCentresForObject(obj):
        return np.array([[int(obj.shape[-2]/2), int(obj.shape[-1]/2)]])

else:
    # Arbitrary motion described in terms of an array of control points at IWCentresForObject
    assert((shiftType == 'piv') or (shiftType == 'piv-zeroedge'))
    def IWCentresForObject(obj, st=shiftType):
        startPos = 0
        # Reusing the code from the skimage example, since that actualy does what we need:
        if st == 'piv-zeroedge':
            src_cols = np.arange(controlPointSpacing, obj.shape[-1], controlPointSpacing)
            src_rows = np.arange(controlPointSpacing, obj.shape[-2]-actualImageExtendSize, controlPointSpacing)
        else:
            src_cols = np.arange(startPos, obj.shape[-1]+1, controlPointSpacing)
            src_rows = np.arange(startPos, obj.shape[-2]+1-actualImageExtendSize, controlPointSpacing)
        src_rows, src_cols = np.meshgrid(src_rows, src_cols)
        return np.dstack([src_cols.flat, src_rows.flat])[0]

    def VelocityShapeForObject(obj):
        return IWCentresForObject(obj).shape
    
    def ExampleShiftDescriptionForObject(obj):
        peakVelocity = 7
        iwPos = IWCentresForObject(obj)
        shiftDescription = np.zeros(VelocityShapeForObject(obj))
        width = obj.shape[-1]
        for n in range(iwPos.shape[0]):
            quadraticProfile = ((width/2)**2 - (iwPos[n,0]-width/2)**2)
            quadraticProfile = quadraticProfile / ((width/2)**2) * peakVelocity
            shiftDescription[n,1] = quadraticProfile
        if xMotionPermitted:
            return shiftDescription
        else:
            return shiftDescription[:,1:2]

    def ExtraDuplicateRow(shifts, add=None):
        assert(len(shifts.shape) == 2)
        rowLength = int(np.sqrt(shifts.shape[0]))
        shifts = np.reshape(shifts, (rowLength, rowLength, shifts.shape[1]))
        toAppend = shifts[:,-1:,:].copy()
        if add is not None:
            toAppend += add
        result = np.append(shifts, toAppend, axis=1)
        return result.reshape(result.shape[0]*result.shape[1], result.shape[2])

    def AddZeroEdgePadding(obj, src, shiftYX):
        paddedSrc = IWCentresForObject(obj, st='piv')
        paddedShifts = np.zeros(paddedSrc.shape)
        for i in range(src.shape[0]):
            match = False
            for j in range(paddedSrc.shape[0]):
                if (src[i] == paddedSrc[j]).all():
                    match = True
                    paddedShifts[j] = shiftYX[i]
            assert(match)
        return paddedSrc, paddedShifts
        
    def ShiftObject(obj, shiftYX):
        # Transform a 3D object according to the flow information provided in shiftDescription
        # I use a piecewise affine transformation that should approximately correspond to
        # what I use for PIV analysis
        src = IWCentresForObject(obj)
        if (src.shape[0] != shiftYX.shape[0]):
            print(src.shape, shiftYX.shape, obj.shape)
            assert(src.shape[0] == shiftYX.shape[0])
            
        if (shiftType == 'piv-zeroedge'):
            (src, shiftYX) = AddZeroEdgePadding(obj, src, shiftYX)
        
        if (actualImageExtendSize > 0):
            src = ExtraDuplicateRow(src, add=np.array([0, actualImageExtendSize]))
            if xMotionPermitted:
                dst = src + ExtraDuplicateRow(shiftYX)
            else:
                dst = src.copy().astype(shiftYX.dtype)
                dst[:,1] = dst[:,1] + ExtraDuplicateRow(shiftYX)[:,0]
        else:
            dst = src.copy().astype(shiftYX.dtype) + shiftYX
            
        tform = PiecewiseAffineTransform()
        tform.estimate(src, dst)
        # Annoyingly, skimage insists that a float input is scaled between 0 and 1, so I must rescale here
        maxVal = np.max(np.abs(obj))
        if len(obj.shape) == 3:
            result = np.zeros(obj.shape)
            for cc in range(obj.shape[0]):
                result[cc] = warp(obj[cc]/maxVal, tform, mode='edge') * maxVal
            return result
        else:
            assert(len(obj.shape) == 2)
            return warp(obj/maxVal, tform, mode='edge') * maxVal

In [None]:
if source == 'synthetic':
    # Generate a synthetic shift in the B image
    dualObject = np.tile(syntheticObjectExt[:,np.newaxis,:,:], (1,2,1,1)) *1e3#* 1e7
    if False:
        warnings.warn('Loading previously-saved dualObject')
        dualObject = np.load('dualObject5.npy')
    
    shiftDescription = ExampleShiftDescriptionForObject(dualObject)
    dualObject[:,1,:,:] = ShiftObject(dualObject[:,1,:,:], shiftDescription)

    # Since I am only using a local minimizer, we need to start with a decent guess as to the flow.
    # I think that's ok though: we should have that from a PIV estimate on the with-artefacts AB images
    #initialShiftGuess = np.zeros(VelocityShapeForObject(dualObject))
    initialShiftGuess = shiftDescription + np.random.random(shiftDescription.shape) * 4.0
else:
    assert(source == 'piv')
    pivImagePair = tifffile.imread('piv-raw-data/038298.tif')[24:26,:15*20,:15*16].astype('float64')
    # Note: frames 57-58 (wrong pair) would be an option to investigate bigger motion (~16px) with imperfect AB matches
    #              64-65 (correct pair) are another example of small movement (0-3px)
    dualObject = pivImagePair[np.newaxis]
    # For now, I just guess an initial shift of zero
    shiftDescription = np.zeros(VelocityShapeForObject(dualObject)).astype('float64')
    initialShiftGuess = shiftDescription.copy()
    
    
lb = []
ub = []
if xMotionPermitted:
    for n in range(shiftDescription.shape[0]):
        lb.extend([shiftDescription[n,0]-xSearchRange, shiftDescription[n,1]-ySearchRange])
        ub.extend([shiftDescription[n,0]+xSearchRange, shiftDescription[n,1]+ySearchRange])
else:
    for n in range(shiftDescription.shape[0]):
        lb.extend([shiftDescription[n,0]-ySearchRange])
        ub.extend([shiftDescription[n,0]+ySearchRange])
shiftSearchBounds = scipy.optimize.Bounds(lb, ub, True)

plt.subplot(1, 2, 1)
plt.imshow(dualObject[0,0])
plt.subplot(1, 2, 2)
plt.imshow(dualObject[0,1])
plt.show()

In [None]:
# Code used for investigations in which I directly warp the input object/images,
# without any use of light field PSFs and deconvolution

def ScoreShift2(candidateShiftYX, method, imageAB, hMatrix=None, shiftHistory=None, scaling=1.0, log=True, comparator=None, maxIter=8):
    return ScoreShift3(candidateShiftYX, method, imageAB, hMatrix, shiftHistory, scaling, log, comparator, maxIter=maxIter)[0]

def ScoreShift3(candidateShiftYX, method, imageAB, hMatrix=None, shiftHistory=None, scaling=1.0, log=True, comparator=None, maxIter=8):
    # Our input parameters get flattened, so we need to reshape them to Nx2 like my code is expecting
    # 'scaling' is useful for optimizers that insist on initial very small step sizes
    if xMotionPermitted:
        candidateShiftYX = candidateShiftYX.reshape(int(candidateShiftYX.shape[0]/2),2) * scaling
    else:
        candidateShiftYX = candidateShiftYX.reshape(candidateShiftYX.shape[0],1) * scaling
    # Sanity check and reminder that we have a 2xMxN AB image pair
    assert(len(imageAB.shape) == 3)  
    assert(imageAB.shape[0] == 2)
        
    if log:
#        print('======== Score shift ========', candidateShiftYX.T)
        print('======== Score shift ========')

    if method == 'joint':
        # Perform the joint deconvolution to recover a single object
        res = deconvRL_PIV(hMatrix, imageAB, maxIter=maxIter, shiftDescription=candidateShiftYX)
        # Evaluate how well the forward-projected result matches the actual camera images, using SSD
        candidateImageAB = forwardProjectACC_PIV(hMatrix, res, candidateShiftYX)
    elif method == 'joint-test-trivial':
        # Debugging method in which I use trivial projectors that behave like a delta function PSF
        res = deconvRLTrivial(hMatrix, imageAB, maxIter=maxIter, shiftDescription=candidateShiftYX)
        candidateImageAB = forwardProjectTrivial(hMatrix, res, candidateShiftYX)
    else:
        # Just warp the raw B image manually and look at how the two images compare
        assert(method == 'naive')
        candidateImageAB = imageAB.copy()
        # A bit of dimensional gymnastics here, because ShiftObject expects an *object*,
        # i.e. a 3D volume, whereas in this case we just have a 2D image
        candidateImageAB[1,:,:] = ShiftObject(candidateImageAB[np.newaxis,0,:,:], candidateShiftYX)[0]  
        res = None  # So that we have something to return
    # Sanity check and reminder that we have a 2xMxN AB image pair
    assert(len(candidateImageAB.shape) == 3)  
    assert(candidateImageAB.shape[0] == 2)

    imageToScore = candidateImageAB[:, 1:-1-actualImageExtendSize, 1:-1-actualImageExtendSize]
    referenceImage = imageAB[:, 1:-1-actualImageExtendSize, 1:-1-actualImageExtendSize]
    # Score by comparing the A and B images to the ones we are optimizing on.
    # Note: in some simulated or naive cases, the A camera images will always be a perfect match,
    # but for the real case the joint solution will be a compromise for both the A and B camera images.
    #
    # I have tried to renormalize to aid comparison between the images - based on the relative intensity
    # of the candidate and observed A images. I chose the A images because they will be identical in the case
    # of the 'naive' method (direct warping). However, for the 'joint' method they won't be.
    # TODO: I need to think more about whether this normalization is necessary and appropriate.
    # (I think I introduced it in the hope of fixing a problem,
    # but lack of normalization wasn't the fundamental issue in the end)
    renormHack = np.average(candidateImageAB[0]) / np.average(imageAB[0])
    ssdScore = np.sum((imageToScore/renormHack - referenceImage)**2)

    if comparator is not None:
        maxLoc = np.argmax(np.abs(imageToScore - comparator)[1:-1,1:-1])
        maxVal =    np.max(np.abs(imageToScore - comparator)[1:-1,1:-1])
        print('showing B image diffs')
        plt.imshow((imageToScore[1] - comparator)[170:,150:])
        plt.colorbar()
        plt.title('BRel (max %e)'%maxVal)
        print('Max val %f at %d (image scale %d)' % (maxVal, maxLoc, np.max(comparator)))
        plt.show()

    if shiftHistory is not None:
        shiftHistory.Update(candidateShiftYX, ssdScore)
        if log:
            if shiftHistory.PlotHistory(onlyPlotEvery=20):
                if method == 'joint':
                    dualObject = np.tile(res[:,np.newaxis,:,:] / 2.0, (1,2,1,1))
                    dualObject[:,1,:,:] = ShiftObject(dualObject[:,1,:,:], shiftDescription)
                    ShowDualObjectAndFlow(dualObject, candidateShiftYX)
                else:
                    ShowDualObjectAndFlow(candidateImageAB, candidateShiftYX)
                print('Last trial shift: ', candidateShiftYX.T)

    if log:
        print('return %e' % ssdScore)
    return (ssdScore, renormHack, np.average(candidateImageAB[0]), np.average(imageAB), candidateImageAB, res)

def ShowDualObjectAndFlow(dualObject, shiftDescription, otherObject=None, otherObject2=None):
    plt.subplot(1, 2, 1)
    if (len(dualObject.shape) == 4):
        assert(dualObject.shape[1] == 2)
        plt.imshow(dualObject[0,0])
        plt.subplot(1, 2, 2)
        plt.imshow(dualObject[0,1])
    else:
        assert(len(dualObject.shape) == 3)  # It's actually a dual image not an object
        assert(dualObject.shape[0] == 2)
        plt.imshow(dualObject[1])
    iwPos = IWCentresForObject(dualObject)
    if xMotionPermitted == False:
        for n in range(iwPos.shape[0]):
            plt.plot([iwPos[n,0], iwPos[n,0]], \
                     [iwPos[n,1], iwPos[n,1] - shiftDescription[n,0]/2], color='red')
    else:
        for n in range(iwPos.shape[0]):
            plt.plot([iwPos[n,0], iwPos[n,0] - shiftDescription[n,0]/2], \
                     [iwPos[n,1], iwPos[n,1] - shiftDescription[n,1]/2], color='red')
    plt.xlim(0, dualObject.shape[-1])
    plt.ylim(dualObject.shape[-2], 0)
    plt.show()
    if otherObject is not None:
        plt.imshow(otherObject[0])
        plt.show()        
    if otherObject2 is not None:
        plt.imshow(otherObject2[0])
        plt.show()   
        
def CheckConvergence(funcToCall, convergedShift, args):
    initialScore = funcToCall(convergedShift.flatten(), *args)
    print('initial score', initialScore)
    for du in [0.5, -0.5, 1.5, -1.5]:
        for n in [7, 8, 12, 13]:
            temp = convergedShift.copy()
            temp[n] += du
            score = funcToCall(temp, *args)
            print('offset score', score)
            if (score < initialScore):
                print(n, du, 'BETTER! (by %f%%)' % ((initialScore-score)/score*100))

def ReportOnOptimizerConvergence(shiftHistory, method, obj, hMatrix=None):
    if shiftHistory is None:
        print('ReportOnOptimizerConvergence returning - called with shiftHistory=None')
        return
    bestShift = shiftHistory.BestShift()
    print('Best score:', shiftHistory.BestScore())
    print('Best shift: np.array([', end='')
    for n in bestShift.flatten():
        print('%f, '%n, end='')
    print('])')
    CheckConvergence(ScoreShift2, bestShift.flatten(), (method, obj, hMatrix, None, 1.0, False))
    return bestShift
                
class ShiftHistory:
    def __init__(self):
        self.Reset()

    def __copy__(self):
        result = ShiftHistory()
        result.shiftHistory = self.shiftHistory
        result.scoreHistory = self.scoreHistory
        result.counter = self.counter
        return result

    def Reset(self):
        self.scoreHistory = []
        self.shiftHistory = []
        self.counter = 0
    
    def Update(self, shift, score):
        self.shiftHistory.append(shift)
        self.scoreHistory.append(score)
        self.counter = self.counter + 1
        
    def BestScore(self):
        return np.min(self.scoreHistory)

    def BestShift(self):
        return self.shiftHistory[np.argmin(self.scoreHistory)]

    def PlotHistory(self, onlyPlotEvery=1):
        if ((self.counter%onlyPlotEvery) == 0) and (len(self.shiftHistory) > 0):
            print('best score so far: %e' % np.min(self.scoreHistory))
            # Plot one of the shifts
            shiftShape = self.shiftHistory[0].shape
            selectedItem = np.minimum(int(np.sqrt(shiftShape[0])/2), shiftShape[0]-1)
            selectedShift = np.array(self.shiftHistory)[:, selectedItem, -1]
            plt.plot(selectedShift)
            plt.show()
            # Plot scores, with a suitable y axis scaling to see the interesting parts.
            # We limit the y axis to avoid stupid guesses distorting the plot.
            improvement = self.scoreHistory[0] - np.min(self.scoreHistory)
            plt.ylim(np.min(self.scoreHistory), self.scoreHistory[0]+2*improvement)
            plt.plot(self.scoreHistory)
            plt.show()
            # Plot an indication of which values are being updated on which iteration
            for n in range(1, len(self.scoreHistory)):
                changes = np.array(np.where((self.shiftHistory[n] == self.shiftHistory[n-1]).flatten() == False))
                if (changes.size > 0):
                    plt.plot(n, changes, 'x', color='red')
            plt.show()

            with open('scores.txt', 'a') as f:
                f.write('%f\t' % self.scoreHistory[-1])
                for n in self.shiftHistory[-1]:
                    if xMotionPermitted:
                        f.write('%f\t%f\t' % (n[0], n[1]))
                    else:
                        f.write('%f\t' % (n[0]))
                f.write('\n')
            return True
        else:
            return False        

In [None]:
# Generate synthetic light-field-recovered AB images (doing it the naive way, not using my new joint deconvolution)
# Run the imaging cycle on each of the AB images individually (i.e. introduce artefacts into them)
dualObjectRecovered = dualObject.copy()
for n in [0, 1]:
    cameraImage = forwardProjectACC(pivHMatrix, dualObject[:,n,:,:], logPrint=False)
    backProjected = backwardProjectACC(pivHMatrix, cameraImage, logPrint=False)
    
    # With the shifted images, we have problems with true zeroes in regions that have no features remaining.
    # To avoid this, I apply a very small nonzero background so that the deconvolution doesn't fail.
    backProjected = np.maximum(backProjected, 1e-5*np.max(backProjected))
    
    dualObjectRecovered[:,n,:,:] = deconvRL(pivHMatrix, backProjected, maxIter=8, Xguess=backProjected, logPrint=False)

In [None]:
print('Original object')
iwPos = IWCentresForObject(dualObject)
ShowDualObjectAndFlow(dualObject, shiftDescription)
print('Recovered from light field images (plane %d)' % zPlaneToModel)
ShowDualObjectAndFlow(dualObjectRecovered, shiftDescription)

In [None]:
if True:
    # If I want to give the algorithm the best possible starting point,
    # I can give it the actual true shift values as its starting point
    # (but it still may iterate away from that...)
    warnings.warn("WARNING: starting guess is actually the correct flow description")
    startShiftForOptimizer = shiftDescription.copy()
else:
    startShiftForOptimizer = initialShiftGuess.copy()

    
def OptimizeToRecoverFlowField(method, imageAB, hMatrix, shiftDescription, initialShiftGuess, shiftHistory=None):
    imageAB = imageAB.copy()    # This is just paranoia - I don't think it should get manipulated
    print('True shift:', shiftDescription.T)

    if False:
        plt.imshow(imageAB[0,:,:])
        plt.show()
        plt.imshow(imageAB[1,:,:])
        plt.show()

    if False:
        print('Score for correct shift:', ScoreShift2(shiftDescription.flatten(), method, imageAB, hMatrix))
        print('Score for initial guess:', ScoreShift2(initialShiftGuess.flatten(), method, imageAB, hMatrix))

    if True:
        optimizationAlgorithm = 'Powell'
        options = {'xtol': 1e-2}
    elif True:
        optimizationAlgorithm = 'L-BFGS-B'
        options = {'eps': 5e-03, 'gtol': 1e-6}
    else:
        optimizationAlgorithm = 'Nelder-Mead'
        options = {'eps': 5e-03, 'xatol': 1e-2, 'adaptive': True}

    if shiftHistory is None:
        shiftHistory = ShiftHistory()

    # Optimize to obtain the best-matching shift
    try:
        shift = scipy.optimize.minimize(ScoreShift2, initialShiftGuess, bounds=shiftSearchBounds, args=(method, imageAB, hMatrix, shiftHistory), method=optimizationAlgorithm, options=options)
        print('Optimizer finished:', str(shift.message), 'Final shift:', shift.x.T)
    except KeyboardInterrupt:
        # Catch keyboard interrupts so that we still return whatever shiftHistory we have built up so far.
        print('KEYBOARD INTERRUPT DURING OPTIMIZATION')
    return shiftHistory

In [None]:
# Perform the reconstruction using direct shift-matching of the raw input images (real experimental SPIM images)
if False:
    shiftHistoryRaw = OptimizeToRecoverFlowField('naive', dualObject[0], None, shiftDescription, startShiftForOptimizer)

# Note: if continuing a previously-interrupted run then we can do this to pick up roughly where we left off.
# i.e. provide BestShift() for the two shift-related input parameters, and pass the existing shift history as the final (optional) parameter
#    shiftHistoryRaw = OptimizeToRecoverFlowField('naive', dualObject[0], None, shiftHistoryRaw.BestShift(), shiftHistoryRaw.BestShift(), shiftHistoryRaw)

In [None]:
try:
    bestShift = ReportOnOptimizerConvergence(shiftHistoryRaw, 'naive', dualObject[0])
except NameError:
    warnings.warn('History probably not available')

In [None]:
# Perform the reconstruction using direct shift-matching of the light-field-deconvolved images
if False:
    shiftHistoryNaive = OptimizeToRecoverFlowField('naive', dualObjectRecovered[0], None, shiftDescription, startShiftForOptimizer)

In [None]:
try:
    ReportOnOptimizerConvergence(shiftHistoryNaive, 'naive', dualObjectRecovered[0])
except NameError:
    warnings.warn('History probably not available')

In [None]:
# Perform the reconstruction using my new joint algorithm
if True:
    # Generate a camera image pair from the object.
    if source == 'piv':
        # The camera AB images are determined by separate forward projection of the AB spim images
        imageAB = forwardProjectACC(pivHMatrix, dualObject)
    else:
        # The synthetic B image is determined with the help of the chosen shift transform.
        imageAB = forwardProjectACC(pivHMatrix, dualObject[:,0,:,:])

    if False:
        # Try starting using the solution obtained by direct warping of AB image pair,
        # to see if that yields a better minimum than the one I had found so far
        startShiftForOptimizer = np.array([4.179692, 2.422054, 2.277945, -3.442326, 2.588265, -1.628299, -0.701406, -0.351879, 1.371176, -0.887983, -0.329903, -1.814173, -0.039174, -1.546245, 4.530734, 6.433376, -3.923658, 0.765194, -3.139235, 11.998391, 0.159123, 4.712269, 0.000467, 1.881291, -1.076288, 2.571115, -0.752311, 6.346806, 0.705785, 1.599266, 0.526477, 0.520794, 1.613503, -0.944800, -4.052433, 0.938896, -9.285762, 9.058932, -1.427279, 3.465706, -2.667963, 6.611281, -2.711337, 6.691590, 1.149587, 6.157966, 5.232333, 7.419857, 3.860831, 0.035423, 1.096535, -0.919879, 1.315764, 0.783761, -1.649745, 1.829128, -1.506330, 2.903395, -2.640247, 6.333610, -2.974801, 5.116153, -1.640844, 6.727836, 6.771447, 5.497423, 7.318270, 3.963571, -1.879130, 0.376258, -2.545277, 3.033343, -1.405359, 2.988077, -3.664550, 3.713645, -2.404847, 3.906314, -0.068660, 0.731329, 3.443943, 1.132651, 8.621877, 2.114740, 4.915054, 3.548191, 3.346262, 5.315995, -2.250714, 3.869669, -3.046189, 2.948226, -1.592374, 0.569959, -0.875566, 0.699708, -0.309011, -0.220754, 2.785740, 4.885732, 1.892708, 2.223612, 5.977712, 3.248553, 1.007629, 3.325284, -0.803304, 5.829418, -2.376718, 0.837404, 0.982720, 0.897666, -1.992212, 4.365900, -0.497122, 3.024971, -0.809540, 2.668115, 3.179223, -4.673659, 3.866968, 0.850031, 2.411868, 0.370574, 3.250005, 2.128673])

    # Run the joint optimizer optimizer to find the shift value for an input frame pair
    shiftHistoryJoint = OptimizeToRecoverFlowField('joint', imageAB, pivHMatrix, shiftDescription, startShiftForOptimizer, shiftHistoryJoint)

In [None]:
if False:
    np.save('shiftHistoryJoint.npy', shiftHistoryJoint.shiftHistory)
    np.save('scoreHistoryJoint.npy', shiftHistoryJoint.scoreHistory)
    np.save('dualObjectJoint.npy', dualObject)

In [None]:
try:
    ReportOnOptimizerConvergence(shiftHistoryJoint, 'joint', imageAB, pivHMatrix)
except NameError:
    warnings.warn('History probably not available')

In [None]:
# Look at how the scores are evolving during the powell iterations

if True:
    vals = np.array(shiftHistoryJoint.shiftHistory)
    scores = np.array(shiftHistoryJoint.scoreHistory)
else:
    # Load from files previously saving using:
    #   np.save('scoreHistory.npy', np.array(shiftHistoryJoint.scoreHistory))
    #   np.save('shiftHistory.npy', np.array(shiftHistoryJoint.shiftHistory))
    vals = np.load('/Users/jonny/Desktop/shiftHistory.npy')
    scores = np.load('/Users/jonny/Desktop/scoreHistory.npy')
iwOfInterest = 5*7+3
iwOfInterest = 5*7+6# Looking at border control point for shiftHistoryNaive
x = []
y = []
y2 = []

if False:
    # Compare the results from different control points
    sh = vals[3020].flatten()
    sh[iwOfInterest] = 5.011
    (_,_,_,_,comp) = ScoreShift3(sh, 'naive', objectToUse, log=False)    
    for d in [5.012, 5.0135, 5.0145]:
        sh[iwOfInterest] = d
        sc = ScoreShift2(sh, 'naive', objectToUse[0], log=False, comparator=comp)
        print('score', sc)

if True:
    for iw in [iwOfInterest]:
#    for iw in range(49):
        for n in range(0,vals.shape[0]-1):
            if (vals[n,iw,0] != vals[n+1,iw,0]):
                x.append(vals[n,iw,0])
                y.append(scores[n])
            else:
                if (len(x) > 0):
                    if (len(x) > 2):
                        plt.plot(x, y, 'x')
                        plt.title('%d,%d %d(%d)'%(iw/7,iw%7, n, len(x)))
                        plt.show()
                    x = []
                    y = []
                    y2 = []
                nStart = n
    if (len(x) > 2):
        plt.plot(x, y, 'x')
        plt.title('%d,%d %d(%d)'%(iw/7,iw%7, n, len(x)))
        plt.show()



In [None]:
# Code useful for understanding how two images differ, since I have been having
# a lot of problems related to warp(), where tiny changes in shifts make a difference to the result
# (these are largely due to edge effects of one type or another)
def ShowDifferences(im1, im2, fullIm1, sh):
    diff = im1-im2
    print(diff.shape)
    print('Largest difference', np.max(np.abs(diff)), 'loc', np.argmax(np.abs(diff)), \
          np.argmax(np.abs(diff))%diff.shape[1], int(np.argmax(np.abs(diff))/diff.shape[1]))
    plt.imshow(diff)
    iwPos = IWCentresForObject(dualObject, st='piv')
    if False:
        for n in range(iwPos.shape[0]):
            plt.plot(iwPos[n,0], iwPos[n,1], 'x', color='red')
    elif True:
        src = IWCentresForObject(fullIm1[np.newaxis])
        if (src.shape[0] != sh.shape[0]):
            assert(src.shape[0] == sh.shape[0])
        if (shiftType == 'piv-zeroedge'):
            (src, sh) = AddZeroEdgePadding(fullIm1[np.newaxis], src, sh)
            print('padded')
        for n in range(sh.shape[0]):
            plt.plot([iwPos[n,0], iwPos[n,0]+sh[n,0]*1e9], \
                     [iwPos[n,1], iwPos[n,1]+sh[n,1]*1e9], color='red')
            if not sh[n,0] == 0:
                print('x', [iwPos[n,0], iwPos[n,0]+sh[n,0]*1e9])
                print('y', [iwPos[n,1], iwPos[n,1]+sh[n,1]*1e9])
    plt.xlim(-10,60)
    plt.ylim(80,-10)
    
    plt.show()

In [None]:
# To understand how the optimizer is behaving, scan the search space rather than optimizing

# Generate a camera image pair from the object.
# The B image is determined with the help of the chosen shift transform.
imageAB = forwardProjectACC_PIV(pivHMatrix, dualObject[:,0,:,:], shiftDescription)
# Run the joint optimizer optimizer to find the shift value for an input frame pair
shiftHistorySearch = ShiftHistory()
candidateShiftYX = startShiftForOptimizer.copy()
for dx in range(-4,5,1):
#    candidateShiftYX[4*11+3] = dx
#    [5147746000.0, 4475782000.0, 3882488600.0, 3441547500.0, 3223274800.0, 3499643600.0, 4056565800.0, 4845896000.0, 5856099300.0]
    candidateShiftYX[3*11+5] = dx
    ScoreShift2(candidateShiftYX.flatten(), 'joint', imageAB, pivHMatrix, shiftHistorySearch, log=False)


In [None]:
plt.plot(shiftHistorySearch.scoreHistory)