In [None]:
import imagej
import matplotlib.pyplot as plt
import scyjava as sj
import numpy as np
import os

sj.config.add_option('-Xmx6g')
ij = imagej.init(add_legacy=False)
print(f"ImageJ2 version: {ij.getVersion()}")

# import imagej2 and imglib2 Java classes
CreateNamespace = imagej.sj.jimport('net.imagej.ops.create.CreateNamespace')
FinalDimensions = imagej.sj.jimport('net.imglib2.FinalDimensions')
FloatType = imagej.sj.jimport('net.imglib2.type.numeric.real.FloatType')
Views = sj.jimport('net.imglib2.view.Views')

ImageJ2 version: 2.16.0


In [349]:
import psutil
print(f'Cores            : {psutil.cpu_count(logical=False)}') # Physical CPU count
print(f'Total memory     : {psutil.virtual_memory().total / (1024 ** 3)} GB') # Total memory in gigabytes
print(f'Available memory : {np.round(psutil.virtual_memory().available / (1024 ** 3),1)}') # Available memory in gigabytes

Cores            : 10
Total memory     : 16.0 GB
Available memory : 4.5


In [303]:
def deconvolve(img, wv, iterations, reg, na, ri_sample, ri_immersion, lat_res, ax_res, pz):
    """
    Deconvolve a 3D image using the Richardson-Lucy algorithm with total variation regularization.

    Parameters
    ----------
    img : numpy.ndarray
        The input image to be deconvolved.
    wv : float
        The emission wavelength in nanometers.
    iterations : int
        The number of iterations for the Richardson-Lucy algorithm.
    reg : float
        The regularization factor for total variation.
    na : float
        The numerical aperture of the objective lens.
    ri_sample : float
        The refractive index of the sample.
    ri_immersion : float
        The refractive index of the immersion medium.
    lat_res : float
        The lateral resolution in micrometers.
    ax_res : float
        The axial resolution in micrometers.
    pz : float
        The distance from the coverslip in micrometers.

    Returns
    -------
    numpy.ndarray
        The deconvolved image.
    """

    # convert input image to imglib2 ImagePlus
    img_f = ij.py.to_java(img)
   # convert input parameters into meters
    wv = wv * 1E-9
    lat_res = lat_res * 1E-6
    ax_res = ax_res * 1E-6
    pz = pz * 1E-6
   
   # convert the input image dimensions to imglib2 FinalDimensions
    psf_dims = FinalDimensions(img.shape)

    # create synthetic PSF
    psf = ij.op().namespace(CreateNamespace).kernelDiffraction(
        psf_dims, na, wv, ri_sample, ri_immersion, lat_res, ax_res, pz, FloatType())

    img_decon = ij.op().deconvolve().richardsonLucyTV(img_f, psf, iterations, reg)

    return img_decon

In [332]:
wd = '/Users/jkellerm/Library/CloudStorage/OneDrive-MichiganMedicine/0-active-projects/merozoite/2025-05-13_merozoite-pretreat/prj/'
img_list = os.listdir(wd)
img_list = [f for f in img_list if f.endswith('.tif')]
img_list.sort()

print(img_list[0][:-4])

20hpi_extra-01-slice-8


In [None]:


# open the image
img = ij.io().open('20hpi_extra-06-slice-8.tif')
img = ij.op().convert().float32(img)  # convert to 32-bit

decon_slices = []  # list to hold deconvolved slices

# set the iterations and regularization factor for Richardson-Lucy TV
iterations = 20
reg = 0.002
na = 1.4 # numerical aperture
wavelength = [617, 508, 461, 550] # emission wavelength
ri_immersion = 1.5 # refractive index (immersion)
ri_sample = 1.4 # refractive index (sample)
lat_res = 0.07 # lateral resolution (i.e. xy)
ax_res = 0.24 # axial resolution (i.e. z)
pz = 0 # distance away from coverslip

for i in range(img.shape[2]):
    slice_img = img[:, :, i]
    wv = wavelength[i]
    img_decon = deconvolve(slice_img, wv, iterations, reg, na, ri_sample, ri_immersion, lat_res, ax_res, pz)
    decon_slices.append(img_decon)  # accumulate slices

# Now stack all slices at once
img_decon = Views.stack(decon_slices)

ximg = ij.py.from_java(img)
ximg_decon = ij.py.from_java(img_decon)

# Create a figure with 8 subplots
fig, ax = plt.subplots(2, 4, figsize=(12, 9))

# Plot the original image channels
for i in range(ximg.shape[0]):
    ax[0,i].imshow(ximg[i,:,:], cmap='gray')
    ax[0,i].set_title(f"Original Channel {i+1}")
    ax[0,i].axis('off')
    ax[1,i].imshow(ximg_decon[i,:,:], cmap='gray')
    ax[1,i].set_title(f"Deconvolved Channel {i+1}")
    ax[1,i].axis('off')

# Adjust layout and display the plot
plt.tight_layout()
plt.show()

In [326]:
img_decon = ij.py.to_dataset(ximg_decon, dim_order=['ch', 'row', 'col'])

if (os.path.exists('deconvolved_image.tif')):
    os.remove('deconvolved_image.tif')
ij.io().save(img_decon, 'deconvolved_image.tif')