# Generate a .mat file representing the light field PSF
This code is very closely modelled on Prevedel's original Matlab code (but with a bug fix for the z=0 plane)

## Before running
This code relies on the small `light_field_integrands` module. This can be installed by going to the `light-field-integrands` subfolder and running `python setup.py build install`

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.special, scipy.integrate, scipy.signal, scipy.misc
import h5py, sys, time, os, h5py, warnings, cProfile, pstats
from tqdm.notebook import tqdm as tqdm
import light_field_integrands

In [None]:
M = 40
NA = 0.95
MLPitch = 150e-6
Nnum = 15
# JT: Presumably this stands for 'oversampling ratio'? Although it doesn't appear in the PSF filename,
# this is the value that is stored in the .mat files I have been using (and is the default value for their code)
OSR = 3

n = 1.0
fml = 3000e-6
lam = 520e-9;
zmin = -26e-6
zmax = 0
zspacing = 2e-6

In [None]:
eqtol = 1e-10;

k = 2*np.pi*n/lam
k0 = 2*np.pi*1/lam
d = fml
ftl = 200e-3          #focal length of tube lens
fobj = ftl/M          # focal length of objective lens
fnum_obj = M/(2*NA)   # f-number of objective lens (imaging-side)
fnum_ml = fml/MLPitch # f-number of microlens

assert((Nnum%2)==1), 'Nnum must be an odd number'
assert((OSR%2)==1), 'OSR must be an odd number'

In [None]:
# JT: Load a .mat file generated by the actual Matlab code, for comparison with ours
def MatrixFileString():
    # Generates the long file string that the Matlab code generates based on the PSF parameters.
    # Currently, I just treat the variables as globals, rather than accepting them all as parameters to this function.
    return 'M%gNA%gMLPitch%gfml%gfrom%gto%gzspacing%gNnum%glambda%gn%g'%(M, NA, MLPitch*1e6, fml*1e6, zmin*1e6, zmax*1e6, zspacing*1e6, Nnum, lam*1e9, n)

def LoadRawMatrixData(matPath):
    # Load the matrices from the .mat file.
    # This is slow since they must be decompressed and are rather large! (9.5GB each, in single-precision FP)
    hReducedShape = []
    htReducedShape = []
    with h5py.File(matPath, 'r') as f:
        print('Load CAindex')
        sys.stdout.flush()
        _CAindex = f['CAindex'].value.astype('int')
        
        print('Load H')
        sys.stdout.flush()
        _H = f['H'].value.astype('float32')
        Nnum = _H.shape[2]
        aabbRange = int((Nnum+1)/2)        
        for cc in range(_H.shape[0]):
            HCC =  _H[cc, :aabbRange, :aabbRange, _CAindex[0,cc]-1:_CAindex[1,cc], _CAindex[0,cc]-1:_CAindex[1,cc]]
            hReducedShape.append(HCC.shape)

        print('Load Ht')
        sys.stdout.flush()
        _Ht = f['Ht'].value.astype('float32')

        for cc in range(_Ht.shape[0]):
            HtCC =  _Ht[cc, :aabbRange, :aabbRange, _CAindex[0,cc]-1:_CAindex[1,cc], _CAindex[0,cc]-1:_CAindex[1,cc]]
            htReducedShape.append(HtCC.shape)
        
    return (_H, _Ht, hReducedShape, htReducedShape, _CAindex)

In [None]:
def calcPSFFT(p3, fobj, NA, x1space, scale, lam, fml, M, n):     #√
    k = 2*np.pi*n/lam
    alpha = np.arcsin(NA/n)
    p1 = 0
    p2 = 0
    psfLine = np.zeros((len(x1space)))
    integrandCython_r = scipy.LowLevelCallable.from_cython(light_field_integrands, 'integrandPSF_r')
    integrandCython_i = scipy.LowLevelCallable.from_cython(light_field_integrands, 'integrandPSF_i')

    for a in tqdm(range(len(x1space))):
        x1 = x1space[a]
        x2 = 0
        xL2normsq = (((x1+M*p1)**2+(x2+M*p2)**2)**0.5)/M
        v = k*xL2normsq*np.sin(alpha)   
        u = 4*k*p3*(np.sin(alpha/2)**2)

        Koi = M/((fobj*lam)**2)*np.exp(-1j*u/(4*(np.sin(alpha/2)**2)))
        if False:
            # Old, slow pure python code, left here for reference
    #        integrand = @(theta) (sqrt(cos(theta))) .* (1+cos(theta))  .*  (exp(-(i*u/2)* (sin(theta/2).^2) / (sin(alpha/2)^2)))  .*  (besselj(0, sin(theta)/sin(alpha)*v))  .*  (sin(theta));
            integrand = lambda theta: (np.sqrt(np.cos(theta))) * (1+np.cos(theta))  \
                                        *  (np.exp(-(1j*u/2)* (np.sin(theta/2)**2) / (np.sin(alpha/2)**2))) \
                                        *  (scipy.special.jn(0, np.sin(theta)/np.sin(alpha)*v)) \
                                        *  (np.sin(theta))
    #        I0 = integral(@(theta)integrand (theta),0,alpha);  
            integrand_r = lambda theta: np.real(integrand(theta))
            integrand_i = lambda theta: np.imag(integrand(theta))
            # JT: I have bumped up the subdivision limit to 80 in order to silence warnings about problems
            # in the integration. However, I suspect it probably doesn't need this level of detail...
            I0_r2,err_r2 = scipy.integrate.quad(integrand_r, 0, alpha, limit=80)
            I0_i2,err_i2 = scipy.integrate.quad(integrand_i, 0, alpha, limit=80)
        if True:
            # New fast code (using cython for speed)
            alphaFactor = np.sin(alpha/2)**(-2)
            uOver2 = u/2
            vFactor = v/np.sin(alpha)
            I0_r,err_r = scipy.integrate.quad(integrandCython_r, 0, alpha, limit=180, args=(alphaFactor, uOver2, vFactor))
            I0_i,err_i = scipy.integrate.quad(integrandCython_i, 0, alpha, limit=180, args=(alphaFactor, uOver2, vFactor))
        
        I0 = (I0_r + 1j*I0_i)
        err = (err_r + 1j*err_i)
        psfLine[a] =  np.abs((Koi*I0)**2)
    return psfLine / np.max(psfLine)

In [None]:
pixelPitch = MLPitch/Nnum # pitch of virtual pixels

# JT: not sure why these first two are created as single-element arrays 
#     - maybe a feature that they never implemented?
x1objspace = np.array([0])
x2objspace = np.array([0])
x3objspace = np.arange(zmin, zmax+0.1*zspacing, zspacing)
objspace = np.ones((len(x1objspace),len(x2objspace),len(x3objspace)))
# JT: I am not completely sure why, but the code has to work with at least two different z coordinates.
# Having only one ultimately leads to division-by-zero in calculating IMGSIZE_REF_IL (because p3max==0),
# but I don't follow what IMGSIZE has to do with the total of z planes we have.
# I wonder if it might be something to do with having a generous estimate of how rapidly the PSF will spread
# as a function of z coordinate...
assert(len(x3objspace) > 1)

p3max = np.max(np.abs(x3objspace))
x1testspace = (pixelPitch/OSR) * np.arange(0, Nnum*OSR*20 +1) #√  [Matlab really does start at 0]
x2testspace = [0]   
psfLine = calcPSFFT(p3max, fobj, NA, x1testspace, pixelPitch/OSR, lam, d, M, n)

In [None]:
outArea = np.where(psfLine<0.04)[0]
if len(outArea) == 0:  #√ [checked that this logic works]
    raise('Estimated PSF size exceeds the limit');   
IMGSIZE_REF = int(np.ceil(outArea[0]/(OSR*Nnum)))

In [None]:
def calcML(fml, k, x1MLspace, x2MLspace, x1space, x2space):  #√
    x1length = len(x1space)
    x2length = len(x2space)
    x1MLdist = len(x1MLspace)
    x2MLdist = len(x2MLspace)
    # JT: the Matlab here is a very strange code construction, but its aim appears to be to identify
    # one (any) index in x1space that is ==0, and take that as one 'center' in x1.
    # It then constructs a list of indices that represents *all* the 'centers' in x1.
    # The code relies on the fact that (x1center: -x1MLdist:1) is defined in matlab to use 
    # only the value of the first element of the array x1center when generating the range.
    # original Matlab:
    #   x1center = find(x1space==0);
    #   x1centerALL = [  (x1center: -x1MLdist:1)  (x1center + x1MLdist: x1MLdist :x1length)];
    #   x1centerALL = sort(x1centerALL);
    x1center = np.where(x1space==0)[0][0]
    x1centerALL_p = np.append(np.arange(x1center, -1, -x1MLdist), \
                              np.arange(x1center+x1MLdist, x1length, x1MLdist)) #√ for python array indexing
    np.sort(x1centerALL_p)
    x2center = np.where(x2space==0)[0][0]
    x2centerALL_p = np.append(np.arange(x2center, -1, -x2MLdist), \
                              np.arange(x2center+x2MLdist, x2length, x2MLdist)) #√ for python array indexing
    np.sort(x2centerALL_p)

    patternML = np.zeros((len(x1MLspace), len(x2MLspace)), dtype='complex128')
    patternMLcp = np.zeros((len(x1MLspace), len(x2MLspace)), dtype='complex128')
    for a in range(len(x1MLspace)):
        for b in range(len(x2MLspace)):
            x1 = x1MLspace[a]
            x2 = x2MLspace[b]
            xL2norm = x1**2 + x2**2
            patternML[a,b] = np.exp(-1j*k/(2*fml)*xL2norm)
            patternMLcp[a,b] = np.exp(-0.05*1j*k/(2*fml)*xL2norm) 
    MLcenters = np.zeros((len(x1space), len(x2space)))
    for a in range(len(x1centerALL_p)):
        for b in range(len(x2centerALL_p)):
            MLcenters[x1centerALL_p[a], x2centerALL_p[b]] = 1
    MLARRAY = scipy.signal.fftconvolve(MLcenters.astype('complex128'), patternML, 'same')
    return MLARRAY

In [None]:
print('Size of PSF ~= {0} [microlens pitch]'.format(IMGSIZE_REF))
IMG_HALFWIDTH = np.maximum(Nnum*(IMGSIZE_REF + 1), 2*Nnum)
print('Size of IMAGE = {0}x{1}'.format(IMG_HALFWIDTH*2*OSR+1, IMG_HALFWIDTH*2*OSR+1))
x1space = (pixelPitch/OSR)*np.arange(-IMG_HALFWIDTH*OSR, IMG_HALFWIDTH*OSR+0.1, 1);   #√? not sure if this is array indexing
x2space = (pixelPitch/OSR)*np.arange(-IMG_HALFWIDTH*OSR, IMG_HALFWIDTH*OSR+0.1, 1); 
x1length = len(x1space)
x2length = len(x2space)

x1MLspace = (pixelPitch/OSR)* np.arange(-(Nnum*OSR-1)/2 , (Nnum*OSR-1)/2+0.1, 1)
x2MLspace = (pixelPitch/OSR)* np.arange(-(Nnum*OSR-1)/2 , (Nnum*OSR-1)/2+0.1, 1)
x1MLdist = len(x1MLspace)
x2MLdist = len(x2MLspace)

#%%%%%%%%%%%%%%%%%% FIND NON-ZERO POINTS %%%%%%%%%%%%%%%%%%%%%%%%%%
validpts = np.where(objspace>eqtol)
numpts = len(validpts[0])
# Matlab code:
#  [p1indALL p2indALL p3indALL] = ind2sub( size(objspace), validpts);
#  p1ALL = x1objspace(p1indALL)';
(p1indALL, p2indALL, p3indALL) = validpts
p1ALL = x1objspace[p1indALL]
p2ALL = x2objspace[p2indALL]
p3ALL = x3objspace[p3indALL]

#%%%%%%%%%%%%%%%%%%%%%%%% DEFINE ML ARRAY %%%%%%%%%%%%%%%%%%%%%%%%% 
MLARRAY = calcML(fml, k0, x1MLspace, x2MLspace, x1space, x2space)

#%%%%%%%%%%%%%%%%%%%%%% Alocate Memory for storing PSFs %%%%%%%%%%%   
LFpsfWAVE_STACK = np.zeros((x1length, x2length, numpts), dtype='complex128')
psfWAVE_STACK = np.zeros((x1length, x2length, numpts), dtype='complex128')

# Note: if, when this cell is run, a warning appears about multidimensional indexing,
# this is due to an internal issue in scipy (which I think can be fixed by upgrading to the latest scipy).

# Part 1: "Projection from single point"

In [None]:
def fresnel2D(f0,dx0,z,lam):  #√
    (Nx,Ny) = f0.shape
    k = 2*np.pi/lam

    du = 1/(Nx*dx0)
    u = np.append(np.arange(0,np.ceil(Nx/2)), np.arange(np.ceil(-Nx/2),0))*du  #√
    dv = 1/(Ny*dx0)
    v = np.append(np.arange(0,np.ceil(Ny/2)), np.arange(np.ceil(-Ny/2),0))*dv  #√

    #√ think I checked this.
    #(although there is probably a much more legible way to do this in python with meshgrid or similar,
    # and indeed with fftshift as well!)
    H = np.exp(-1j*2*np.pi**2 * (np.tile(u[:,np.newaxis],(1,len(v)))**2+np.tile(v,(len(u),1))**2)*z/k)  
    f1 = np.exp(1j*k*z)*np.fft.ifft2(np.fft.fft2(f0) * H )
    dx1 = dx0
    x1 = np.arange(-Nx/2,Nx/2)*dx1
    return f1,dx1,x1                              

In [None]:
def calcPSF_p(p1, p2, p3, fobj, NA, x1space, x2space, scale, lam, MLARRAY, fml, M, n, centerArea_p):  #√
    k = 2*np.pi*n/lam
    alpha = np.arcsin(NA/n)
    x1length = len(x1space)
    x2length = len(x2space)
    zeroline = np.zeros(len(x2space), dtype='complex128')

    pattern = np.zeros((x1length, x2length), dtype='complex128')
    centerPT_m = int(np.ceil(len(x1space)/2))     #√Matlab indexing
    integrandCython_r = scipy.LowLevelCallable.from_cython(light_field_integrands, 'integrandPSF_r')
    integrandCython_i = scipy.LowLevelCallable.from_cython(light_field_integrands, 'integrandPSF_i')
    
    for a in tqdm(range(centerArea_p[0],centerPT_m), leave=False):   #√
        patternLine = zeroline.copy()
        for b in range(a,centerPT_m):  #√
            x1 = x1space[a]
            x2 = x2space[b]
            xL2normsq = (((x1+M*p1)**2+(x2+M*p2)**2)**0.5)/M

            v = k*xL2normsq*np.sin(alpha)
            u = 4*k*(p3*1)*(np.sin(alpha/2)**2)
            Koi = M/((fobj*lam)**2)*np.exp(-1j*u/(4*(np.sin(alpha/2)**2)))
            tol = 1e-15
            if False:
                # Old, slow pure python code for reference
                #intgrand = @(theta) (sqrt(cos(theta))) .* (1+cos(theta))  .*  (exp(-(i*u/2)* (sin(theta/2).^2) / (sin(alpha/2)^2)))  .*  (besselj(0, sin(theta)/sin(alpha)*v))  .*  (sin(theta));
                #I0 = integral(@(theta)intgrand (theta),0,alpha);  
                def integrand(theta, alpha, u, v):
                    return (np.sqrt(np.cos(theta))) * (1+np.cos(theta))  \
                                        *  (np.exp(-(1j*u/2)* (np.sin(theta/2)**2) / (np.sin(alpha/2)**2))) \
                                        *  (scipy.special.jn(0, np.sin(theta)/np.sin(alpha)*v)) \
                                        *  (np.sin(theta))
                integrand_r = lambda theta: np.real(integrand(theta, alpha, u, v))
                integrand_i = lambda theta: np.imag(integrand(theta, alpha, u, v))
                I0_r2,err_r = scipy.integrate.quad(integrand_r, 0, alpha, limit=180,epsabs=tol,epsrel=tol)
                I0_i2,err_i = scipy.integrate.quad(integrand_i, 0, alpha, limit=180,epsabs=tol,epsrel=tol)
#                print(I0_r2, I0_i2, err_r)
            if True:
                # New fast code
                # JT: note that I see strangely incorrect results with tol=1e-10 (see "digression" cell below).
                alphaFactor = np.sin(alpha/2)**(-2)
                uOver2 = u/2
                vFactor = v/np.sin(alpha)
                I0_r,err_r = scipy.integrate.quad(integrandCython_r, 0, alpha, args=(alphaFactor, uOver2, vFactor),limit=180,epsabs=tol,epsrel=tol)
                I0_i,err_i = scipy.integrate.quad(integrandCython_i, 0, alpha, args=(alphaFactor, uOver2, vFactor),limit=180,epsabs=tol,epsrel=tol)
#                print(I0_r, I0_i, err_r)

            I0 = (I0_r + 1j*I0_i)
            err = (err_r + 1j*err_i)
            patternLine[b] = Koi*I0
        pattern[a,:] = patternLine

    patternA = pattern[0:centerPT_m, 0:centerPT_m];   #√
    patternAt = np.fliplr(patternA)

    pattern3D = np.zeros((pattern.shape[0], pattern.shape[1], 4), dtype='complex128');
    pattern3D[:,:,0] = pattern;
    pattern3D[:centerPT_m, centerPT_m-1:,0] = patternAt   #√
    # JT: empirically, this does rotate in the same direction as matlab (when the indexing order
    # is identical in both cases). However, it shouldn't matter because we consider all four rotations
    # and take the maximum!
    pattern3D[:,:,1] = np.rot90( pattern3D[:,:,0] , -1)
    pattern3D[:,:,2] = np.rot90( pattern3D[:,:,0] , -2)
    pattern3D[:,:,3] = np.rot90( pattern3D[:,:,0] , -3)
    # JT: unfortunately it is a pain to do the simple 'max' in python.
    # Matlab takes the maximum abs(z), whereas python silently takes the maximum real(z).
    # I can't see an obvious and tidy way to code what I need in python.
    # I think the following should work as a quick bodge, and this shouldn't be a bottleneck
    pattern = pattern3D[:,:,0].copy()
    pattern[pattern == 0] = pattern3D[:,:,1][pattern == 0]
    pattern[pattern == 0] = pattern3D[:,:,2][pattern == 0]
    pattern[pattern == 0] = pattern3D[:,:,3][pattern == 0]

#    %%%%%%%%%%%%%%%%%%% CALCULATED LF PSF %%%%%%%%%%%%%%%%%%%%%%%%%%%
    f1,dx1,x1 = fresnel2D(pattern*MLARRAY, scale, 1*fml, lam)

    return pattern, f1, pattern3D

In [None]:
centerPT_m = int(np.ceil(len(x1space)/2)) #√
halfWidth =  Nnum*(IMGSIZE_REF + 0 )*OSR
centerArea_p = np.arange(np.maximum((centerPT_m - halfWidth),1)-1,          #√
                       np.minimum((centerPT_m + halfWidth),len(x1space)-1))

warnings.resetwarnings()
for eachpt in tqdm(range(numpts), desc='Computing PSFs'):
    p1 = p1ALL[eachpt]
    p2 = p2ALL[eachpt]
    p3 = p3ALL[eachpt]
    
    IMGSIZE_REF_IL = np.ceil(IMGSIZE_REF*( np.abs(p3)/p3max))
    halfWidth_IL =  np.maximum(Nnum*(IMGSIZE_REF_IL + 0 )*OSR, 2*Nnum*OSR)
    centerArea_IL_p = np.arange(np.maximum((centerPT_m - halfWidth_IL),1)-1,
                              np.minimum((centerPT_m + halfWidth_IL),len(x1space)), dtype=np.int)   #√
    print('Plane {0}: size of center area = {1}x{2}'.format(eachpt, len(centerArea_IL_p), len(centerArea_IL_p)))
    
    # excute PSF computing function
    if True:
        t1 = time.time()
        psfWAVE, LFpsfWAVE, pattern3D = calcPSF_p(p1, p2, p3, fobj, NA, x1space, x2space, pixelPitch/OSR, lam, MLARRAY, d, M, n,  centerArea_IL_p)
        psfWAVE_STACK[:,:,eachpt]  = psfWAVE
        LFpsfWAVE_STACK[:,:,eachpt]= LFpsfWAVE
        print('Plane {0} took {1}'.format(eachpt, time.time()-t1))
    else:
        warnings.warn('Not actually computing PSF!')

### Digression: accuracy of integration
Strangely, and rather worryingly, scipy.integrate.quad seems to misbehave with certain very specific inputs. As demonstrated below, if the tolerance is 1e-10 the returned result can be wrong by 2% (despite reporting that the error is ~1e-10). I don't understand enough about what it is doing to know why on earth this might be happening! I have increased the tolerance to 1e-15 and that seems to have made the problem go away, but it is still a little worrying not to understand why it is happening (and whether 1e-15 is definitely safe under all circumstances...).

In [None]:
if False:
    def DemonstrateProblem(tol, thetas, vals):
        alpha = 1.253235897503375
        u = -432.1261323834447
        v = 686.3628350636566
        def integrand(theta, alpha, u, v, thetas, vals):
            result = ((np.sqrt(np.cos(theta))) * (1+np.cos(theta))  \
                                *  (np.exp(-(1j*u/2)* (np.sin(theta/2)**2) / (np.sin(alpha/2)**2))) \
                                *  (scipy.special.jn(0, np.sin(theta)/np.sin(alpha)*v)) \
                                *  (np.sin(theta))).real
            thetas.append(theta)
            vals.append(result)
            return result
        I0_r,err_r = scipy.integrate.quad(lambda theta:integrand(theta, alpha, u, v, thetas, vals), 0, alpha, limit=180,epsabs=tol,epsrel=tol)
        print(tol, I0_r, err_r)

    thetas = []
    vals = []
    for tol in [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15]:
        DemonstrateProblem(tol, thetas, vals)

    thetas10 = []
    vals10 = []
    DemonstrateProblem(1e-10, thetas10, vals10)    
    thetas11 = []
    vals11 = []
    DemonstrateProblem(1e-11, thetas11, vals11)    

    def PlotForOrder(thetas, vals, line=True, dots=True, new=True):
        order = np.argsort(thetas)
        temp1 = np.array(thetas)[order]
        temp2 = (np.array(vals).real)[order]
        if new is True:
            plt.figure(figsize=(20,10))
        if line is True:
            plt.plot(temp1, temp2)
        if dots is True:
            plt.plot(temp1, temp2, '.')

    # There is a range from 0.6 to 0.8 where it really does not sample the function much.
    # I don't know how the algorithm is meant to work, but it seems rather implausible to me that
    # it could be possible to be that confident in the integral when there is so much going on in the function
    # that has not been sampled by the integrator at all!
    for lim in np.arange(0, 1.2, 0.1):
        PlotForOrder(thetas11, vals11, True, False)
        PlotForOrder(thetas10, vals10, False, True, False)
        plt.xlim(lim, lim+0.1)
        plt.show()



# Part 2: Compute light field PSFs

In [None]:
def im_shift2(img, SHIFTX, SHIFTY):  #√
    eqtol = 1e-10
    assert (np.abs(SHIFTX%1)<eqtol and np.abs(SHIFTY%1)<eqtol), 'SHIFTX and SHIFTY should be integer numbers'

    SHIFTX = int(round(SHIFTX))
    SHIFTY = int(round(SHIFTY))
    new_im = np.zeros_like(img);

    # JT: logic for here: 0:end-SHIFTX in matlab would skip final element if SHIFTX=-1
    # In python, :-1 would skip final element too, so :-SHIFTX will do the trick.
    # However, care is needed to cope with SHIFTX=0, hence the use of endx
    endx,endy = img.shape
    if SHIFTX >=0 and SHIFTY >= 0:
        new_im[SHIFTX:, SHIFTY:] = img[:endx-SHIFTX, :endy-SHIFTY]
    elif SHIFTX >=0 and SHIFTY < 0:
        new_im[SHIFTX:, :SHIFTY] = img[:endx-SHIFTX, -SHIFTY:]
    elif SHIFTX <0 and SHIFTY >= 0:
        new_im[:SHIFTX, SHIFTY:] = img[-SHIFTX:, :endy-SHIFTY]
    else:
        new_im[:SHIFTX, :SHIFTY] = img[-SHIFTX:, -SHIFTY:]

    return new_im

In [None]:
def pixelBinning(SIMG, OSR):  #√
    assert((OSR % 2) == 1)   # Should already be caught at top of script, but repeat the check here to be sure.
    x1length, x2length = SIMG.shape

    x1center_m = int((x1length-1)/2 + 1)   #√ I think (though I am only assuming I need to cast to int here)
    x2center_m = int((x2length-1)/2 + 1)
    x1centerinit_m = x1center_m - int((OSR-1)/2)
    x2centerinit_m = x2center_m - int((OSR-1)/2)
    x1init_m = x1centerinit_m -  int(np.floor(x1centerinit_m/OSR)*OSR)
    x2init_m = x2centerinit_m -  int(np.floor(x2centerinit_m/OSR)*OSR)

    x1shift = 0
    x2shift = 0
    if x1init_m<1:
        x1init_m += OSR
        x1shift = 1
    if x2init_m<1:
        x2init_m += OSR
        x2shift = 1

    # JT: commented out in MATLAB code:  SIMG_crop = SIMG( (x1init:1:end-OSR+1), (x2init:1:end-OSR+1) );
    # JT: commented out in MATLAB code:  SIMG_crop = SIMG_crop( (1:1: floor(size(SIMG_crop,1)/OSR)*OSR) ,  (1:1: floor(size(SIMG_crop,2)/OSR)*OSR) );
    halfWidth = len(range(x1init_m,x1center_m-1+1))  #√
    # JT: not sure why this is split into multiple separate ranges that are then concatenated:
    # Matlab: SIMG_crop = SIMG( [ (x1init:x1center-1) x1center x1center+1:x1center+halfWidth ],  [ (x2init:x2center-1) x2center x2center+1:x2center+halfWidth ] );
    SIMG_crop = SIMG[ x1init_m-1:x1center_m+halfWidth,  #√
                      x2init_m-1:x2center_m+halfWidth]

#    %%%%%%%%%%%%%%%%%% PIXEL BINNING  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    # JT: I am not totally certain I am doing the same as the matlab,
    # but this achieves what I think I would expect the matlab to do!
    m,n = SIMG_crop.shape 
    SIMG_crop = np.reshape(SIMG_crop, (int(m/OSR), OSR, int(n/OSR), OSR))
    OIMG = np.sum(SIMG_crop, axis=(1,3))
    #SIMG_crop = sum( reshape(SIMG_crop,OSR,[]) ,1 );
    #SIMG_crop=reshape(SIMG_crop,m/OSR,[]).'; %Note transpose
    #SIMG_crop=sum( reshape(SIMG_crop,OSR,[]) ,1);
    #OIMG =reshape(SIMG_crop,n/OSR,[]).'; %Note transpose

    return OIMG, x1shift, x2shift

In [None]:
x1objspace = (pixelPitch/M)*np.arange(-np.floor(Nnum/2), np.floor(Nnum/2)+0.1, 1)
x2objspace = x1objspace.copy()
XREF = np.ceil(len(x1objspace)/2)
YREF = np.ceil(len(x1objspace)/2)
CP_p = np.arange((centerPT_m-1)/OSR - halfWidth/OSR, (centerPT_m-1)/OSR + halfWidth/OSR +0.1, 1, dtype=np.int) #√For Python indexing
H = np.zeros((len(CP_p), len(CP_p), len(x1objspace), len(x2objspace), len(x3objspace) ))

if True:
    for i in tqdm(range(len(x1objspace)*len(x2objspace)*len(x3objspace)), desc='Computing LF PSFs'):
        (a, b, c) = np.unravel_index(i, (len(x1objspace), len(x2objspace), len(x3objspace)))
        psfREF = psfWAVE_STACK[:,:,c]
        psfSHIFT = im_shift2(psfREF, OSR*(a+1-XREF), OSR*(b+1-YREF) )   #√ switched to allow for a,b in python
        f1,dx1,x1 = fresnel2D(psfSHIFT*MLARRAY, pixelPitch/OSR, d,lam)
        f1 = im_shift2(f1, -OSR*(a+1-XREF), -OSR*(b+1-YREF) )     #√ switched to allow for a,b in python

        xmin_p =  np.maximum( centerPT_m-1  - halfWidth, 0) #√
        xmax_p =  np.minimum( centerPT_m-1  + halfWidth, f1.shape[0]-1) #√
        ymin_p =  np.maximum( centerPT_m-1  - halfWidth, 0) #√
        ymax_p =  np.minimum( centerPT_m-1  + halfWidth, f1.shape[1]-1) #√

        f1_AP = np.zeros_like(f1)
        f1_AP[xmin_p:xmax_p+1,ymin_p:ymax_p+1] = f1[xmin_p:xmax_p+1,ymin_p:ymax_p+1]   #√
        [f1_AP_resize, x1shift, x2shift] = pixelBinning(np.abs(f1_AP**2), OSR)      
        # JT: I had to split this up into two separate commands to make it work in Python
        temp = f1_AP_resize[ CP_p - x1shift, : ]   #√
        f1_CP = temp[ :, CP_p-x2shift ]   #√
        H[:,:,a,b,c] = f1_CP
    # Take a copy of H before the maximum was calculated
    H_premax = H.copy()
    Hmax = np.max(H)
    H = H/Hmax
else:
    warnings.warn('Skipping LF PSF calculation, and using H from .mat file')
    H = _H.T.copy()

In [None]:
x1space = (pixelPitch/1)*np.arange(-IMG_HALFWIDTH*1, IMG_HALFWIDTH*1+0.1, 1);
x2space = x1space.copy()
x1space = x1space[CP_p]
x2space = x2space[CP_p]

if True:
    # Force very small values (in each separate plane) to zero
    tol = 0.005
    # JT TODO: I think this may be slow due to the copying back at the end.
    # I am almost certain that that is unnecessary in python, and could be removed.
    # (or actually, is it really any sort of bottleneck? I think not...)
    for i in tqdm(range(H.shape[4]), desc='clipping to zero'):
        H4Dslice = H[:,:,:,:,i]
        H4Dslice[H4Dslice < (tol*np.max(H4Dslice))] = 0
        H[:,:,:,:,i] = H4Dslice
else:
    warnings.warn('Not clipping to zero')

In [None]:
H = H.astype('float32')

#%%%%%%%%%%%%%%%%% Estimate PSF size again  %%%%%%%%%%%%%%%%%%%%%%%
# JT: I *DO* want to save CAindex in 1-based MATLAB indexing.
#     My python deconvolution code expects that (since we are just loading .mat files...),
#     and my deconvolution code will take that into account.
centerCP_m = np.ceil(len(CP_p)/2)
CAindex = np.zeros((2,len(x3objspace)), dtype='int')
for i in range(len(x3objspace)):
    IMGSIZE_REF_IL = np.ceil(IMGSIZE_REF*( np.abs(x3objspace[i])/p3max))
    halfWidth_IL =  np.maximum(Nnum*(IMGSIZE_REF_IL + 0 ), 2*Nnum)
    CAindex[0,i] = np.maximum( centerCP_m - halfWidth_IL , 1)
    CAindex[1,i] = np.minimum( centerCP_m + halfWidth_IL , H.shape[0])

In [None]:
# Free up memory from variables we have now finished with
if False:
    del f1
    del f1_AP
    del f1_AP_resize
    del f1_CP
    del psfREF
    del psfSHIFT
    del LFpsfWAVE_STACK
    del psfWAVE_STACK
else:
    warnings.warn('Not freeing up variables')
    # For small |z| these array sizes are not that big.
    # It's possible, though, that at larger |z| these really do take up a substantial amount of space
    def PrintSizes(*args):
        for a in args:
            print(a.size*a.itemsize)
    PrintSizes(f1, f1_AP, f1_AP_resize, f1_CP, psfREF, psfSHIFT, LFpsfWAVE_STACK, psfWAVE_STACK)

# Part 3: calculate Ht
This appears to be correct now. If I run this code here, starting with H loaded from a real Matlab .mat file, the result I get for Ht (or, at least, for certain PSFs selected from it) is basically identical to the Ht stored in the original .mat file.

In [None]:
# This cell contains old code, now superseded by MUCH faster code I have written to do the same thing
# (there is no need to do the convolution - or indeed the longhand rotation - to get the elements we need for Ht)
def Rotate180(img):
    h = img.shape[0]
    rot_img = np.zeros_like(img)
    for i in range(h):
        rot_img[i] = img[h-i-1,::-1]
    return rot_img

def backwardProject(H, projection, Nnum, ccRange=None):
    x3length = H.shape[4]
    if (ccRange is None):
        ccRange = range(x3length)
    Backprojection = np.zeros((projection.shape[0], projection.shape[0], x3length))
    for aa in tqdm(range(Nnum), leave=False):
        for bb in range(Nnum):
            for cc in ccRange:
                Ht = Rotate180(H[:,:,aa,bb,cc])
                tempSlice = scipy.signal.fftconvolve(projection, Ht, 'same')
                Backprojection[aa::Nnum, bb::Nnum, cc] += tempSlice[aa::Nnum, bb::Nnum]
    return Backprojection

def calcHt(H):
    Hsize1,_,Nnum,_,x3length = H.shape
    tmpsize = int(np.ceil(H.shape[0]/Nnum))
    if ((tmpsize%2) == 1):
        imgsize = (tmpsize+2)*Nnum;
    else:
        imgsize = (tmpsize+3)*Nnum,

    zeroprojection = np.zeros((imgsize, imgsize))
    imcenter_m = int(np.ceil(imgsize/2))
    imcenterinit_m = imcenter_m - int(np.ceil(Nnum/2))

    Ht = np.zeros_like(H)
    for aa in tqdm(range(Nnum)):
        for bb in tqdm(range(Nnum), leave=False):
            temp = zeroprojection.copy()
            temp[imcenterinit_m+aa, imcenterinit_m+bb] = 1  #√
            tempback = backwardProject(H, temp, Nnum)
            tempback_cut = tempback[imcenter_m - int((Hsize1-1)/2) - 0*Nnum - 1 : imcenter_m + int((Hsize1-1)/2) + 0*Nnum, 
                                    imcenter_m - int((Hsize1-1)/2) - 0*Nnum - 1 : imcenter_m + int((Hsize1-1)/2) + 0*Nnum]#√
            tempback_shift = np.zeros_like(tempback_cut)
            for cc in range(x3length):
                Ht[:,:,aa,bb,cc] = im_shift2(tempback_cut[:,:,cc], int(np.ceil(Nnum/2)-aa-1), int(np.ceil(Nnum/2)-bb-1) ) #√
    return Ht

In [None]:
# This cell contains my new code (way faster)
def backwardProject_new(H, projection, Nnum, imcenterinit_mp, _aa, _bb):
    # Note that with imcenterinit_mp the _mp suffix is a reminder that the same number can
    # work for both python and matlab. Since we add aa to it to get a pixel index, that works
    # in either language, since aa will start from 0 and 1 in the respective languages.
    x3length = H.shape[4]
    Backprojection2 = np.zeros((projection.shape[0], projection.shape[0], x3length))
    imcenter = imcenterinit_mp + int(Nnum/2)
    # Original code convolves a point at _aa,_bb with the rotated H
    for aa in range(Nnum):
        for bb in range(Nnum):
            # No need to call Rotate180 - we can just mirror the axes and that has the same effect
            _Ht = H[::-1,::-1,aa,bb]
            # Identify the indices in _Ht that we will be keeping
            # i.e. the ones that, in the original code, would actually be 
            # sampled by the aa::Nnum indexing of tempSlice.
            # In this calculation, note that the central pixel of _Ht 
            # will be indexed at a multiple of Nnum
            HtCentA = int(_Ht.shape[0]/2)
            HtCentB = int(_Ht.shape[1]/2)
            assert((HtCentA % Nnum) == 0)
            aStart = (-_aa + aa) % Nnum
            bStart = (-_bb + bb) % Nnum
            # The next question is where these should go in Backprojection2. 
            # We know that the *middle pixel* of _Ht goes at imcenterinit+_aa,_bb in Backprojection2.
            # It therefore follows that pixel 'aStart' should go
            # at the following coordinate in Backprojection2:
            aStartDest = imcenterinit_mp+_aa - HtCentA + aStart
            bStartDest = imcenterinit_mp+_bb - HtCentB + bStart
            strided = _Ht[aStart::Nnum, bStart::Nnum]
            aEndDest = aStartDest + strided.shape[0]*Nnum
            bEndDest = bStartDest + strided.shape[1]*Nnum
            Backprojection2[aStartDest:aEndDest:Nnum, bStartDest:bEndDest:Nnum] = strided
    return Backprojection2

def calcHt_new(_H):
    # Not currently working - I just naively tried using BackwardProject
    # (having previously set up Ht for the matrix to be a rotated version of H),
    # but this does not give anything remotely resembling the correct result for Ht!
    Hsize1,_,Nnum,_,x3length = _H.shape
    tmpsize = int(np.ceil(_H.shape[0]/Nnum))
    if ((tmpsize%2) == 1):
        imgsize = (tmpsize+2)*Nnum;
    else:
        imgsize = (tmpsize+3)*Nnum,

    zeroprojection = np.zeros((imgsize, imgsize))
    imcenter_m = int(np.ceil(imgsize/2))
    imcenterinit_m = imcenter_m - int(np.ceil(Nnum/2))

    Ht = np.zeros_like(_H)
    for aa in tqdm(range(Nnum)):
        for bb in tqdm(range(Nnum), leave=False):
            temp = zeroprojection.copy().astype('float32')
            temp[imcenterinit_m+aa, imcenterinit_m+bb] = 1  #√ same arithmetic works for me, because aa starts at 0 instead of 1
            tempback = backwardProject_new(H, temp, Nnum, imcenterinit_m, aa, bb)
            tempback_cut = tempback[imcenter_m - int((Hsize1-1)/2) - 0*Nnum - 1 : imcenter_m + int((Hsize1-1)/2) + 0*Nnum, 
                                    imcenter_m - int((Hsize1-1)/2) - 0*Nnum - 1 : imcenter_m + int((Hsize1-1)/2) + 0*Nnum]#√
            tempback_shift = np.zeros_like(tempback_cut)
            for cc in range(x3length):
                Ht[:,:,aa,bb,cc] = im_shift2(tempback_cut[:,:,cc], int(np.ceil(Nnum/2)-aa-1), int(np.ceil(Nnum/2)-bb-1) ) #√
    return Ht

In [None]:
#%%%%%%%%%%%% Calculate Ht (transpose for backprojection) %%%%%%%%%
print('Computing Transpose (3/3)')
if False:
    Ht = calcHt(H)
elif True:
    # JT: my new code, massively faster
    Ht = calcHt_new(H)
else:
    warnings.warn('Not computing transpose!')
    Ht = H.copy()
Ht = Ht.astype('float32')

# Save the matrices we have generated

In [None]:
def SaveMatrices(_H, _Ht, _CAindex, Nnum, matPathStem):
    # Save a .mat file just like the Matlab code does.
    # In fact, my deconvolution code will take that and convert it to my own format,
    # but to avoid confusion(?) I just generate the .mat file here.
    # My deconvolution code will auto-generate the files it actually needs, when it sees they don't exist yet.
    matPath = '%s_%s.mat'%(matPathStem, MatrixFileString())
    print('Saving to', matPath)
    with h5py.File(matPath, 'w') as f:
        print('Write parameters')
        f['M'] = M
        f['NA'] = NA
        f['MLPitch'] = MLPitch
        f['Nnum'] = 15
        f['OSR'] = OSR
        f['n'] = n
        f['fml'] = fml
        f['lambda'] = lam
        f['zmax'] = zmax
        f['zmin'] = zmin
        f['zspacing'] = zspacing
        
        f['fobj'] = fobj
        f['d'] = d
        f['x3objspace'] = x3objspace
        f['pixelPitch'] = pixelPitch
        f['CAindex'] = CAindex

        f['Hmax'] = Hmax   # JT: maximum value of H prior to normalization. Useful for fusing multiple z ranges.
        
        # Matlab works with Fortran-contiguous arrays, and writes them to disk as such.
        # Consequently, if I want my arrays (same indexing order but different contiguity)
        # to look exactly the same when saved to disk, I need to save their transpose.
        print('Write H'); sys.stdout.flush()
        print('H shape', H.shape)
        f['H'] = H.T
        print('Write Ht'); sys.stdout.flush()
        f['Ht'] = Ht.T
        print('Ht shape', Ht.shape)
        print('Done'); sys.stdout.flush()

SaveMatrices(H, Ht, CAindex, Nnum, 'PSFMatrix/expt')

# Check my results against the Matlab ones

In [None]:
# Theirs:
(_H1, _Ht1, hReducedShape1, htReducedShape1, _CAindex1) = LoadRawMatrixData('/Volumes/Development/SIsoftware/PSFmatrix/PSFmatrix_%s.mat'%MatrixFileString())
# Mine:
(_H2, _Ht2, hReducedShape2, htReducedShape2, _CAindex2) = LoadRawMatrixData('/Volumes/Development/light-field-flow/PSFmatrix/expt_%s.mat'%MatrixFileString())

In [None]:
print(np.max(np.abs(_H1 - _H2)))
print(np.max(np.abs(_Ht1 - _Ht2)))