In [1]:
import cupy as cp
from skimage.io import imread, imshow
from typing import List

In [2]:
from napari_tools_menu import register_function

@register_function(menu="Image math > FFT (2D)")
def forward_fft_2d(image:"napari.types.ImageData", viewer:"napari.Viewer"=None):
    try:
        import cupy as np
    except:
        import numpy as np
    original_image = image

    image = np.asarray(image)
    image = np.fft.fft2(image)
    image = np.fft.fftshift(image)
    
    if viewer is not None:
        from napari_workflows._workflow import _get_layer_from_data
        layer = _get_layer_from_data(viewer, original_image)
        if layer is not None:
            name = layer.name
        else:
            name = "Image"
        
        viewer.add_image(image.real.get(), name=name + " (real)")
        viewer.add_image(image.imag.get(), name=name + " (imaginary)")
    else:
        return image.real.get(), image.imag.get()

In [3]:
@register_function(menu="Image math > inverse FFT (2D)")
def inverse_fft_2d(real_image:"napari.types.ImageData", imaginary_image:"napari.types.ImageData") -> "napari.types.ImageData":
    try:
        import cupy as np
    except:
        import numpy as np
    
    # create complex image, source: https://github.com/numpy/numpy/issues/16039
    image = np.empty(real_image.shape, dtype=np.complex128)
    image.real = np.asarray(real_image)
    image.imag = np.asarray(imaginary_image)
    image = np.fft.ifftshift(image)
    image = np.fft.ifft2(image)
    
    return image.real.get()

In [4]:
@register_function(menu="Image math > Set masked pixels")
def set_masked_pixels(image:"napari.types.ImageData", mask:"napari.types.LabelsData", new_value:float = 0) -> "napari.types.ImageData":
    
    try:
        import cupy as np
    except:
        import numpy as np
        
    image = np.asarray(image).copy()
    image[np.asarray(mask) != 0] = new_value
    
    return image.get()

In [5]:
image = imread("C:/structure/data/zfish_nucl_env.tif")

In [6]:
import napari
viewer = napari.Viewer()

In [7]:
viewer.add_image(image)

<Image layer 'image' at 0x1c0b5591b80>

In [8]:
#viewer.add_image(cp.asarray(image))

In [9]:
real, imag = forward_fft_2d(image)

In [10]:
result = inverse_fft_2d(real, imag)

In [11]:
result.dtype

dtype('float64')

In [12]:
viewer.add_image(result)

<Image layer 'result' at 0x1c0b98e1250>