In [78]:
from math import sqrt

import colorcet as cc
import DeconTools.core.proximals as prox
import DeconTools.core.PSF as p
import DeconTools.utils as utils
import DeconTools.viz as viz
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndi
import torch
from DeconTools.operators.derivatives import (
    compute_first_derivative_filters,
    compute_second_derivative_filters,
)
from DeconTools.operators.fftops import CompositeLinearOperator, LinearOperator
from matplotlib.colors import PowerNorm
from tifffile import imread

from ADMM_torch import DeconADMM
from FourierFilters import second_order_diffops_3d

mps = torch.device("mps")
%load_ext autoreload
%autoreload 2
%matplotlib qt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
# setup data and PSF
yeast_dapi = imread("../data/dapi_img_crop.tif")[:-1]
yeast_dapi = np.maximum(yeast_dapi.astype(np.float32) - 100.0, 0.0)
yeast_dapi = yeast_dapi[:, 230:230+128, 250:250+128]
Nz, Ny, Nx = yeast_dapi.shape

params = p.MicroscopeParameters(
    excitation_wavelength=0.485,
    emission_wavelength=0.529,
    numerical_aperture=0.6,
    immersion_refractive_index=1.0,
    sample_refractive_index=1.36,
    pixel_size=0.1565,
)

pu = p.PupilFunction(params=params, Nx=Nx, Ny=Ny)
zslices = p.compute_z_planes(Nz, dz=0.580)
psf3d = pu.calculate_confocal_3D_psf(zslices)
psf3d /= psf3d.sum()

In [15]:
# der2filts = compute_second_derivative_filters(3)
# der2ops = [LinearOperator(h, (Nz, Ny, Nx)) for h in der2filts]
# DD = CompositeLinearOperator(der2ops)
L2 = second_order_diffops_3d(Nz, Ny, Nx, lateral_to_axial_ratio=0.269)
L2 = [l.astype(np.complex64) for l in L2]

In [79]:
# problem setup
# torch_psf3d = torch.from_numpy(psf3d.astype(np.float32)).to(mps)
# b = torch.from_numpy(yeast_dapi.astype(np.float32)).to(mps)
b = yeast_dapi.astype(np.float32)
b = ndi.gaussian_filter(b, 1.2)
H = LinearOperator(psf3d, (Nz, Ny, Nx))

In [80]:
otf = H.ft_kernel
p = DeconADMM(b, otf, L2, device="mps")

In [81]:
p.reset()
lam = 1e-2
rho = 6000 * lam / b.max()
print(rho)
for _ in range(100):
    p.step(rho, lam, 1.0, regularization="SCAD", a=3.7)

0.5090504898764309
Iteration  100, RMSE =     0.7831

In [83]:
sol = p.x.detach().cpu().numpy()
clo, chi = np.percentile(sol, (0.1, 99.99))

viewer = viz.SimpleOrthoViewer(
    p.x.detach().cpu().numpy(),
    cmap=cc.m_gouldian,
    norm=plt.Normalize(vmin=clo, vmax=chi),
)