In [1]:
import colorcet as cc
import DeconTools.operators as op
import DeconTools.viz as viz
import numpy as np
import scipy.ndimage as ndi
import torch

%load_ext autoreload
%autoreload 2
%matplotlib qt

In [2]:
# create test object (spherical shell)
Nz, Ny, Nx = 80, 128, 128
Z, Y, X = np.meshgrid(
    np.arange(Nz), np.arange(Ny), np.arange(Nx), indexing="ij"
)

R = np.sqrt((Z - 40) ** 2 + (Y - 64) ** 2 + (X - 64) ** 2)
sphere = (R >= 25) & (R <= 27)
sphere = ndi.gaussian_filter(sphere.astype(float), 1.25)

In [3]:
viewer = viz.SimpleOrthoViewer(sphere)
viewer.show()

2024-12-10 15:46:02.003 python[12672:2787149] +[IMKClient subclass]: chose IMKClient_Modern
2024-12-10 15:46:02.003 python[12672:2787149] +[IMKInputSession subclass]: chose IMKInputSession_Modern


In [4]:
mps = torch.device("mps")
torch_sphere = torch.from_numpy(sphere.astype(np.float32)).to(mps)

In [33]:
# create finite-difference operators
d2ops = op.derivatives.compute_second_derivative_filters(3)
# move them in torch device
torch_d2ops = [torch.from_numpy(h.astype(np.float32)).to(mps) for h in d2ops]

# define linear operators
linops = [op.fftops.LinearOperator(h, (Nz, Ny, Nx)) for h in torch_d2ops]

# define composite linear operator ℜ^N -> ℜ^6N
H = op.fftops.CompositeLinearOperator(
    linops, weights=[1.0, 1.0, 1.0, 1.414213, 1.414213, 1.414213]
)

In [34]:
Dz = op.fftops.LinearOperator(torch.from_numpy(d1ops[0]).to(mps), (Nz, Ny, Nx))
Dy = op.fftops.LinearOperator(torch.from_numpy(d1ops[1]).to(mps), (Nz, Ny, Nx))
Dx = op.fftops.LinearOperator(torch.from_numpy(d1ops[2]).to(mps), (Nz, Ny, Nx))
gg = H.dot(torch_sphere)

In [30]:
gz = Dz.dot(torch_sphere)
gy = Dy.dot(torch_sphere)
gx = Dx.dot(torch_sphere)
gyx = Dx.dot(gy)
gyz = Dz.dot(gy)

In [42]:
viewer = viz.SimpleOrthoViewer(gg[5].cpu().numpy(), cmap=cc.m_gouldian)

In [43]:
len(gg)

6