Skip to content

Commit

Permalink
Merge pull request #92 from mikgroup/dev
Browse files Browse the repository at this point in the history
Normal Operator API with Toeplitz NUFFT
  • Loading branch information
sidward committed Aug 18, 2021
2 parents dd77e16 + e84ca32 commit 15343d6
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 52 deletions.
17 changes: 8 additions & 9 deletions sigpy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _get_alg(self):

def _get_ConjugateGradient(self):
I = linop.Identity(self.x.shape)
AHA = self.A.H * self.A
AHA = self.A.N
AHy = self.A.H(self.y)

if self.lamda != 0:
Expand All @@ -275,13 +275,12 @@ def _get_ConjugateGradient(self):
tol=self.tol)

def _get_GradientMethod(self):
def gradf(x):
with self.y_device:
r = self.A(x)
r -= self.y
with self.y_device:
AHy = self.A.H(self.y)

def gradf(x):
with self.x_device:
gradf_x = self.A.H(r)
gradf_x = self.A.N(x) - AHy
if self.lamda != 0:
if self.z is None:
util.axpy(gradf_x, self.lamda, x)
Expand All @@ -292,7 +291,7 @@ def gradf(x):

if self.alpha is None:
I = linop.Identity(self.x.shape)
AHA = self.A.H * self.A
AHA = self.A.N
if self.lamda != 0:
AHA += self.lamda * I

Expand Down Expand Up @@ -410,7 +409,7 @@ def minL_x():
if self.z is not None:
AHy += self.lamda * self.z

AHA = self.A.H * self.A
AHA = self.A.N
I = linop.Identity(self.x.shape)
if self.G is None:
AHA += (self.lamda + self.rho) * I
Expand Down Expand Up @@ -500,7 +499,7 @@ def __init__(self, A, y, proxg, eps, x=None, G=None,
self.x_device = backend.get_device(self.x)
if G is None:
self.max_eig_app = MaxEig(
A.H * A, dtype=self.x.dtype, device=self.x_device,
A.N, dtype=self.x.dtype, device=self.x_device,
show_pbar=show_pbar)

proxfc = prox.Conj(prox.L2Proj(A.oshape, eps, y=y))
Expand Down
56 changes: 51 additions & 5 deletions sigpy/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from sigpy import backend, interp, util


__all__ = ['fft', 'ifft', 'nufft', 'nufft_adjoint', 'estimate_shape']
__all__ = ['fft', 'ifft', 'nufft', 'nufft_adjoint', 'estimate_shape',
'toeplitz_psf']


def fft(input, oshape=None, axes=None, center=True, norm='ortho'):
Expand All @@ -29,7 +30,7 @@ def fft(input, oshape=None, axes=None, center=True, norm='ortho'):
"""
xp = backend.get_array_module(input)
if not np.issubdtype(input.dtype, np.complexfloating):
input = input.astype(np.complex)
input = input.astype(np.complex64)

if center:
output = _fftc(input, oshape=oshape, axes=axes, norm=norm)
Expand Down Expand Up @@ -62,7 +63,7 @@ def ifft(input, oshape=None, axes=None, center=True, norm='ortho'):
"""
xp = backend.get_array_module(input)
if not np.issubdtype(input.dtype, np.complexfloating):
input = input.astype(np.complex)
input = input.astype(np.complex64)

if center:
output = _ifftc(input, oshape=oshape, axes=axes, norm=norm)
Expand Down Expand Up @@ -92,7 +93,6 @@ def nufft(input, coord, oversamp=1.25, width=4):
oversamp (float): oversampling factor.
width (float): interpolation kernel full-width in terms of
oversampled grid.
n (int): number of sampling points of the interpolation kernel.
Returns:
array: Fourier domain data of shape
Expand Down Expand Up @@ -167,7 +167,6 @@ def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4):
oversamp (float): oversampling factor.
width (float): interpolation kernel full-width in terms of
oversampled grid.
n (int): number of sampling points of the interpolation kernel.
Returns:
array: signal domain array with shape specified by oshape.
Expand Down Expand Up @@ -204,6 +203,53 @@ def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4):
return output


def toeplitz_psf(coord, shape, oversamp=1.25, width=4):
"""Toeplitz PSF for fast Normal non-uniform Fast Fourier Transform.
While fast, this is more computationally expensive.
Args:
coord (array): Fourier domain coordinate array of shape (..., ndim).
ndim determines the number of dimension to apply nufft adjoint.
coord[..., i] should be scaled to have its range between
-n_i // 2, and n_i // 2.
shape (tuple of ints): shape of the form
(..., n_{ndim - 1}, ..., n_1, n_0).
This is the shape of the input array of the forward nufft.
oversamp (float): oversampling factor.
width (float): interpolation kernel full-width in terms of
oversampled grid.
Returns:
array: PSF to be used by the normal operator defined in
`sigpy.linop.NUFFT`
See Also:
:func:`sigpy.linop.NUFFT`
"""
xp = backend.get_array_module(coord)
with backend.get_device(coord):
ndim = coord.shape[-1]

new_shape = _get_oversamp_shape(shape, ndim, 2)
new_coord = _scale_coord(coord, new_shape, 2)

idx = [slice(None)]*len(new_shape)
for k in range(-1, -(ndim + 1), -1):
idx[k] = new_shape[k]//2

d = xp.zeros(new_shape, dtype=xp.complex64)
d[tuple(idx)] = 1

psf = nufft(d, new_coord, oversamp, width)
psf = nufft_adjoint(psf, new_coord, d.shape, oversamp, width)
fft_axes = tuple(range(-1, -(ndim + 1), -1))
psf = fft(psf, axes=fft_axes, norm=None) * (2**ndim)

return psf


def _fftc(input, oshape=None, axes=None, norm='ortho'):

ndim = input.ndim
Expand Down
73 changes: 69 additions & 4 deletions sigpy/linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Linop():
oshape: output shape.
ishape: input shape.
H: adjoint linear operator.
N: normal linear operator.
"""
def __init__(self, oshape, ishape, repr_str=None):
Expand All @@ -60,6 +61,9 @@ def __init__(self, oshape, ishape, repr_str=None):
else:
self.repr_str = repr_str

self.adj = None
self.normal = None

def _check_ishape(self, input):
for i1, i2 in zip(input.shape, self.ishape):
if i2 != -1 and i1 != i2:
Expand Down Expand Up @@ -102,6 +106,9 @@ def apply(self, input):
def _adjoint_linop(self):
raise NotImplementedError

def _normal_linop(self):
return self.H * self

@property
def H(self):
r"""Return adjoint linear operator.
Expand All @@ -116,7 +123,24 @@ def H(self):
Linop: adjoint linear operator.
"""
return self._adjoint_linop()
if self.adj is None:
self.adj = self._adjoint_linop()
return self.adj

@property
def N(self):
r"""Return normal linear operator.
A normal linear operator :math:`A^HA` for
a linear operator :math:`A`.
Returns:
Linop: adjoint linear operator.
"""
if self.normal is None:
self.normal = self._normal_linop()
return self.normal

def __call__(self, input):
return self.__mul__(input)
Expand Down Expand Up @@ -175,6 +199,9 @@ def _apply(self, input):
def _adjoint_linop(self):
return self

def _normal_linop(self):
return self


class ToDevice(Linop):
"""Move input between devices.
Expand Down Expand Up @@ -652,6 +679,9 @@ def _apply(self, input):
def _adjoint_linop(self):
return Reshape(self.ishape, self.oshape)

def _normal_linop(self):
return Identity(self.ishape)


class Transpose(Linop):
"""Tranpose input with the given axes.
Expand Down Expand Up @@ -687,6 +717,9 @@ def _adjoint_linop(self):

return Transpose(oshape, axes=iaxes)

def _normal_linop(self):
return Identity(self.ishape)


class FFT(Linop):
"""FFT linear operator.
Expand All @@ -712,6 +745,9 @@ def _apply(self, input):
def _adjoint_linop(self):
return IFFT(self.ishape, axes=self.axes, center=self.center)

def _normal_linop(self):
return Identity(self.ishape)


class IFFT(Linop):
"""IFFT linear operator.
Expand All @@ -738,6 +774,9 @@ def _apply(self, input):
def _adjoint_linop(self):
return FFT(self.ishape, axes=self.axes, center=self.center)

def _normal_linop(self):
return Identity(self.ishape)


def _get_matmul_oshape(ishape, mshape, adjoint):
ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape)
Expand Down Expand Up @@ -1169,6 +1208,9 @@ def _apply(self, input):
def _adjoint_linop(self):
return Circshift(self.ishape, [-s for s in self.shift], axes=self.axes)

def _normal_linop(self):
return Identity(self.ishape)


class Wavelet(Linop):
"""Wavelet transform linear operator.
Expand Down Expand Up @@ -1262,7 +1304,6 @@ def _apply(self, input):
return xp.sum(input, axis=self.axes)

def _adjoint_linop(self):

return Tile(self.ishape, self.axes)


Expand Down Expand Up @@ -1333,6 +1374,9 @@ def _apply(self, input):
def _adjoint_linop(self):
return BlocksToArray(self.ishape, self.blk_shape, self.blk_strides)

def _normal_linop(self):
return Identity(self.ishape)


class BlocksToArray(Linop):
"""Accumulate blocks into an array in a sliding window manner.
Expand Down Expand Up @@ -1365,6 +1409,9 @@ def _apply(self, input):
def _adjoint_linop(self):
return ArrayToBlocks(self.oshape, self.blk_shape, self.blk_strides)

def _normal_linop(self):
return Identity(self.ishape)


def Gradient(ishape, axes=None):
import warnings
Expand Down Expand Up @@ -1407,13 +1454,14 @@ class NUFFT(Linop):
coord (array): Coordinates, with values [-ishape / 2, ishape / 2]
oversamp (float): Oversampling factor.
width (float): Kernel width.
n (int): Kernel sampling number.
toeplitz (bool): Use toeplitz PSF to evaluate normal operator.
"""
def __init__(self, ishape, coord, oversamp=1.25, width=4):
def __init__(self, ishape, coord, oversamp=1.25, width=4, toeplitz=False):
self.coord = coord
self.oversamp = oversamp
self.width = width
self.toeplitz = toeplitz

ndim = coord.shape[-1]

Expand All @@ -1433,6 +1481,23 @@ def _adjoint_linop(self):
return NUFFTAdjoint(self.ishape, self.coord,
oversamp=self.oversamp, width=self.width)

def _normal_linop(self):
if self.toeplitz is False:
return self.H * self

ndim = self.coord.shape[-1]
psf = fourier.toeplitz_psf(self.coord, self.ishape, self.oversamp,
self.width)

fft_axes = tuple(range(-1, -(ndim + 1), -1))

R = Resize(psf.shape, self.ishape)
F = FFT(psf.shape, axes=fft_axes)
P = Multiply(psf.shape, psf)
T = R.H * F.H * P * F * R

return T


class NUFFTAdjoint(Linop):
"""NUFFT adjoint linear operator.
Expand Down
2 changes: 1 addition & 1 deletion sigpy/mri/rf/adiabatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def bir4(n, beta, kappa, theta, dw0):
a3 = np.tanh(beta * (3 - 4 * t[n // 2:3 * n // 4]))
a4 = np.tanh(beta * (4 * t[3 * n // 4:] - 3))

a = np.concatenate((a1, a2, a3, a4)).astype(complex)
a = np.concatenate((a1, a2, a3, a4)).astype(np.complex64)
a[n // 4:3 * n // 4] = a[n // 4:3 * n // 4] * np.exp(1j * dphi)

om1 = dw0 * np.tan(kappa * 4 * t[:n // 4]) / np.tan(kappa)
Expand Down
8 changes: 4 additions & 4 deletions sigpy/mri/rf/b1sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def dz_b1_rf(dt=2e-6, tb=4, ptype='st', flip=np.pi / 6, pbw=0.3,
pulse_len = tb / b

# calculate number of samples in pulse
n = np.int(np.ceil(pulse_len / dt / 2) * 2)
n = int(np.ceil(pulse_len / dt / 2) * 2)

if pbc == 0:
# we want passband as close to zero as possible.
Expand All @@ -72,7 +72,7 @@ def dz_b1_rf(dt=2e-6, tb=4, ptype='st', flip=np.pi / 6, pbw=0.3,
ds = np.double(np.abs(ii) > f[2])

# shift the target pattern to minimum center position
pbc = np.int(np.ceil((f[2] - f[1]) * n * os / 2 + f[1] * n * os / 2))
pbc = int(np.ceil((f[2] - f[1]) * n * os / 2 + f[1] * n * os / 2))
dl = np.roll(d, pbc)
dr = np.roll(d, -pbc)
dsl = np.roll(ds, pbc)
Expand Down Expand Up @@ -155,7 +155,7 @@ def dz_b1_gslider_rf(dt=2e-6, g=5, tb=12, ptype='st', flip=np.pi / 6,
pulse_len = tb / b

# calculate number of samples in pulse
n = np.int(np.ceil(pulse_len / dt / 2) * 2)
n = int(np.ceil(pulse_len / dt / 2) * 2)

om = 2 * np.pi * 4257 * pbc # modulation freq to center profile at pbc
t = np.arange(0, n) * pulse_len / n - pulse_len / 2
Expand Down Expand Up @@ -228,7 +228,7 @@ def dz_b1_hadamard_rf(dt=2e-6, g=8, tb=16, ptype='st', flip=np.pi / 6,
pulse_len = tb / b

# calculate number of samples in pulse
n = np.int(np.ceil(pulse_len / dt / 2) * 2)
n = int(np.ceil(pulse_len / dt / 2) * 2)

# modulation frequency to center profile at pbc gauss
om = 2 * np.pi * 4257 * pbc
Expand Down

0 comments on commit 15343d6

Please sign in to comment.