Status: my new code (which eliminates some of the FFTs) takes 16 seconds instead of 63. 

Without exploiting symmetries it takes 23 seconds. I had to write C code to speed up the generation of the reflected/transposed FFT matrices - without that it was barely any faster than just computing the rfftn for all aa,bb.

Now, 2/3rds of the time is spent in fft2, so that is the bottleneck. If I insist on a square array then I could halve that, but that would be a limitation. I could, I suppose, make it a *recommendation* that allows the code to run faster. I could certainly take advantage of that for my own work.


NOTE: my c code can't cope with an array that has been transposed (why...? Maybe because of a <= condition on the loop?). I should probably fix that, though I doubt it's a performance issue to just .copy() the transposed array.
However, it looks as if a decent chunk of the fft time is actually being spent in the other ffts (for the reduced arrays) anyway!

In [None]:
import numpy as np
import numexpr as ne
import scipy.ndimage, scipy.optimize, scipy.io
from scipy.ndimage.filters import convolve
from scipy.signal import convolve2d, fftconvolve
import os, sys, time, warnings
import matplotlib.pyplot as plt
%matplotlib inline
import tifffile
import h5py
from multiprocessing import Pool
from functools import partial
from joblib import Parallel, delayed
import cProfile, pstats
from tqdm import tqdm_notebook as tqdm

In [None]:
matPath = 'PSFmatrix/PSFmatrix_M22.2NA0.5MLPitch125fml3125from-110to110zspacing4Nnum19lambda520n1.33.mat'

warnings.warn('WARNING: Switched to faster matrix for testing')
matPath = 'PSFmatrix/PSFmatrix_M40NA0.95MLPitch150fml3000from-26to0zspacing2Nnum15lambda520n1.0.mat'

In [None]:
# Load the matrices from the .mat file.
# This is slow since they must be decompressed and are rather large! (9.5GB each, in single-precision FP)
with h5py.File(matPath, 'r') as f:
    print('Load H')
    sys.stdout.flush()
    H = f['H'].value.astype('float32')
    print('Load Ht')
    sys.stdout.flush()
    Ht = f['Ht'].value.astype('float32')
    print('Load misc')
    sys.stdout.flush()
    CAindex = f['CAindex'].value.astype('int')

print(H.dtype, H.shape, Ht.shape, CAindex.shape)
print(CAindex.transpose())

## Objects stored in the .mat file

### Optical parameters from GUI: [? means I am not sure if or where it is stored]

M<br>
NA<br>
d    "fml" in GUI (stored here in units of m)<br>
pixelPitch is "ML pitch" / "Nnum" (stored here in units of m)<br>
? n<br>
? wavelength<br>

### User parameters from GUI:

OSR<br>
zspacing<br>
? z-min<br>
? z-max<br>
Nnum<br>


### Misc parameter:

fobj (can presumably be deduced from mag, NA etc?)<br>

### The actual arrays:

H:             shape (56, 19, 19, 343, 343), type "f4"<br>
Ht:            shape (56, 19, 19, 343, 343), type "f4"<br>

### Information about object space:

x1objspace:    x pixel positions in object space (19 elements across one lenslet)<br>
x2objspace:    y pixel positions in object space (19 elements across one lenslet)<br>
x3objspace:    z pixel positions in object space (56 z planes)<br>
x1space:       x pixel positions in lenslet space (19 elements across one lenslet)<br>
x2space:       y pixel positions in lenslet space (19 elements across one lenslet)<br>

### Not sure what these are exactly:

CAindex:       shape (2, 56) - something about the start and end index of the PSF array, for each z plane.<br>
CP:            shape (343, 1)<br>
MLARRAY:       shape (1141, 1141), type "|V16"<br>
objspace:      shape (56, 1, 1)<br>
settingPSF:    You would think this contains the GUI parameters, but e.g. print(f['settingPSF']['M'].value) gives a strange 3x1 array [50, 50, 46, 50] etc...?<br>


In [None]:
# Note: I am a little unsure how to interpret the arrays I have loaded from the .mat.
# From looking at how H and CAindex are accessed, it looks as if the shapes I have loaded
# are the reversal of the shape ordering as expected in Matlab.
# I suppose that makes sense given that matlab is column-major in its array accesses.
# The data has been loaded from disk in the order it is *stored*,
# and I therefore need to flip around all the matlab array index ordering 
# (e.g. matlabArray(1,2,3) becomes pythonArray[3,2,1])

In [None]:
from scipy._lib._version import NumpyVersion
from numpy.fft import fft, fftn, rfft, rfftn, irfftn
_rfft_mt_safe = (NumpyVersion(np.__version__) >= '1.9.0.dev-e24486e')

def _next_regular(target):
    """
    Find the next regular number greater than or equal to target.
    Regular numbers are composites of the prime factors 2, 3, and 5.
    Also known as 5-smooth numbers or Hamming numbers, these are the optimal
    size for inputs to FFTPACK.

    Target must be a positive integer.
    """
    if target <= 6:
        return target

    # Quickly check if it's already a power of 2
    if not (target & (target-1)):
        return target

    match = float('inf')  # Anything found will be smaller
    p5 = 1
    while p5 < target:
        p35 = p5
        while p35 < target:
            # Ceiling integer division, avoiding conversion to float
            # (quotient = ceil(target / p35))
            quotient = -(-target // p35)

            # Quickly find next power of 2 >= quotient
            try:
                p2 = 2**((quotient - 1).bit_length())
            except AttributeError:
                # Fallback for Python <2.7
                p2 = 2**(len(bin(quotient - 1)) - 2)

            N = p2 * p35
            if N == target:
                return N
            elif N < match:
                match = N
            p35 *= 3
            if p35 == target:
                return p35
        if p35 < match:
            match = p35
        p5 *= 5
        if p5 == target:
            return p5
    if p5 < match:
        match = p5
    return match

def _centered(arr, newsize):
    # Return the center newsize portion of the array.
    newsize = np.asarray(newsize)
    currsize = np.array(arr.shape)
    startind = (currsize - newsize) // 2
    endind = startind + newsize
    myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
    return arr[tuple(myslice)]

def tempMul(bb,fshape,result):
    result *= np.exp(-1j * bb * 2*np.pi / fshape[0] * np.arange(result.shape[0],dtype='complex64'))[:,np.newaxis]
    return result

def expand2(result, bb, aa, Nnum, fshape):
    return np.tile(result, (Nnum,1))

def expand(reducedF, bb, aa, Nnum, fshape):
    result = np.tile(reducedF, (1,int(Nnum/2+1)))
    result = result[:,:int(fshape[1]/2+1)]
    result *= np.exp(-1j * aa * 2*np.pi / fshape[1] * np.arange(result.shape[1],dtype='complex64'))
    result = expand2(result, bb, aa, Nnum, fshape)
    return tempMul(bb,fshape,result)


def special_rfftn(in1, bb, aa, Nnum, fshape):
    # Compute the fft of elements in1[bb::Nnum,aa::Nnum], after in1 has been zero-padded out to fshape
    # We exploit the fact that fft(masked-in1) is fft(arr[::Nnum,::Nnum]) replicated Nnum times.
    reducedShape = ()
    for d in fshape:
        assert((d % Nnum) == 0)
        reducedShape = reducedShape + (int(d/Nnum),)
        
    assert(in1.ndim == 2)
    reduced = in1[bb::Nnum,aa::Nnum]

    if True:
        # Compute an array giving rfft(mask(in1))
        reducedF = scipy.fftpack.fft2(reduced, reducedShape)
        return expand(reducedF, bb, aa, Nnum, fshape)
        # Old code for reference - but I broke it out into expand() to aid in profiling
        result = np.tile(reducedF, (1,Nnum/2+1))
        result = result[:,:fshape[1]/2+1]
        result *= np.exp(-1j * aa * 2*np.pi / fshape[1] * np.arange(result.shape[1]))
        result = np.tile(result, (Nnum,1))
        result *= np.exp(-1j * bb * 2*np.pi / fshape[0] * np.arange(result.shape[0]))[:,np.newaxis]
        targetLength = int(fshape[1]/2)+1
        return result[:,:targetLength]
#        return tempMul(bb,fshape,result)

    else:
        # Compute an array giving fft(mask(in1)) (instead of rfft) - much more straightforward!
        reps = ()
        for d in fshape:
            reps = reps + (Nnum,)
        return np.tile(reducedF, reps)

def convolutionShape(in1, in2, Nnum):
    # Logic copied from fftconvolve source code
    s1 = np.array(in1.shape)
    s2 = np.array(in2.shape)
    shape = s1 + s2 - 1
    if False:
        # TODO: I haven't worked out if/how I can do this yet.
        # This is the original code in fftconvolve, which says:
        # Speed up FFT by padding to optimal size for FFTPACK
        fshape = [_next_regular(int(d)) for d in shape]
    else:
        fshape = [int(np.ceil(d/float(Nnum)))*Nnum for d in shape]
    fslice = tuple([slice(0, int(sz)) for sz in shape])
    return (fshape, fslice, s1)
    
def special_fftconvolve_part1(in1, bb, aa, Nnum, in2, newMethod=True):
    (fshape, fslice, s1) = convolutionShape(in1, in2, Nnum)
    # Pre-1.9 NumPy FFT routines are not threadsafe - this code requires numpy 1.9 or greater
    assert(_rfft_mt_safe)
    if newMethod:
        fa = special_rfftn(in1, bb, aa, Nnum, fshape)
    else:
        tempSlice = np.zeros(in1.shape, dtype=in1.dtype)
        tempSlice[bb::Nnum, aa::Nnum] = in1[bb::Nnum, aa::Nnum]
        fa = rfftn(tempSlice, fshape)
    return (fa, fshape, fslice, s1)

def special_fftconvolve_part3(fab, fshape, fslice, s1):
    ret = irfftn(fab, fshape)[fslice].copy()
    return _centered(ret, s1)

def special_fftconvolve(in1, bb, aa, Nnum, in2, newMethod=True, fb=None):
    '''
    in1 consists of subapertures of size Nnum x Nnum pixels.
    We are being asked to convolve only pixel (bb,aa) within each subaperture, i.e.
        tempSlice = np.zeros(in1.shape, dtype=in1.dtype)
        tempSlice[bb::Nnum, aa::Nnum] = in1[bb::Nnum, aa::Nnum]
    This allows us to take a significant shortcut in computing the FFT for in1.
    '''
    (fa, fshape, fslice, s1) = special_fftconvolve_part1(in1, bb, aa, Nnum, in2, newMethod=newMethod)
    if fb is None:
        fb = rfftn(in2, fshape)
    if newMethod:
        return (fa*fb, fshape, fslice, s1)
    else:
        return special_fftconvolve_part3(fa*fb, fshape, fslice, s1)

testHtCC = np.random.random((5,5,30,30)).astype(np.float32)
testHtCC = H[13,int(H.shape[1]/2)-2:int(H.shape[1]/2)+3,int(H.shape[2]/2)-2:int(H.shape[2]/2)+3,CAindex[0,13]-1:CAindex[1,13], CAindex[0,13]-1:CAindex[1,13]]
for shape in [(200,200), (200,300)]:
    testProjection = np.random.random(shape).astype(np.float32)
    testResultOld = backwardProjectForZ_old(testHtCC, testProjection)
    testResultNew = backwardProjectForZ(testHtCC, testProjection)
    print('test result (should be <<1):', np.max(np.abs(testResultOld - testResultNew)))

In [None]:
# Note: H.shape in python is (<num z planes>, Nnum, Nnum, <psf size>, <psf size>),
#                       e.g. (56, 19, 19, 343, 343)
useSymmetries = True
from numba import jit

sys.path.insert(0, 'py_symmetry')
import py_symmetry as jps


class Projector(object):
    def __init__(self, projection, HtCCBB, Nnum):
        # Note: H and Hts are not class variables - they are global to make multithreading easier
        
        # Nnum: number of pixels across a lenslet array (after rectification)
        self.Nnum = Nnum
        
        # This next chunk of logic copied from fftconvolve source code.
        # s1, s2: shapes of the input arrays
        # fshape: shape of the (full, possibly padded) result array in Fourier space
        # fslice: slicing tuple specifying the actual result size that should be returned
        self.s1 = np.array(projection.shape)
        self.s2 = np.array(HtCCBB[0].shape)
        shape = self.s1 + self.s2 - 1
        if False:
            # TODO: I haven't worked out if/how I can do this yet.
            # This is the original code in fftconvolve, which says:
            # Speed up FFT by padding to optimal size for FFTPACK
            self.fshape = [_next_regular(int(d)) for d in shape]
        else:
            self.fshape = [int(np.ceil(d/float(Nnum)))*Nnum for d in shape]
        self.fslice = tuple([slice(0, int(sz)) for sz in shape])
        
        # rfslice: slicing tuple to crop down full fft array to the shape that would be output from rfftn
        self.rfslice = (slice(0,self.fshape[0]), slice(0,int(self.fshape[1]/2)+1))
        return
    
    def MirrorXArray(self, Hts, fHtsFull):
        padLength = self.fshape[0] - Hts.shape[0]
        if False:
            fHtsFull = fHtsFull.conj() * np.exp((1j * (1+padLength) * 2*np.pi / self.fshape[0]) * np.arange(self.fshape[0],dtype='complex64')[:,np.newaxis])
            fHtsFull[:,1::] = fHtsFull[:,1::][:,::-1]
            return fHtsFull
        else:
            temp = np.exp((1j * (1+padLength) * 2*np.pi / self.fshape[0]) * np.arange(self.fshape[0])).astype('complex64')
            if True:
                result = jps.mirrorX(fHtsFull, temp)
            else:
                result = np.empty(fHtsFull.shape, dtype=fHtsFull.dtype)
                result[:,0] = fHtsFull[:,0].conj()*temp
                for i in range(1,fHtsFull.shape[1]):
                    result[:,i] = (fHtsFull[:,fHtsFull.shape[1]-i].conj()*temp)
            return result

    def MirrorYArray(self, Hts, fHtsFull):
        padLength = self.fshape[1] - Hts.shape[1]
        if False:
            fHtsFull = fHtsFull.conj() * np.exp(1j * (1+padLength) * 2*np.pi / self.fshape[1] * np.arange(self.fshape[1],dtype='complex64'))
            fHtsFull[1::] = fHtsFull[1::][::-1]
            return fHtsFull
        else:
            temp = np.exp((1j * (1+padLength) * 2*np.pi / self.fshape[1]) * np.arange(self.fshape[1])).astype('complex64')
            if True:
                result = jps.mirrorY(fHtsFull, temp)
            else:
                result = np.empty(fHtsFull.shape, dtype=fHtsFull.dtype)
                result[0] = fHtsFull[0].conj()*temp
                for i in range(1,fHtsFull.shape[0]):
                    result[i] = (fHtsFull[fHtsFull.shape[0]-i].conj()*temp)
            return result


        
    def convolvePart3(self, projection, bb, aa, Hts, fHtsFull, mirrorX):
        # TODO: to make this work, I need the full matrix for fHts and then I need to slice it 
        # to the correct shape when I call through to special_fftconvolve here. Is fshape what I need?
        (result1,_,_,_) = special_fftconvolve(projection,bb,aa,self.Nnum,Hts,newMethod=True,fb=fHtsFull[self.rfslice])
        if mirrorX:
            # Compute the mirror FFT array.
            # Note: I experimented with rewriting this code to speed it up,
            # but didn't seem to be able to improve on it much
            fHtsFull = self.MirrorXArray(Hts, fHtsFull)
            #padLength = self.fshape[0] - Hts.shape[0]
            #fHtsFull = fHtsFull.conj() * np.exp((1j * (1+padLength) * 2*np.pi / self.fshape[0]) * np.arange(self.fshape[0])[:,np.newaxis])
            #fHtsFull[:,1::] = fHtsFull[:,1::][:,::-1]

            
            #fHtsFull2 = fftn(Hts[::-1,:], self.fshape)
            #print('part 3', np.max(np.abs(fHtsFull2-fHtsFull)))
            (result2,_,_,_) = special_fftconvolve(projection,self.Nnum-bb-1,aa,self.Nnum,Hts[::-1,:],newMethod=True,fb=fHtsFull[self.rfslice]) 
            return result1+result2
        else:
            return result1

    def convolvePart2(self, projection, bb, aa, Hts, fHtsFull, mirrorY, mirrorX):
        result1 = self.convolvePart3(projection,bb,aa,Hts,fHtsFull,mirrorX)
        if mirrorY:
            fHtsFull = self.MirrorYArray(Hts, fHtsFull)
            #padLength = self.fshape[1] - Hts.shape[1]
            #fHtsFull = fHtsFull.conj() * np.exp(1j * (1+padLength) * 2*np.pi / self.fshape[1] * np.arange(self.fshape[1]))
            #fHtsFull[1::] = fHtsFull[1::][::-1]
            #fHtsFull2 = fftn(Hts[:,::-1], self.fshape)
            #print('part 2', np.max(np.abs(fHtsFull2-fHtsFull)))
            result2 = self.convolvePart3(projection,bb,self.Nnum-aa-1,Hts[:,::-1],fHtsFull,mirrorX)
            return result1+result2
        else:
            return result1

    def convolve(self, projection, bb, aa, Hts):
        cent = int(self.Nnum/2)

        if useSymmetries:
            # Full symmetry
            mirrorX = (bb != cent)
            mirrorY = (aa != cent)
            transpose = ((aa != bb) and (aa != (self.Nnum-bb-1)))
        else:
            mirrorX = False
            mirrorY = False
            transpose = False
            # Simplifying the transpose condition because we are not mirroring yet
            # Note: if we disable mirroring but keep transposing, the condition is:
            #  transpose = ((aa != bb))
            
        # TODO: it would speed things up if I could avoid computing the full fft for Hts.
        # However, it's not immediately clear to me how to fill out the full fftn array from rfftn
        # in the case of a 2D transform.
        # For 1D it's the reversed conjugate, but for 2D it's more complicated than that.
        # It's possible that it's actually nontrivial, in spite of the fact that
        # you can get away without it when only computing fft/ifft for real arrays)
        if useSymmetries:
            fHtsFull = scipy.fftpack.fft2(Hts, self.fshape)
        else:
            fHtsFull = scipy.fftpack.rfft2(Hts, self.fshape)
        result1 = self.convolvePart2(projection,bb,aa,Hts,fHtsFull,mirrorY,mirrorX)
        if transpose:
            if (self.fshape[0] == self.fshape[1]):
                fHtsFull = fHtsFull.transpose().copy()
            else:
                fHtsFull = scipy.fftpack.fft2(Hts.transpose(), self.fshape)
            # Note that mx,my need to be swapped following the transpose
            result2 = self.convolvePart2(projection,aa,bb,Hts.transpose(),fHtsFull,mirrorX,mirrorY) 
            return result1+result2
        else:
            return result1
    
def backwardProjectForZY(HtCCBB, bb, projection):
    tempSliceBack = None
    Nnum = HtCCBB.shape[0]
    if useSymmetries:
        projector = Projector(projection, HtCCBB, Nnum)
        fshape = projector.fshape    # TODO: these three are just for back-compatibility - should tidy up once all is settled
        fslice = projector.fslice
        s1 = projector.s1
        for aa in range(bb,int((Nnum+1)/2)):
            # Extract the part of Ht that represents this lenslet pixel
            Hts = HtCCBB[aa]
            fab = projector.convolve(projection, bb, aa, Hts)
            if (tempSliceBack is None):
                tempSliceBack = fab
            else:
                tempSliceBack += fab
    else:
        for aa in range(Nnum):
            # Extract the part of Ht that represents this lenslet pixel
            Hts = HtCCBB[aa]
            (fab, fshape, fslice, s1) = special_fftconvolve(projection, bb, aa, Nnum, Hts, newMethod=True)
            if (tempSliceBack is None):
                tempSliceBack = fab
            else:
                tempSliceBack += fab
    return (tempSliceBack, fshape, fslice, s1)
    
def backwardProjectForZ(HtCC, projection):
    tempSliceBack = None
    Nnum = HtCC.shape[1]
    if useSymmetries:
        r = range(int((Nnum+1)/2))
    else:
        r = range(Nnum)
    for bb in tqdm(r, leave=False, desc='Backward-project - y'):
        (result, fshape, fslice, s1) = backwardProjectForZY(HtCC[bb], bb, projection)
        if (tempSliceBack is None):
            tempSliceBack = result
        else:
            tempSliceBack += result
    return special_fftconvolve_part3(tempSliceBack, fshape, fslice, s1)


def forwardProjectForZ(HCC, realspaceCC):
    TOTALprojection = None
    Nnum = H.shape[2]
    if useSymmetries:
        r = range(int((Nnum+1)/2))
    else:
        r = range(Nnum)
    for bb in tqdm(r, leave=False, desc='Forward-project - y'):
        for aa in tqdm(range(Nnum), leave=False, desc='Forward-project - x'):
            # ******
            # ****** TODO: I am here, with this bit ready to update to match current back-projection code
            # ******
            # Extract the part of H that represents this lenslet pixel
            Hs = HCC[bb, aa]
            # Create a workspace representing just the voxels cc,bb,aa behind each lenslet (the rest is 0)
            tempspace = np.zeros((realspaceCC.shape[0], realspaceCC.shape[1]), dtype='float32');
            tempspace[bb::Nnum, aa::Nnum] = realspaceCC[bb::Nnum, aa::Nnum]  # ???? what to do about index ordering?
            # Compute how those voxels project onto the sensor, and accumulate
            (result, fshape, fslice, s1) = conv2(tempspace, Hs, 'same')

            if (TOTALprojection is None):
                TOTALprojection = fab
            else:
                TOTALprojection += fab
    return special_fftconvolve_part3(TOTALprojection, fshape, fslice, s1)
    


print('Done')

In [None]:
global gHt

def backwardProjectACC(Ht, projection, CAindex, planes=None):
    Backprojection = np.zeros((Ht.shape[0], projection.shape[0], projection.shape[1]), dtype='float32')
    # Iterate over each z plane
    if planes is None:
        planes = range(Ht.shape[0])
    if True:
        # Single-threaded over z
        for cc in tqdm(planes, desc='Back-project - z'):
            HtCC =  Ht[cc, :, :, CAindex[0,cc]-1:CAindex[1,cc], CAindex[0,cc]-1:CAindex[1,cc]]
            Backprojection[cc] = backwardProjectForZ(HtCC, projection)
    else:
        # Multithreaded over z
        work = []
        for cc in planes
            HtCC =  Ht[cc, :, :, CAindex[0,cc]-1:CAindex[1,cc], CAindex[0,cc]-1:CAindex[1,cc]]
#            work.append((HtCC, projection))
            gHt = Ht
            work.append((cc, CAindex, projection))

        results = Parallel(n_jobs=10) \
                    (delayed(backwardProjectForZ_global)(*args) for args in tqdm(work, desc='Back-project - z'))
        for cc in range(x3length):
            Backprojection[cc] = results[cc]
    return Backprojection

def forwardProjectACC(H, realspace, CAindex, planes=None):
    TOTALprojection = np.zeros((realspace.shape[1], realspace.shape[2]), dtype='float32')
    # Iterate over each z plane
    if planes is None:
        planes = range(realspace.shape[0])
    for cc in tqdm(planes, desc='Forward-project - z'):
        HCC = H[cc, :, :, CAindex[0,cc]-1:CAindex[1,cc], CAindex[0,cc]-1:CAindex[1,cc]]
        TOTALprojection += forwardProjectForZ(HCC, realspace[cc])
    return TOTALprojection

In [None]:
def forwardProjectForZ_old(HCC, realspaceCC):
    # Iterate over each lenslet pixel
    Nnum = H.shape[2]
    TOTALprojection = np.zeros((realspaceCC.shape[0], realspaceCC.shape[1]), dtype='float32');
    for bb in tqdm(range(Nnum), leave=False, desc='Forward-project - y'):
        for aa in tqdm(range(Nnum), leave=False, desc='Forward-project - x'):
            # Extract the part of H that represents this lenslet pixel
            Hs = HCC[bb, aa]
            # Create a workspace representing just the voxels cc,bb,aa behind each lenslet (the rest is 0)
            tempspace = np.zeros((realspaceCC.shape[0], realspaceCC.shape[1]), dtype='float32');
            tempspace[bb::Nnum, aa::Nnum] = realspaceCC[bb::Nnum, aa::Nnum]  # ???? what to do about index ordering?
            # Compute how those voxels project onto the sensor, and accumulate
            TOTALprojection += conv2(tempspace, Hs, 'same')
    return TOTALprojection
    
def backwardProjectForZ_old(HtCC, projection):
    tempSliceBack = np.zeros(projection.shape, dtype='float32')
    # Iterate over each lenslet pixel
    Nnum = HtCC.shape[1]
    for aa in tqdm(range(Nnum), leave=False, desc='y'):
        for bb in range(Nnum):
            # Extract the part of Ht that represents this lenslet pixel
            Hts = HtCC[bb, aa]
            # Create a workspace representing just the voxels cc,bb,aa behind each lenslet (the rest is 0)
            tempSlice = np.zeros(projection.shape, dtype='float32')
            tempSlice[bb::Nnum, aa::Nnum] = projection[bb::Nnum, aa::Nnum]
            # Compute how those voxels back-project from the sensor
            tempSliceBack += fftconvolve(tempSlice, Hts, 'same')
    return tempSliceBack
    
def backwardProjectACC_original(Ht, projection, CAindex, planes=None):
    Backprojection = np.zeros((x3length, projection.shape[0], projection.shape[1]), dtype='float32')
    # Iterate over each z plane
    if planes is None:
        planes = range(Ht.shape[0])
    for cc in tqdm(planes, desc='Back-project - z'):
        HtCC =  Ht[cc, :, :, CAindex[0,cc]-1:CAindex[1,cc], CAindex[0,cc]-1:CAindex[1,cc]]
        Backprojection[cc] = backwardProjectForZ_old(HtCC, projection)
    return Backprojection

In [None]:
def deconvRL(Htf, maxIter, Xguess):
    for i in range(maxIter):
        t0 = time.time()
        HXguess = forwardProjectACC(H, Xguess, CAindex)

        
        
        #print(HXguess.shape, HXguess.dtype)
        #tifffile.imsave('HXguess.tif', np.transpose(HXguess))        
        #return HXguess
        
        
        
        HXguessBack = backwardProjectACC(Ht, HXguess, CAindex)
        errorBack = Htf / HXguessBack
        Xguess = Xguess * errorBack
        Xguess[np.where(np.isnan(Xguess))] = 0
        ttime = time.time() - t0
        print('iter %d | %d, took %.1f secs' % (i+1, maxIter, ttime))
    return Xguess

In [None]:
# Load the input image
LFmovie = tifffile.imread('Data/02_Rectified/exampleData/20131219WORM2_small_full_neg_X1_N15_cropped_uncompressed.tif')
LFmovie = LFmovie.transpose()[np.newaxis,:,:]
print(LFmovie.shape)

In [None]:
maxIter = 8
#for frame in range(1):
frame = 0

LFIMG = LFmovie[frame].astype('float32')
t0 = time.time()
if True:
#    Htf = backwardProjectACC_original(Ht, np.tile(LFIMG,(2,2)), CAindex)
    Htf = backwardProjectACC_original(Ht, LFIMG, CAindex, planes=[4])
else:
#    myStats = cProfile.run('Htf = backwardProjectACC(Ht, np.tile(LFIMG,(2,2))[:900,:900], CAindex)', 'mystats')
    myStats = cProfile.run('Htf = backwardProjectACC(Ht, LFIMG, CAindex, planes=[4])', 'mystats')
    p = pstats.Stats('mystats')
    p.strip_dirs().sort_stats('cumulative').print_stats(40)

print('iter 0 | %d, took %.1f secs' % (maxIter, time.time()-t0))
# Used to take 6:08, 18.19s/it
# Now takes 2:16, 5.83s/it
# With multithreading it takes a total of 38.2s (i.e. speedup of 3.5x). So still significant scope for improvement.
# When I make Ht into a global variable it speeds up to 35.4s. 
# I think much of the problem is just that it's not distributing the work evenly enough across threads

# On macbook, took 2:34 single-threaded

# Macbook timings testing with back-projecting plane 4:
#  Original version took 20.0s
#  With reduced work, but no symmetries, takes 7.8s. I have eliminated the two most expensive FFTs so I
#  might have hoped for about 6s runtime, but there are still overheads in handling the replacement maths:
#   6.7s in special_fftconvolve, breaking down as:
#     1.35 in part1
#     4.3 in rfftn
#     1.0 on fa*fb
#   So, I can speed things up by using symmetries in rfftn, but there are still other overheads
#   The time in part1 is mostly taken up with expanding the second dimension (as you'd expect).
#   I suspect I could speed that up by doing the multiplication on-the-fly without ever tiling the fa matrix.
#   That would be a bit messy, but I should try it.


# TODO: since I am profiling, I should do so for the full z range to see what difference it makes,
# and also maybe consider doing it for a larger sensor size just to check whether that changes the relative importances.


In [None]:
tifffile.imsave('Htf_plane4_slow.tif', np.transpose(Htf*1e3, axes=(0,2,1)))

In [None]:
Xguess = Htf.copy();    # Or do it differently if !indepIter
# TODO: Probably don't need to copy it, but I am doing that now for safety

Xguess = deconvRL(Htf, maxIter, Xguess)
    
    
    
'''
Matlab (with multithreading) takes:
  iter 0 | 8, took 28.1 secs
    z=1 took 3.9252 secs
    z=2 took 3.7995 secs
    z=3 took 2.9881 secs
    z=4 took 2.7214 secs
    z=5 took 2.4903 secs
    z=6 took 2.2715 secs
    z=7 took 1.9398 secs
    z=8 took 1.7239 secs
    z=9 took 1.4613 secs
    z=10 took 1.2395 secs
    z=11 took 0.97674 secs
    z=12 took 0.83704 secs
    z=13 took 0.83821 secs
    z=14 took 0.85321 secs
 
 
My python code took 340 secs.
That's about 10x slower, which means I'm not too far off with multithreading.
With threading at the z level it took 83 seconds (~4x speedup compared to single-threading),
with much of the slowness due to not all cores being in use for the whole of the calculation time.
With threading at the b level it took 228 seconds (not much speedup compared to single-threading; not sure why)

I think parallelism ought to be a bit better in the case where I have more planes to reconstruct,
although I should probably sort them by matrix size so I start on the slowest ones first.
'''
0

In [None]:
Xguess_0 = Xguess.copy()
tifffile.imsave('iter8.tif', np.transpose(Xguess_0*1e3, axes=(0,2,1)))

In [None]:
# Xguess_1 - parallelized at z level
# Xguess_2 - parallelized at b level
# Xguess_3 - no parallelization

In [None]:
plt.imshow(Xguess_0[13])
plt.show()

plt.imshow(Xguess_1[13])
plt.show()
plt.imshow(Xguess_2[13])
plt.show()
plt.imshow(Xguess_3[13])
plt.show()

In [None]:
HXguess = forwardProjectACC(H, Xguess_1, CAindex)

In [None]:
a = np.random.random((300,300))
t1 = time.time()
for n in range(10):
    b = np.fft.rfftn(a, s=(1000,1000))
print(time.time()-t1)
t1 = time.time()
for n in range(10):
    b = np.fft.rfftn(a, s=(1000,1000),axes=(0,1))
print(time.time()-t1)
t1 = time.time()
for n in range(10):
    b = np.fft.rfftn(a, s=(1000,1000),axes=(1,0))
print(time.time()-t1)