In [1]:
from dataclasses import dataclass, field
from typing import Tuple

import DeconTools.viz as viz
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import scipy.fft as fft
import tifffile
from DeconTools.core.filters import compute_lanczos_filter
from DeconTools.core.Models import TransformParams
from DeconTools.core.PSF import MicroscopeParameters, psf3d

%matplotlib qt

In [87]:
# get the 'best' or in-focus plane @ z=10
data = tifffile.imread(
    "/Users/delnatan/BurgessLab/Seans spreads Daniel decon study/dxy_90nm/confocal 90 nm pixel .czi - C=1.tif"
).astype(np.float32)
data = np.maximum(data, 0.0)[3]/ 100.0
Ny, Nx = data.shape
plt.imshow(data)

<matplotlib.image.AxesImage at 0x307539250>

In [88]:
mpars = MicroscopeParameters(
    excitation_wavelength=0.488,
    emission_wavelength=0.530,
    numerical_aperture=1.40,
    sample_refractive_index=1.40,
    immersion_refractive_index=1.515,
    pixel_size=0.090,
    confocal=False,
)

P = TransformParams(
    data_shape=(Ny, Nx),
    data_padding=(40, 40),
    zoom_factor=(3, 3),
    freq_cutoff=0.2,
    microscope_parameters=mpars,
)

In [89]:
cutoff_mask = np.zeros(P.extended_object_rfft_shape)
cutoff_mask[P.extended_data_rfft_indices] = 1.0

In [90]:
fig, ax = plt.subplots(ncols=2, figsize=(8, 3.5))
ax[0].imshow(P.ftICF, norm=mcolors.PowerNorm(0.2))
ax[0].imshow(cutoff_mask, alpha=0.2, cmap="gray")
ax[1].imshow(np.abs(P.OTF), norm=mcolors.PowerNorm(0.2))

<matplotlib.image.AxesImage at 0x303fd3e50>

In [91]:
def forward(x: np.ndarray, T: TransformParams) -> np.ndarray:
    X = fft.rfftn(x)
    lpX = T.OTF * T.ftICF * X
    # fourier crop
    lpXcrop = lpX[T.extended_data_rfft_indices]
    lpx = fft.irfftn(lpXcrop, s=T.extended_data_shape)
    return lpx[T.data_slices]


def adjoint(y: np.ndarray, T: TransformParams) -> np.ndarray:
    # zero pad data
    ypad = np.zeros(T.extended_data_shape)
    ypad[T.data_slices] = y
    Y = fft.rfftn(ypad)
    # pad in Fourier domain
    ftY = np.zeros(T.extended_object_rfft_shape, dtype=np.complex64)
    ftY[T.extended_data_rfft_indices] = Y[T.extended_data_rfft_indices]
    lpY = np.conj(T.ftICF) * np.conj(T.OTF) * ftY
    return fft.irfftn(lpY * T.adjoint_iscale, s=T.extended_object_shape)

In [92]:
# RL prototype
f = np.ones(P.extended_object_shape, dtype=np.float32)
hnorm = adjoint(np.ones(P.data_shape, dtype=np.float32), P)

for i in range(100):
    model = forward(f, P)
    ratio = data / model
    logL = np.sum(data * np.log(ratio + 1e-6) + model - data)
    print(f"\rIteration = {i+1}, logL ==={logL:12.4E}", end="")
    update = adjoint(ratio, P)
    f *= update
    f = np.where(hnorm > 1e-3, f / hnorm, f)

Iteration = 13, logL ===         NAN

  logL = np.sum(data * np.log(ratio + 1e-6) + model - data)


Iteration = 100, logL ===         NAN

In [93]:
funpad = f[P.object_slices]

In [94]:
fig, ax = plt.subplots(ncols=2, figsize=(10, 4.5))
ax[0].imshow(data, cmap="magma")
ax[1].imshow(funpad, cmap="magma")
ax[0].set_title("input image")
ax[1].set_title("3x Zoom-deconvolved")

Text(0.5, 1.0, '3x Zoom-deconvolved')

In [11]:
tifffile.imwrite("3x_airyscan_120nm_decon.tif", funpad)
tifffile.imwrite("3x_airyscan_120nm.tif", data)