In [2]:
import torch
import os
import gc
import numpy as np
import cv2 as cv
import math

# Auxillary Functions

## Images to PSF

In [20]:
def img_to_pt(path):
    # Find all images in a folder:
    imgs = []
    accepted_filetypes = ['tif', 'tiff', 'png', 'jpg', 'jpeg', 'bmp']
    for f in os.listdir(path):
        ext = os.path.splitext(f)[1][1:]
        if ext.lower() not in accepted_filetypes:
            continue
        imgs.append(cv.imread(path+f, cv.IMREAD_GRAYSCALE).astype(np.float32))

    # Convert them into a tensor for pyTorch
    tensor_imgs = torch.from_numpy(np.array(imgs,dtype=np.float32)).to(torch.float32)
    del imgs
    gc.collect()
    
    return tensor_imgs

# $2^n$ for FFTs

In [4]:
def two_n_squareify(tensor):
    # This function takes a tensor and pads it with zeros until the dimensions are a power of 2
    # The return tensor is a square tensor
    
    # Find the next power of 2
    next_power = max(2**math.ceil(math.log2(len(tensor[0,:,0]))), 2**math.ceil(math.log2(len(tensor[0,0,:]))))
    
    # Find the amount of padding needed
    x_pad = next_power - len(tensor[0,:,0])
    y_pad = next_power - len(tensor[0,0,:])
    
    # Check to see if it's even or odd
    if x_pad % 2 == 0:
        l_x_pad = int(x_pad/2)
        r_x_pad = int(x_pad/2)
    else:
        l_x_pad = int(x_pad/2)
        r_x_pad = x_pad - int(x_pad/2)
    
    if y_pad % 2 == 0:
        l_y_pad = int(y_pad/2)
        r_y_pad = int(y_pad/2)
    else:
        l_y_pad = int(y_pad/2)
        r_y_pad = y_pad - int(y_pad/2)
    
    return_tensor = torch.nn.functional.pad(tensor, (l_y_pad, r_y_pad, l_x_pad, r_x_pad), "constant", 0)
    
    return return_tensor

## Load PSF

In [None]:
def load_psf(path, normalize):
    # Loads the PSF stack from a path
    #path = r'/Users/halensolomon/Code/FLFM_local/testing/psf_stack.pt'
    
    # if the file type is .pt, then we can just load it
    if path[-3:] == '.pt':
        psf = torch.load(path)
        
        if normalize == True:
            for i in range(len(psf[:,0,0])):
                psf[i,:,:] = torch.div(psf[i,:,:],torch.sum(psf[i,:,:])) # Sum of each axial slice should be 1
    
    # if the path is actually a folder, then we need to load in all the files in the folder
    elif os.path.isdir(path):
        psf = torch.zeros(len(os.listdir(path)), len(Image.open(path + '/' + os.listdir(path)[0]).convert('L').getdata()), len(Image.open(path + '/' + os.listdir(path)[0]).convert('L').getdata()))
        
        for i in range(len(os.listdir(path))): # Format is [depth, x, y]
            psf[i,:,:] = torch.from_numpy(np.array(Image.open(path + '/' + os.listdir(path)[i]).convert('L').getdata()).reshape(Image.open(path + '/' + os.listdir(path)[i]).convert('L').size))
            if normalize == True:
                psf[i,:,:] = torch.div(psf[i,:,:],torch.sum(psf[i,:,:])) # Sum of each axial slice should be 1
    
    return psf