In [57]:
import imagej
import matplotlib.pyplot as plt
import scyjava as sj
import numpy as np
import xarray as xr
import os

sj.config.add_option('-Xmx6g')
ij = imagej.init(add_legacy=False)
print(f"ImageJ2 version: {ij.getVersion()}")

# import imagej2 and imglib2 Java classes
CreateNamespace = imagej.sj.jimport('net.imagej.ops.create.CreateNamespace')
FinalDimensions = imagej.sj.jimport('net.imglib2.FinalDimensions')
FloatType = imagej.sj.jimport('net.imglib2.type.numeric.real.FloatType')
Views = sj.jimport('net.imglib2.view.Views')

ImageJ2 version: 2.16.0


In [2]:
import psutil
print(f'Cores            : {psutil.cpu_count(logical=False)}') # Physical CPU count
print(f'Total memory     : {psutil.virtual_memory().total / (1024 ** 3)} GB') # Total memory in gigabytes
print(f'Available memory : {np.round(psutil.virtual_memory().available / (1024 ** 3),1)}') # Available memory in gigabytes

Cores            : 10
Total memory     : 16.0 GB
Available memory : 5.4


In [47]:
def deconvolve(img, wv, iterations, reg, na, ri_sample, ri_immersion, lat_res, ax_res, pz):
    """
    Deconvolve a 3D image using the Richardson-Lucy algorithm with total variation regularization.

    Parameters
    ----------
    img : numpy.ndarray
        The input image to be deconvolved.
    wv : float
        The emission wavelength in nanometers.
    iterations : int
        The number of iterations for the Richardson-Lucy algorithm.
    reg : float
        The regularization factor for total variation.
    na : float
        The numerical aperture of the objective lens.
    ri_sample : float
        The refractive index of the sample.
    ri_immersion : float
        The refractive index of the immersion medium.
    lat_res : float
        The lateral resolution in micrometers.
    ax_res : float
        The axial resolution in micrometers.
    pz : float
        The distance from the coverslip in micrometers.

    Returns
    -------
    numpy.ndarray
        The deconvolved image.
    """

    # convert input image to imglib2 ImagePlus
    img_f = ij.py.to_java(img)
   # convert input parameters into meters
    wv = wv * 1E-9
    lat_res = lat_res * 1E-6
    ax_res = ax_res * 1E-6
    pz = pz * 1E-6
   
   # convert the input image dimensions to imglib2 FinalDimensions
    psf_dims = FinalDimensions(img.shape)

    # create synthetic PSF
    psf = ij.op().namespace(CreateNamespace).kernelDiffraction(
        psf_dims, na, wv, ri_sample, ri_immersion, lat_res, ax_res, pz, FloatType())

    img_decon = ij.op().deconvolve().richardsonLucyTV(img_f, psf, iterations, reg)

    return img_decon


def plot_image_4channel(image):
    middle_slice = image.shape[0] // 2
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    axes[0].imshow(image[middle_slice, 0, :, :], cmap='gray')
    axes[0].set_title('Channel 1, middle_slice')
    axes[1].imshow(image[middle_slice, 1, :, :], cmap='gray')
    axes[1].set_title('Channel 2, middle_slice')
    axes[2].imshow(image[middle_slice, 2, :, :], cmap='gray')
    axes[2].set_title('Channel 3, middle_slice')
    axes[3].imshow(image[middle_slice, 3, :, :], cmap='gray')
    axes[3].set_title('Channel 4, middle_slice')
    for ax in axes:
        ax.axis('off')
    plt.suptitle(f'{image[:-4]}', fontsize=16)
    plt.tight_layout()
    plt.show()

def plot_image_3channel(image):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    axes[0].imshow(image[0,:,:], cmap='gray')
    axes[0].set_title('Channel 1')
    axes[1].imshow(image[1,:,:], cmap='gray')
    axes[1].set_title('Channel 2')
    axes[2].imshow(image[2,:,:], cmap='gray')
    axes[2].set_title('Channel 3')
    axes[3].imshow(image[3,:,:], cmap='gray')
    axes[3].set_title('Channel 4')
    for ax in axes:
        ax.axis('off')
    plt.suptitle(f'{image[:-4]}', fontsize=16)
    plt.tight_layout()
    plt.show()


def dump_info(image):
    """A handy function to print details of an image object."""
    name = image.name if hasattr(image, 'name') else None # xarray
    if name is None and hasattr(image, 'getName'): name = image.getName() # Dataset
    if name is None and hasattr(image, 'getTitle'): name = image.getTitle() # ImagePlus
    print(f" name: {name or 'N/A'}")
    print(f" type: {type(image)}")
    print(f"dtype: {image.dtype if hasattr(image, 'dtype') else 'N/A'}")
    print(f"shape: {image.shape}")
    print(f" dims: {image.dims if hasattr(image, 'dims') else 'N/A'}")

In [41]:
wd = '/Users/jkellerm/Library/CloudStorage/OneDrive-MichiganMedicine/0-active-projects/expansion/expansion_20250605_nikon/'
dir = wd + '32tif/'
img_list = os.listdir(dir)
img_list = [f for f in img_list if f.endswith('.tif')]
img_list.sort()

for img in img_list:
    print(img)

nd010-1-ch1.tif
nd010-1-ch2.tif
nd010-1-ch3.tif
nd010-1-ch4.tif
nd010-1-decon.tif
nd010-1-decon_ch1.tif
nd010-1-decon_ch2.tif
nd010-1-decon_ch3.tif
nd010-1-decon_ch4.tif
nd010-1.tif
nd010.tif
nd012-1-decon.tif
nd012-1-decon_ch1.tif
nd012-1-decon_ch2.tif
nd012-1-decon_ch3.tif
nd012-1-decon_ch4.tif
nd012-1.tif
nd012.tif
nd013-1-decon.tif
nd013-1-decon_ch1.tif
nd013-1-decon_ch2.tif
nd013-1-decon_ch3.tif
nd013-1-decon_ch4.tif
nd013-1.tif
nd013.tif


In [62]:
# set the iterations and regularization factor for Richardson-Lucy TV
iterations = 1
reg = 0.002
na = 1.5 # numerical aperture
wavelength = [647, 594, 488, 405] # emission wavelength
ri_immersion = 1 # refractive index (immersion)
ri_sample = 1 # refractive index (sample)
lat_res = 0.07 # lateral resolution (i.e. xy)
ax_res = 1 # axial resolution (i.e. z)
pz = 0 # distance away from coverslip
if iterations < 10:
    iterstring = f'0{iterations}'
else:
    iterstring = str(iterations)

# change this to the desired image file
img_name = 'nd012-1.tif'  
# open the image and convert to float32
img = ij.io().open(os.path.join(dir, img_name))
img = ij.op().convert().float32(img)

print('img')
dump_info(img)

decon_channels = []  # list to hold deconvolved channels
for channel in range(img.shape[2]):
    if img.ndim == 3:
        decon = deconvolve(img[:,:,channel], wavelength[channel], iterations, reg, na, ri_sample, ri_immersion, lat_res, ax_res, pz)
    elif img.ndim == 4:
        decon = deconvolve(img[:,:,channel,:], wavelength[channel], iterations, reg, na, ri_sample, ri_immersion, lat_res, ax_res, pz)
    decon = ij.py.from_java(decon)
    decon_channels.append(decon)
decon_ximg = np.stack(decon_channels, axis=1)
# switch the row and col dimensions
# decon_ximg = np.transpose(decon_ximg, (0, 1, 3, 2))  # (pln, ch, col, row)
decon_ximg = xr.DataArray(decon_ximg, name='decon_ximg', dims=('pln', 'ch', 'row', 'col'))
decon_img = ij.py.to_java(decon_ximg)

print('decon_ximg')
dump_info(decon_ximg)
print('decon_img')
dump_info(decon_img)

if os.path.exists(os.path.join(dir, f'{img_name[:-4]}-decon-{iterstring}-iteration-{reg}-reg.tif')):
    os.remove(os.path.join(dir, f'{img_name[:-4]}-decon-{iterstring}-iteration-{reg}-reg.tif'))
    print(f'Removed existing file: {img_name[:-4]}-decon-{iterstring}-iteration-{reg}-reg.tif')
    ij.io().save(decon_img, os.path.join(dir, f'{img_name[:-4]}-decon-{iterstring}-iteration-{reg}-reg.tif'))
    print(f'Saved deconvolved image: {img_name[:-4]}-decon-{iterstring}-iteration-{reg}-reg.tif')
else:
    ij.io().save(decon_img, os.path.join(dir, f'{img_name[:-4]}-decon-{iterstring}-iteration-{reg}-reg.tif'))
    print(f'Saved deconvolved image: {img_name[:-4]}-decon-{iterstring}-iteration-{reg}-reg.tif')

img
 name: N/A
 type: <java class 'net.imglib2.img.planar.PlanarImg'>
dtype: <java class 'net.imglib2.type.numeric.real.FloatType'>
shape: (123, 177, 4, 31)
 dims: N/A
decon_ximg
 name: decon_ximg
 type: <class 'xarray.core.dataarray.DataArray'>
dtype: float32
shape: (31, 4, 177, 123)
 dims: ('pln', 'ch', 'row', 'col')
decon_img
 name: decon_ximg
 type: <java class 'net.imagej.DefaultDataset'>
dtype: <java class 'net.imglib2.type.numeric.real.FloatType'>
shape: (123, 177, 4, 31)
 dims: ('X', 'Y', 'Channel', 'Z')
Removed existing file: nd012-1-decon-01-iteration-0.002-reg.tif
Saved deconvolved image: nd012-1-decon-01-iteration-0.002-reg.tif


In [None]:

if ximg.ndim == 4:
    plot_image_4channel(ximg)
elif ximg.ndim == 3:
    plot_image_3channel(ximg)
    

fix, axes = plt.subplots(1, 4, figsize=(20, 5))
axes[0].imshow(ij.py.from_java(decon_channels[0])[:, :, 15], cmap='gray')
axes[0].set_title('Deconvolved Channel 1, Slice 15')
axes[1].imshow(ij.py.from_java(decon_channels[1])[:, :, 15], cmap='gray')
axes[1].set_title('Deconvolved Channel 2, Slice 15')
axes[2].imshow(ij.py.from_java(decon_channels[2])[:, :, 15], cmap='gray')
axes[2].set_title('Deconvolved Channel 3, Slice 15')
axes[3].imshow(ij.py.from_java(decon_channels[3])[:, :, 15], cmap='gray')
axes[3].set_title('Deconvolved Channel 4, Slice 15')
for ax in axes:
    ax.axis('off')
plt.suptitle(f'Deconvolved Image: {img_name}, Slice 15', fontsize=16)
plt.tight_layout()
plt.show()


In [None]:

for i in range(img.shape[2]):
    slice_img = img[:, :, i]
    wv = wavelength[i]
    img_decon = deconvolve(slice_img, wv, iterations, reg, na, ri_sample, ri_immersion, lat_res, ax_res, pz)
    decon_slices.append(img_decon)  # accumulate slices

# Now stack all slices at once
img_decon = Views.stack(decon_slices)

ximg = ij.py.from_java(img)
ximg_decon = ij.py.from_java(img_decon)

# Create a figure with 8 subplots
fig, ax = plt.subplots(2, 4, figsize=(12, 9))

# Plot the original image channels
for i in range(ximg.shape[0]):
    ax[0,i].imshow(ximg[i,:,:], cmap='gray')
    ax[0,i].set_title(f"Original Channel {i+1}")
    ax[0,i].axis('off')
    ax[1,i].imshow(ximg_decon[i,:,:], cmap='gray')
    ax[1,i].set_title(f"Deconvolved Channel {i+1}")
    ax[1,i].axis('off')

# Adjust layout and display the plot
plt.tight_layout()
plt.show()

In [326]:
img_decon = ij.py.to_dataset(ximg_decon, dim_order=['ch', 'row', 'col'])

if (os.path.exists('deconvolved_image.tif')):
    os.remove('deconvolved_image.tif')
ij.io().save(img_decon, 'deconvolved_image.tif')