# Notebook implementing fast light field deconvolution, and motion-aware deconvolution

In [None]:
from __future__ import print_function

import os, sys, time, datetime, warnings
import math
import numpy as np
import cProfile, pstats
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
from scipy.ndimage.filters import gaussian_filter
%matplotlib inline

from joblib import Parallel, delayed

import psfmatrix
import jutils as util
import lfdeconv, lfdeconv_piv, projector, lfimage
import motion_recovery as recovery
import flow

# My code expects these to start as None, and I want to do this right at the top
# to minimize the risk of accidentally resetting them to None after I've spent hours
# accumulating valuable data in them!
shiftHistoryJoint = None
shiftHistoryRaw = None
shiftHistoryNaive = None

In [None]:
# Select the PSF matrix we will use in this notebook.
# This closer-spaced one is useful for focusing on native-focal-plane artefacts in flow analysis.
matPath = 'PSFmatrix/PSFmatrix_M40NA0.95MLPitch150fml3000from-13to0zspacing0.5Nnum15lambda520n1.0.mat'   

# For now I am only actually simulating flow in a single plane (imaged with light field microscopy),
# so we set up a single-plane PSF to work from.
fullHMatrix = psfmatrix.LoadMatrix(matPath)
zPlaneToModel = fullHMatrix.numZ-1   # Modelling native focal plane
zPlaneToModel = 7   # Modelling some way from the native focal plane, which should perform fairly well
zPlaneToModel = fullHMatrix.numZ-3   # Modelling close to native focal plane. This has artefacts - prev one is fairly artefact-free
zPlaneToModel = fullHMatrix.numZ-2
pivHMatrix = psfmatrix.LoadMatrix(matPath, numZ=1, zStart=zPlaneToModel)

In [None]:
# Generate two identical images of the same synthetic object,
# which for now consists of a cloud of random gaussian spots
numSpots = 1000
imageSize = 180
sigma = 2
controlPointSpacing = 30
previouslySavedSynthetic = '2019-06-23 18.10.19 syntheticInput.npy'

syntheticImageExtendSize = 30       # TODO: document the purpose of this
syntheticObjectExt = np.zeros((1, imageSize+syntheticImageExtendSize, imageSize), dtype='float32')
if previouslySavedSynthetic is not None:
    # Load a synthetic object that we saved from a previous run
    # (This is useful for reproducible behaviour)
    syntheticObjectExt[0] = np.load(previouslySavedSynthetic)
else:
    # Generate a synthetic object, and save it to disk for reference
    syntheticObjectExt[0, (np.random.random(numSpots)*syntheticObjectExt.shape[1]).astype('int'), \
                          (np.random.random(numSpots)*syntheticObjectExt.shape[2]).astype('int')] = 1
    syntheticObjectExt = gaussian_filter(syntheticObjectExt, sigma=(0,sigma,sigma)).astype('float32')
    fn = datetime.datetime.now().strftime("%Y-%m-%d %H.%M.%S syntheticInput.npy")
    np.save(fn, syntheticObjectExt[0])
    
plt.imshow(syntheticObjectExt[0])

In [None]:
if True:
    # Model a flow that can be different at different xy locations:
    # Note: allowing an x search range is more open-minded,
    # but it makes little difference to the outcome, in the case of vertical flow
    shifter = flow.PIVShifter(controlPointSpacing, syntheticImageExtendSize, xMotionPermitted=False)
    source = 'synthetic'
    searchRangeXY = (0,10)
elif False:
    # Slightly simplified model in which we assume flow will be zero at all boundaries of the image
    shifter = flow.PIVZeroEdgeShifter(controlPointSpacing, 0, xMotionPermitted=True)
    source = 'piv'
    searchRangeXY = (8,8)
else:
    # Very simple model in which there is a uniform shift across the whole field of view
    shifter = flow.UniformSKShifter(0, 0, True)
    source = 'synthetic'
    searchRangeXY = (0,10)
print("** We will use a model for the shifts defined by class {0} **".format(type(shifter).__name__))
if source == 'synthetic':
    print("** The synthetic data will move in a way that is compatible with that model **")

In [None]:
if source == 'synthetic':
    # Generate a synthetic shift in the B image.
    # It is the shifter class that provides the motion profile we will use.
    # Thus, if we chose a simpler shifter (above), we will use an appropriately simple flow profile
    dualObject = np.tile(syntheticObjectExt[:,np.newaxis,:,:], (1,2,1,1)) *1e3#* 1e7
    if False:
        warnings.warn('Loading previously-saved dualObject')
        dualObject = np.load('dualObject5.npy')
    
    trueShiftDescription = shifter.ExampleShiftDescriptionForObject(dualObject)
    dualObject[:,1,:,:] = shifter.ShiftObject(dualObject[:,1,:,:], trueShiftDescription)

    # Since I am only using a local minimizer, we need to start with a decent guess as to the flow.
    # I think that's ok though: we should have that from a PIV estimate on the with-artefacts AB images
    #initialShiftGuess = np.zeros(VelocityShapeForObject(dualObject))
    initialShiftGuess = trueShiftDescription + np.random.random(trueShiftDescription.shape) * 4.0
else:
    assert(source == 'piv')
    pivImagePair = tifffile.imread('piv-raw-data/038298.tif')[24:26,:15*20,:15*16].astype('float32')
    # Note: frames 57-58 (wrong pair) would be an option to investigate bigger motion (~16px) with imperfect AB matches
    #              64-65 (correct pair) are another example of small movement (0-3px)
    dualObject = pivImagePair[np.newaxis]
    # For now, I just guess an initial shift of zero
    trueShiftDescription = np.zeros(VelocityShapeForObject(dualObject)).astype('float32')
    initialShiftGuess = trueShiftDescription.copy()
    
    
plt.subplot(1, 2, 1)
plt.imshow(dualObject[0,0])
plt.subplot(1, 2, 2)
plt.imshow(dualObject[0,1])
plt.show()

In [None]:
if False:
    # If I want to give the algorithm the best possible starting point,
    # I can give it the actual true shift values as its starting point
    # (but it still may iterate away from that...)
    warnings.warn("WARNING: starting guess is actually the correct flow description")
    startShiftForOptimizer = trueShiftDescription.copy()
else:
    startShiftForOptimizer = initialShiftGuess.copy()    

In [None]:
def DeconvRL_PIV_OLD(hMatrix, imageAB, maxIter, Xguess, shifter, shiftDescription):
    # I believed this to be the RL algorithm in the way I have written it in the past.
    # However, this gives different results to Prevedel's implementation
    # (mine seems to converge more slowly).
    # TODO: I should look into this and see if I've just made a mistake or if they are actually different.
    
    # Xguess is our single combined guess of the object
    Xguess = Xguess.copy()    # Because we will be updating it, and caller may not always be expecting that
    for i in tqdm(range(maxIter), desc='RL deconv'):
        t0 = time.time()
        relativeBlurDual = imageAB / ForwardProjectACC_PIV(hMatrix, Xguess, shiftDescription)
        Xguess *= FusedBackwardProjectACC_PIV(hMatrix, relativeBlurDual, shiftDescription)
        Xguess[np.where(np.isnan(Xguess))] = 0
        t1 = time.time() - t0
    return Xguess

# Some replacement functions to use for testing (effective PSF is a delta function, 1:1 mapping from image to object)
def ForwardProjectTrivial(hMatrix, obj, shifter, shiftDescription):
    # Compute the AB images obtained from the single object we are provided with
    # (with the B image being of the object shifted by shiftYX).
    # We give each image half the intensity in order to conserve energy.
    dualObject = np.tile(obj[:,np.newaxis,:,:] / 2.0, (1,2,1,1))
    dualObject[:,1,:,:] = shifter.ShiftObject(dualObject[:,1,:,:], shiftDescription)
    return dualObject[0]

def DualBackwardProjectTrivial(hMatrix, dualProjection, shifter, shiftDescription):
    dualObject = dualProjection[np.newaxis].copy()
    dualObject[:,1,:,:] = shifter.ShiftObject(dualObject[:,1,:,:], -shiftDescription)
    return dualObject

def FusedBackwardProjectTrivial(hMatrix, dualProjection, shifter, shiftDescription):
    dualObject = DualBackwardProjectTrivial(hMatrix, dualProjection, shiftDescription)
    result = np.sum(dualObject, axis=1) / 2.0     # Merge the two backprojection
    return result

def DeconvRLTrivial(hMatrix, imageAB, maxIter, shifter, shiftDescription):
    # Note:
    #  Htf is the *initial* backprojection of the camera image
    #  Xguess is the initial guess for the object
    return FusedBackwardProjectTrivial(hMatrix, imageAB, shifter, shiftDescription)

# Calculate flow by using direct shift-matching of the raw input images
*Since there is no light field involved at all in this approach, it is of course expected to work very well!*

In [None]:
if True:
    shiftHistoryRaw = recovery.OptimizeToRecoverFlowField('naive', dualObject[0], None, \
                                        shifter, trueShiftDescription, startShiftForOptimizer, searchRangeXY)

# Note: if continuing a previously-interrupted run then we can do this to pick up roughly where we left off.
# i.e. provide BestShift() for the two shift-related input parameters, and pass the existing shift history as the final (optional) parameter
#    shiftHistoryRaw = recovery.OptimizeToRecoverFlowField('naive', dualObject[0], None, shiftHistoryRaw.BestShift(), shiftHistoryRaw.BestShift(), shiftHistoryRaw)

In [None]:
try:
    bestShift = recovery.ReportOnOptimizerConvergence(shiftHistoryRaw, shifter, 'naive', dualObject[0])
except NameError:
    warnings.warn('History probably not available')

# Calculate flow by using direct shift-matching of the light-field-deconvolved images
*This is not expected to work particularly well, due to the presence of the artefacts in the recovered A and B images*

In [None]:
# Generate synthetic light-field-recovered A and B images,
# by running the imaging cycle on each of the AB images individually (i.e. introduce artefacts into them)
# These recovered volumes represent the inputs for the 'naive' method of recovering the flow profile,
# but are not used as inputs for my new algorithm.
dualObjectRecovered = dualObject.copy()
for n in [0, 1]:
    cameraImage = lfdeconv.ForwardProjectACC(pivHMatrix, dualObject[:,n,:,:], logPrint=False)
    backProjected = lfdeconv.BackwardProjectACC(pivHMatrix, cameraImage, logPrint=False)
    
    # With the shifted images, we have problems with true zeroes in regions that have no features remaining.
    # To avoid this, I apply a very small nonzero background so that the deconvolution doesn't fail.
    backProjected = np.maximum(backProjected, 1e-5*np.max(backProjected))
    
    dualObjectRecovered[:,n,:,:] = lfdeconv.DeconvRL(pivHMatrix, backProjected, maxIter=8, Xguess=backProjected, logPrint=False)

In [None]:
print('Original object (and true flow)')
iwPos = shifter.IWCentresForObject(dualObject)
flow.ShowDualObjectAndFlow(dualObject, shifter, trueShiftDescription)
print('Recovered from light field images (plane %d) - showing true flow' % zPlaneToModel)
flow.ShowDualObjectAndFlow(dualObjectRecovered, shifter, trueShiftDescription)
plt.imsave('syntheticInput.tif', dualObject[0,0])
plt.imsave('syntheticInputB.tif', dualObject[0,1])

plt.imsave('syntheticA.tif', dualObjectRecovered[0,0])
plt.imsave('syntheticB.tif', dualObjectRecovered[0,1])

In [None]:
if True:
    shiftHistoryNaive = recovery.OptimizeToRecoverFlowField('naive', dualObjectRecovered[0], None, \
                                        shifter, trueShiftDescription, startShiftForOptimizer, searchRangeXY, shiftHistoryNaive)

In [None]:
try:
    recovery.ReportOnOptimizerConvergence(shiftHistoryNaive, shifter, 'naive', dualObjectRecovered[0])
except NameError:
    warnings.warn('History probably not available')

# Calculate flow by using my new algorithm on the light-field-deconvolved images
*Hopefully this should work better!*

Note that this cell will take a *very* long time to run.

In [None]:
if True:
    # Generate a camera image pair from the object.
    if source == 'piv':
        # The camera AB images are determined by separate forward projection of the AB spim images in dualObject
        imageAB = lfdeconv.ForwardProjectACC(pivHMatrix, dualObject)
    else:
        # The synthetic B image is determined with the help of the chosen shift transform.
        imageAB = lfdeconv_piv.ForwardProjectACC_PIV(pivHMatrix, dualObject[:,0], shifter, trueShiftDescription)

    # Run the joint optimizer optimizer to find the shift value for an input frame pair
    shiftHistoryJoint = recovery.OptimizeToRecoverFlowField('joint', imageAB, pivHMatrix, \
                                        shifter, trueShiftDescription, startShiftForOptimizer, searchRangeXY, shiftHistoryJoint)

In [None]:
try:
    recovery.ReportOnOptimizerConvergence(shiftHistoryJoint, shifter, 'joint', imageAB, pivHMatrix)
except NameError:
    warnings.warn('History probably not available')

In [None]:
if False:
    # Optional code snippet to save the optimizer histories
    np.save('shiftHistoryJoint.npy', shiftHistoryJoint.shiftHistory)
    np.save('scoreHistoryJoint.npy', shiftHistoryJoint.scoreHistory)
    np.save('dualObjectJoint.npy', dualObject)

In [None]:
# Optional code to look at how the scores are evolving during the powell iterations

if True:
    vals = np.array(shiftHistoryJoint.shiftHistory)
    scores = np.array(shiftHistoryJoint.scoreHistory)
else:
    # Load from files previously saving using:
    #   np.save('scoreHistory.npy', np.array(shiftHistoryJoint.scoreHistory))
    #   np.save('shiftHistory.npy', np.array(shiftHistoryJoint.shiftHistory))
    vals = np.load('/Users/jonny/Desktop/shiftHistory.npy')
    scores = np.load('/Users/jonny/Desktop/scoreHistory.npy')
iwOfInterest = 5*7+3
iwOfInterest = 5*7+6# Looking at border control point for shiftHistoryNaive
x = []
y = []
y2 = []

if False:
    # Compare the results from different control points
    sh = vals[3020].flatten()
    sh[iwOfInterest] = 5.011
    (_,_,_,_,comp) = recovery.ScoreShiftDetailed(sh, shifter, 'naive', objectToUse, log=False)    
    for d in [5.012, 5.0135, 5.0145]:
        sh[iwOfInterest] = d
        sc = recovery.ScoreShift(sh, shifter, 'naive', objectToUse[0], log=False, comparator=comp)
        print('score', sc)

if True:
    for iw in [iwOfInterest]:
#    for iw in range(49):
        for n in range(0,vals.shape[0]-1):
            if (vals[n,iw,0] != vals[n+1,iw,0]):
                x.append(vals[n,iw,0])
                y.append(scores[n])
            else:
                if (len(x) > 0):
                    if (len(x) > 2):
                        plt.plot(x, y, 'x')
                        plt.title('%d,%d %d(%d)'%(iw/7,iw%7, n, len(x)))
                        plt.show()
                    x = []
                    y = []
                    y2 = []
                nStart = n
    if (len(x) > 2):
        plt.plot(x, y, 'x')
        plt.title('%d,%d %d(%d)'%(iw/7,iw%7, n, len(x)))
        plt.show()

