# FLFM Forward and Backward Projector Generator
__Requirements__
- Dependencies
    - OpenCV, Numpy, Torch, Matplotlib, gc
- Required Information
    - Camera related information
    - A image of the lenslets (or a image that shows the outlines.)
    
__Warning__
- The methods used in the following code may generate inaccruate results. Consider using experimentally measured PSFs.

In [2]:
import cv2 as cv
import numpy as np
import torch
import matplotlib.pyplot as plt
import gc

In [5]:
%matplotlib inline
plt.rcParams['figure.dpi'] = 300

# 0. Camera Information
(In the future, this will be done by reading a .txt or a .csv file)

In [7]:
# Set-up Specific Constants
NA = 0.4
fobj = 10000
f1 = 1
f2 = 1
fm = np.array([47000])

mla2sensor = 47000 # microns
lenspitch = 2520 # Lens pitch in microns 
pixel_size = 3.54 # Also referred to as pixel pitch in microns/px
refractive_index = 1 # Refractive index of the medium
wavelength = 0.5530

noLensHoriz = 3
noLensVert = 3

spacingPixels = 777

In [None]:
def FLFM_setCameraParams(config):
    # Config should be an array with the following parameters:
    # [NA, fobj, f1, f2, fm, mla2sensor, lenspitch, pixel_size, wavelength, refractive_index, noLensHoriz, noLensVert, spacingPixels, horizOffset, vertOffset, shiftRow, gridRot]
    
    objRad = config[0] * config[1] # Objective radius = NA * fobj
    k = 2 * np.pi * config[9] / config[8] # k = 2 * pi * refractive_index / wavelength (wave number)
    M = (config[4] * config[3]) / (config[2] * config[1]) # Magnification = fm * f2 / (f1 * fobj)
    d_refract = 3.5e-3 # Index of Refraction (__Don't know wherer the number came from__)
    fsRad = config[6] * config[3] / (2 * config[4]) # Field stop radius = lenspitch * f2 / (2 * fm)
    fovRad = fsRad / config[2]
    return [objRad, k, M, d_refract, fsRad, fovRad]

# Example: camera_params = FLFM_setCameraParams(config)

In [None]:
def resolution(camera_params, depth_step):
    # Find the number of pixels behind a lesnlet
    lenslet_pixels = [len(perspectives[0,0,0,:]), len(perspectives[0,0,:,0])]
    # Corresponding sensor resolution
    sensor_res = pixel_size
    object_res = [pixel_size/camera_params[2], pixel_size/camera_params[2], depth_step] # config[2] is the magnification
    fovRadVox = (camera_params[5] / i for i in object_res)  # Field of view radius in voxels
    return [lenslet_pixels, sensor_res, object_res, fovRadVox]

# example: res = resolution(camera_params, 10)

# 1. Hough Transform for Circle Detection 
(Do not need to do everytime unless mask has changed for some ungodly reason.)

In [None]:
mask_img = cv.imread(r'./test_images/mask1.tif', cv.IMREAD_GRAYSCALE)
min_radius = int(len(mask_img[0,:])/9)

In [None]:
# Initial Hough circle detection, rounded to the nearest integer/pixel
circles = cv.HoughCircles(mask_img, cv.HOUGH_GRADIENT, 1.2, 1.5*min_radius, param1=70,param2=25, minRadius= min_radius,maxRadius=3*min_radius)
circles = np.uint16(np.around(circles))
circles = np.reshape(circles, (circles.shape[1], circles.shape[2]))

# Find which one is the center circle
circle_x_avg = np.mean(circles[:,0])
circle_y_avg = np.mean(circles[:,1])
circle_r_avg = np.mean(circles[:,2])

distances = []
theta = []

for i in range(len(circles[:,0])):
    x,y = circles[i,0]-circle_x_avg, circles[i,1]-circle_y_avg
    distances.append(np.sqrt((circles[i,0]-circle_x_avg)**2 + (circles[i,1]-circle_y_avg)**2))
    theta.append(np.arctan2(y,x))

theta = np.delete(theta, np.argmin(distances)) # Remove the center circle
distances = np.delete(distances, np.argmin(distances)) # Remove the center circle

theta = np.sort(theta) # Sort the angles
  
circle_theta_avg = np.mean(theta) # Average angle between circles
circle_theta_avg2 = (np.argmax(theta) - np.argmin(theta)) / (len(theta) - 1) #
circle_dist_avg = np.mean(distances)

circles_corrected = np.array([circle_x_avg, circle_y_avg, circle_r_avg])

for i in range(len(theta)):
    circles_corrected = np.vstack((circles_corrected, (circle_x_avg+circle_dist_avg*np.cos(theta[i]), circle_y_avg+circle_dist_avg*np.sin(theta[i]), circle_r_avg))) # new circle centers and radii

In [None]:
center_lenslet = int(0) # center circle index

In [None]:
lenslet_distances_og = np.zeros((len(circles_corrected[:,0]), len(circles_corrected[:,0])))
lenslet_angles = np.zeros((len(circles_corrected[:,0]), len(circles_corrected[:,0])))

# Find the distances between the lenset centers
for i in range(len(circles_corrected[:,0])):
    for j in range(len(circles_corrected[:,0])):
        lenslet_distances_og[i,j] = (np.sqrt((circles_corrected[i,0]-circles_corrected[j,0])**2 + (circles_corrected[i,1]-circles_corrected[j,1])**2))
        lenslet_angles[i,j] = np.arctan2(circles_corrected[i,1]-circles_corrected[j,1], circles_corrected[i,0]-circles_corrected[j,0])

# Grab the 2nd diagonal of the matrix
lenslet_distances = np.diagonal(lenslet_distances_og, offset = 1)
lenslet_angles = np.diagonal(lenslet_angles, offset = 1)
lenslet_distance_avg = np.mean(lenslet_distances)

In [None]:
lenslet_circle_mask = np.zeros((int(2*circles_corrected[0,2]+1),int(2*circles_corrected[0,2]+1)))

# Create a circular mask that has 0s outside the circle and 1s inside the circle
for i in range(int(2*circles_corrected[0,2]+1)):
    r = circles_corrected[0,2]
    for j in range(int(2*circles_corrected[0,2]+1)):
        
        if np.sqrt((i-r)**2 + (j-r)**2) <= r: # Assuming the radius of the first circle is the same as the rest
            lenslet_circle_mask[i,j] = 1
        else:
            lenslet_circle_mask[i,j] = 0

mask_tensor = torch.from_numpy(lenslet_circle_mask)

In [None]:
perspectives = torch.zeros(int(len(frames)), int(len(circles_corrected[:,0])) , int(2*circles_corrected[0,2]+1), int(2*circles_corrected[0,2]+1)) # Making the different perspectives te "channels", also the frames are always odd numbers
# The first dimension is the frame number, the second dimension is the lenslet number, the third and fourth dimensions are the x and y dimensions of the lenslet

for i in range(len(frames)):
    for j in range(len(circles_corrected[:,0])):
        placeholder = torch.Tensor(frames[i][int(circles_corrected[j,1]-circles_corrected[j,2]):int(circles_corrected[j,1]+circles_corrected[j,2]), int(circles_corrected[j,0]-circles_corrected[j,2]):int(circles_corrected[j,0]+circles_corrected[j,2])]) # Don't know why x and y are switched
        
        if placeholder.shape == perspectives[0,0,:,:].shape:
            placeholder[:,:] = torch.mul(placeholder, mask_tensor)
            perspectives[i,j,:,:] = placeholder[:,:]
            
        else:
            # If the circle is too close to the edge of the image, we need to pad it
            placeholder_x_deficit = int(2*circles_corrected[j,2]+1) - placeholder.shape[0]
            placeholder_y_deficit = int(2*circles_corrected[j,2]+1) - placeholder.shape[1]
            
            if placeholder_x_deficit > 0:
                placeholder = torch.cat((placeholder, torch.zeros(placeholder_x_deficit, placeholder.shape[1])), 0)
            if placeholder_y_deficit > 0:
                placeholder = torch.cat((placeholder, torch.zeros(placeholder.shape[0], placeholder_y_deficit)), 1)
            if placeholder_x_deficit < 0:
                placeholder = torch.cat((placeholder, torch.zeros(placeholder_x_deficit, placeholder.shape[1])), 0)
            if placeholder_y_deficit < 0:
                placeholder = torch.cat((placeholder, torch.zeros(placeholder.shape[0], placeholder_y_deficit)), 1)
                
            placeholder[:,:] = torch.mul(placeholder, mask_tensor)
            perspectives[i,j,:,:] = placeholder[:,:] # This is the same as the placeholder, but padded with zeros if the circle is too close to the edge of the image

In [None]:
torch.save(perspectives, './models_DONOTCOMMIT/perspectives.pt')