In [None]:
import torch
import numpy as np

# Filters

## Weiner-Butterworth

In [None]:
def psf_WB_filter(psf, alpha, beta, pixelSize, depth_step, butterworth_order):
    # Assuming that there is only one peak in the lenslet
    # Frames is the input frames, given as a 

    if len(psf.shape) == 3:
        psf = psf.unsqueeze(0) # Add a dimension for the lenslets, since only one lenslet was given

    # input PSF size and center
    _, Sz, Sx, Sy = psf.shape
    Scx = (Sx+1)/2;
    Scy = (Sy+1)/2
    Scz = (Sz+1)/2
    Sox = round((Sx+1)/2)
    Soy = round((Sy+1)/2)
    Soz = round((Sz+1)/2)
    
    # Pixel size in Fourier domain
    px = 1/Sx
    py = 1/Sy
    pz = 1/Sz
    
    # Remember that the psf is in the format of [lenslet, depth, x, y]
    # Permute the PSF to the format of [lenslet, x, y, depth]
    psf = torch.permute(psf, (0, 2, 3, 1)) # [lenslet, x, y, depth]
    
    PSF_bp = torch.empty(psf.shape)
    
    psf_flip = torch.flip(psf, (1,2,3))
    OTF_flip = torch.fft.fftn(psf_flip, dim = (1,2,3))
    OTF_abs = torch.fft.fftshift(torch.abs(OTF_flip), dim = (1,2,3))
    M = torch.amax(OTF_abs, (1,2), keepdim = True) # find maximum value and position
    OTF_abs_norm = torch.div(OTF_abs, M)
    
    # Create Wiener filter
    OTF_flip_norm = OTF_flip/M
    OTF_Wiener = torch.div(OTF_flip_norm,(torch.pow(torch.abs(OTF_flip_norm),2) + alpha))
    
    # Calculate Cut-off Gain for Wiener Filter
    OTF_Wiener_abs = torch.fft.fftshift(torch.abs(OTF_Wiener))
    
    # tplane = abs(squeeze(OTF_Wiener_a   bs(:,:,Soz))); % central slice
    tplane = torch.abs(OTF_Wiener_abs[:,:,Soz]) # central slice
    tline, _ = torch.max(tplane, dim=0, keepdim=False) # Should return maximums along each row, call each depth
    
    w = np.power((np.tile(np.arange(0,Sx), (Sy,1)).T - Scx),2)+ np.power((np.tile(np.arange(0,Sy),(Sx,1)) - Scy),2)
    w = np.broadcast_to(w[...,None],w.shape+(Sz,)) + np.power(np.tile(np.arange(0,Sz).reshape(1,1,Sz),(Sx,Sy,1)) - Scz, 2) # repeat Sz in the 3rd dimension
    w = torch.from_numpy(w)
            
    if len(psf.shape) == 3:
        psf = psf.unsqueeze(0) # Add a dimension for the lenslets, since only one lenslet was given
    
    PSF_bp = torch.empty(psf.shape)
    mask = torch.empty(Sx,Sy,Sz)

    # Loop through all the lenslets
    for j in range(int(psf.shape[0])):
        # Loop through all the depths
        for i in range(int(psf.shape[3])):
            psf_numpy = psf[j,:,:,i].numpy()
            # Grab the x and y coordinates of the maximum value of the PSF
            [x,y] = np.where(psf_numpy == np.max(psf_numpy))
            
            # Grab the x and y slices of the PSF that contain the maximum value
            psf_x = psf_numpy[:,y]
            psf_y = psf_numpy[x,:]
            
            # Find the indices that are near the half-maximum of the PSF
            idx_x = (np.abs(psf_x - (psf_numpy[x,y]/2))).argmin()
            idx_y = (np.abs(psf_y - (psf_numpy[x,y]/2))).argmin()
            
            # Set the resolution cutoff by setting the resolution to the FWHM
            resx = np.abs(x - idx_x) * pixelSize
            resy = np.abs(y - idx_y) * pixelSize
            resz = depth_step # The MatLab code assumes that the resolution is the same in all directions
            
            # Frequency cutoff in terms of pixels
            tx = 1/resx/px
            ty = 1/resy/py
            tz = 1/resz/pz
            
            # to1 = max(round(Scx -tx), 1); to2 = min(round(Scx+tx), Sx);
            to1 = max(np.round(Scx - tx), 1)
            to2 = min(np.round(Scx + tx), Sx)
            
            # beta_wienerx = (tline(to1) + tline(to2))/2; % OTF frequency intensity at cutoff:x
            beta_wienerx = (tline[:,i][to1] + tline[:,i][to2])/2 # OTF frequency intensity at cutoff:x

            ee = beta_wienerx/(beta**2) - 1
            
            mask[:,:,i] = torch.div(1, torch.sqrt(1 + torch.mul(ee,(torch.pow(w[:,:,i],butterworth_order))))) # w^n = (kx/kcx)^pn
            mask[:,:,i] = torch.fft.ifftshift(mask[:,:,i]) # Butterworth Filter

        # Create Wiener-Butteworth Filter
        OTF_bp = torch.mul(mask,OTF_Wiener) # Final OTF_bp cutfoff gain: beta
            
        PSF_bp[j,:,:,:] = torch.fft.fftshift(torch.real(torch.fft.ifftn(OTF_bp)))# final OTF_bp cutfoff gain: beta
        
    return PSF_bp