# XYZ API
basic library usage for analyzing a single diSPIM acquisition collected by stage scanning

In [None]:
import os
import math
from importlib import reload

import cupy
import numpy
import tifffile
import matplotlib.pyplot as plt
%matplotlib widget

## data setup
input path to data, PSFs, and acquisition parameters

the code will then load the data in and do some automated cropping based on image thresholding

In [None]:
# data folder
#root_fldr = os.path.join(
#    "/scratch/gpfs/mb46/ext_spin/2024-06-18"
#)
#acq = "spindle005"
root_fldr = os.path.join(
    "/projects/SHAEVITZ/mb46/fb_dispim",
    "13hr/2025-06-06"
)
acq = "fruiting_body001"
assert os.path.exists(root_fldr), root_fldr 
data_path = os.path.join(root_fldr, acq)
assert os.path.exists(data_path)

# psfs
psf_dir = "/scratch/gpfs/SHAEVITZ/dispim/extract_spindles"
psf_a = numpy.load(os.path.join(psf_dir, "PSFA_500.npy"))
psf_b = numpy.load(os.path.join(psf_dir, "PSFB_500.npy"))

### input acquisition parameters

In [None]:
step_size = 0.5      # in microns, dist. b/t image planes
pixel_size = 0.1625  # in microns, pixel size
theta = math.pi / 4  # angle b/t objective & coverslip

In [None]:
# these are calculations -- don't change
step_pix = step_size / pixel_size
step_size_lat = step_size / math.cos(theta)
step_pix_lat = step_pix / math.cos(theta)
[step_pix, step_pix_lat]

### load data

In [None]:
from pyspim.data import dispim as data

In [None]:
with data.uManagerAcquisition(data_path, False, numpy) as acq:
    a_raw = acq.get('a', 0, 0)
    b_raw = acq.get('b', 0, 0)

In [None]:
# camera offset on the pco.edge is 100
# subtract this off from the data
# NOTE: you can't just do -100 because we've kept the data in uint16_t
# and so if you hit negative numbers it wraps
a_raw = data.subtract_constant_uint16arr(a_raw, 100)
b_raw = data.subtract_constant_uint16arr(b_raw, 100)

## automated ROI detection

In [None]:
from pyspim import roi

In [None]:
# find ROIs for images A & B
roia = roi.detect_roi_3d(a_raw, 'otsu')
roib = roi.detect_roi_3d(b_raw, 'otsu')
roic = roi.combine_rois(roia, roib)

In [None]:
a_raw = a_raw[roic[0][0]:roic[0][1],
              roic[1][0]:roic[1][1],
              roic[2][0]:roic[2][1]]
b_raw = b_raw[roic[0][0]:roic[0][1],
              roic[1][0]:roic[1][1],
              roic[2][0]:roic[2][1]]

In [None]:
_, ax = plt.subplots(1, 2)
ax[0].imshow(numpy.amax(a_raw, 0), cmap='binary_r')
ax[0].set_title('A')
ax[1].imshow(numpy.amax(b_raw, 0), cmap='binary_r')
ax[1].set_title('B')
for i in range(2): ax[i].axis('off')

## deskewing
deskew the input volumes so that the volumes have the normal 'XYZ' coordinate system (images are shaped like ZYX) where Z is normal to the coverslip and the XY axes are the coverslip. 

In [None]:
from pyspim import deskew as dsk
## OPTIONAL: re-crop the deskewed output
# for really big volumes, this can be helpful in saving memory
# since deskewing tends to generate excess 'black' pixels that 
# can be easily-trimmed out post-deskewing
RECROP = False

### head 'A'

In [None]:
a_dsk = dsk.deskew_stage_scan(a_raw, pixel_size, step_size_lat, 1,
                              method='orthogonal')
a_dsk.shape

In [None]:
if RECROP:
    roia = roi.detect_roi_3d(a_dsk, 'triangle')
    a_dsk = a_dsk[roia[0][0]:roia[0][1],
                  roia[1][0]:roia[1][1],
                  roia[2][0]:roia[2][1]].astype(numpy.float32)
else:
    a_dsk = a_dsk.astype(numpy.float32)
del a_raw

### head 'B'
note that this head is typically scanned in the reverse direction of 'A' so the `direction` parameter flips to `-1`

In [None]:
b_dsk = dsk.deskew_stage_scan(b_raw, pixel_size, step_size_lat, -1,
                              method='orthogonal')

In [None]:
if RECROP:
    roib = roi.detect_roi_3d(b_dsk, 'triangle')
    b_dsk = a_dsk[roib[0][0]:roib[0][1],
                  roib[1][0]:roib[1][1],
                  roib[2][0]:roib[2][1]].astype(numpy.float32)
else:
    b_dsk = b_dsk.astype(numpy.float32)

del b_raw

### look at deskewed outputs

In [None]:
_, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].imshow(numpy.amax(a_dsk, 1), cmap='binary_r')
ax[0].set_title('A - zx')
ax[1].imshow(numpy.amax(b_dsk, 1), cmap='binary_r')
ax[1].set_title('B - zx')
for i in range(2):
    ax[i].axis('off')

In [None]:
_, ax = plt.subplots(1, 1)
ax.imshow(numpy.amax(a_dsk, 2), cmap='binary_r')
ax.imshow(numpy.amax(b_dsk, 2), cmap='viridis', alpha=0.4)

In [None]:
_, ysze, _ = a_dsk.shape
#a_dsk = a_dsk[:,:ysze//2,:]
#b_dsk = b_dsk[:,:ysze//2,:]

## registration
now the two views must be registered to each other so that they can be co-deconvolved. to do this, we first do (optional) phase cross correlation on the maximum projections along each plane. this generates an initial guess for how much we need to rotate, translate, and scale the two views. we then feed this initial guess to an optimization method that will try to maximize the correlation ratio between the two images by transforming 'B' such that it lines up with (static) 'A'.

### phase cross correlation

In [None]:
## phase cross correlation requires the two images be the same size
## there's a utility function that will do this for you
from pyspim.util import pad_to_same_size

a_dsk, b_dsk = pad_to_same_size(a_dsk, b_dsk)

In [None]:
from pyspim.reg import pcc

In [None]:
# NOTE: you can get prelim. measurements for all 3, but
# only using 
t0 = pcc.translation_for_volumes(a_dsk, b_dsk, upsample_factor=1)
t0 = [0, 0, 0]

### optimization

In [None]:
from pyspim.reg import powell

In [None]:
# formulate initial parameters
# NOTE: there are other options for the types of transforms that the code
# can (try) to compute. see the code. 
# NOTE: bounds can also be specified as just the margin (+/-) from the initial parameter
transform_string = 't+r+s'
if transform_string == 't':
    par0 = t0
    bounds = [(t-20,t+20) for t in trans]
elif transform_string == 't+r':
    par0 = numpy.concatenate([t0, numpy.asarray([0,0,0])])
    bounds = [(t-20,t+20) for t in t0] + [(-5,5),]*3
elif transform_string == 't+r+s':
    par0 = numpy.concatenate([t0, numpy.asarray([0,0,0]), numpy.asarray([1,1,1])])
    bounds = [(t-20,t+20) for t in t0] + [(-5,5),]*3 + [(0.9,1.1),]*3

In [None]:
# determine launch params so that the GPU is ~saturated
# TODO: right now setting block_size's is half guessing, but in the future
# we'll want to hit the CUDA occupancy API
from pyspim.util import launch_params_for_volume
launch_par = launch_params_for_volume(a_dsk.shape, 8, 8, 8)

In [None]:
# do the optimization
# NOTE: this can be done either in a single shot, directly estimating the transform
# from the initial parameters (`powell.optimize_affine`) or "piecewise" (`powell.optimize_affine_piecewise`)
# where the final transform is estimated sequentially by progressively increasing the
# transformation complexity, while using the simpler transform as an initial condition
# for the next transform in the sequence. 
# to get a feel for how this works, an example. if the transform is 't+r+s' then
# the piecewise optimization will do 't' (just translation), then 't+r'
# (translation & rotation), then finally doing the 't+r+s'
# (translation & rotation & scaling) 
T, res = powell.optimize_affine_piecewise(
    cupy.asarray(a_dsk), cupy.asarray(b_dsk),
    metric='cr', transform=transform_string, 
    interp_method='cubspl',
    par0=par0, bounds=bounds,
    kernel_launch_params=launch_par,
    verbose=True
)

# check the results of the optimization
# for correlation ratio (metric='cr', values > 0.9 are generally ok)
cr = 1 - res.fun
print('Optimized Metric {:.2f}'.format(cr))

### transformation
use the optimized transform to register 'B' with 'A'

In [None]:
from pyspim.interp import affine
reload(affine)

In [None]:
T

In [None]:
# NOTE: we can reuse the launch parameters because the volume
# sizes are the same -- but this isn't necessarily always
# the case
b_reg = affine.transform(cupy.asarray(b_dsk), T,
                         interp_method='cubspl',
                         preserve_dtype=True, out_shp=None,
                         block_size_z=8, block_size_y=8, block_size_x=8).get()

In [None]:
# if a scaling transform is included in the estimated registration transformation
# then the output `b_reg` might not be the same size as the input `b_dsk`
# but to do deconvolution, you need the input images to be the same size
# for simplicity, we'll just crop to the smallest possible size since for the decon
# to work you need both images to have content, anyway
min_sze = [min(a,b) for a, b in zip(a_dsk.shape, b_reg.shape)]
a_dsk = a_dsk[:min_sze[0],:min_sze[1],:min_sze[2]]
b_reg = b_reg[:min_sze[0],:min_sze[1],:min_sze[2]]

In [None]:
## check results
_, ax = plt.subplots(1, 1)
ax.imshow(numpy.amax(a_dsk, 0), cmap='binary_r')
ax.imshow(numpy.amax(b_reg, 0), cmap='viridis', alpha=0.5)

## intermediate outputs (zarr saving)
for large datasets, we can save memory by doing the deconvolution in chunks.
to do this, take the data and save it as a zarr file, which can then be read in small chunks with each being deconvolved and written back out to the corresponding location in the output zarr file.

*NOTE*: for small datasets, this isn't necessary and can just call `pyspim.decon.rl.dualview_fft.deconvolve` instead.

In [None]:
## save intermediate outputs
import zarr

interm_path = "/scratch/gpfs/mb46/tmp" # a tmp path to put outputs in
if not os.path.exists(interm_path):
    os.mkdir(interm_path)

In [None]:
a_zarr = zarr.creation.open_array(
    os.path.join(interm_path, 'a.zarr'),
    mode='w',
    shape=a_dsk.shape,
    dtype=numpy.uint16,
    fill_value=0
)
a_zarr[:] = a_dsk

In [None]:
b_zarr = zarr.creation.open_array(
    os.path.join(interm_path, 'b.zarr'),
    mode='w',
    shape=b_reg.shape,
    dtype=numpy.uint16,
    fill_value=0
)
b_zarr[:] = b_reg

## deconvolution

In [None]:
from pyspim.decon.rl.dualview_fft import deconvolve_chunkwise

In [None]:
out = zarr.creation.open_array(
    os.path.join(interm_path, 'out.zarr'),
    mode='w',
    shape=b_reg.shape,
    dtype=numpy.float32,
    fill_value=0
)

In [None]:
psf_a.shape

In [None]:
deconvolve_chunkwise(
    a_zarr, b_zarr, out,
    [128, 512, 512], [40,40,40],
    cupy.asarray(psf_a), cupy.asarray(psf_b),
    cupy.asarray(psf_a[::-1,::-1,::-1]), cupy.asarray(psf_b[::-1,::-1,::-1]),
    'additive', 20, 1e-6, 
    False, None, 0, 0,
    True
)

In [None]:
decon = zarr.load(os.path.join(interm_path, 'out.zarr'))
decon.shape

In [None]:
## check results
_, ax = plt.subplots(2, 2)
ax[0,0].imshow(numpy.amax(decon, 0).T, cmap='binary_r', vmax=700)
ax[1,0].imshow(numpy.amax(decon, 2), cmap='binary_r', vmax=700)
ax[0,1].imshow(numpy.amax(decon, 1).T, cmap='binary_r', vmax=700)
#ax[1,1].imshow(numpy.zeros((decon.shape[1],decon.shape[1])), cmap='binary_r')
for i in range(2):
    for j in range(2):
        ax[i,j].axis('off')
plt.tight_layout()

In [None]:
size_gb = decon.size / 1e9
size_gb

In [None]:
# optional: write to TIF file for viewing in Fiji/ImageJ/whatever
# NOTE: this does support BigTIFF and for really large d
tifffile.imwrite(os.path.join(data_path, 'decon.ome.tif'),
                 numpy.round(decon).clip(0, 2**16).astype(numpy.uint16), 
                 imagej=True,
                 resolution=(1/0.1625, 1/0.1625),
                 metadata={
                     'unit' : 'um',
                     'axes' : 'ZYX',
                     'spacing' : 0.1625,
                 })