Skip to content

Commit

Permalink
Merge pull request #36 from mikgroup/restructure-interp
Browse files Browse the repository at this point in the history
Restructure interpolation
  • Loading branch information
frankong committed Feb 2, 2020
2 parents 01c6320 + f232c28 commit ac8f0f2
Show file tree
Hide file tree
Showing 6 changed files with 660 additions and 592 deletions.
40 changes: 8 additions & 32 deletions sigpy/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def ifft(input, oshape=None, axes=None, center=True, norm='ortho'):
return output


def nufft(input, coord, oversamp=1.25, width=4.0, n=128):
def nufft(input, coord, oversamp=1.25, width=4):
"""Non-uniform Fast Fourier Transform.
Args:
Expand Down Expand Up @@ -125,9 +125,9 @@ def nufft(input, coord, oversamp=1.25, width=4.0, n=128):

# Interpolate
coord = _scale_coord(coord, input.shape, oversamp)
kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype,
backend.get_device(input))
output = interp.interpolate(output, width, kernel, coord)
output = interp.interpolate(
output, coord, kernel='kaiser_bessel', width=width, param=beta)
output /= width**ndim

return output

Expand All @@ -149,7 +149,7 @@ def estimate_shape(coord):
return shape


def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4.0, n=128):
def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4):
"""Adjoint non-uniform Fast Fourier Transform.
Args:
Expand Down Expand Up @@ -187,9 +187,9 @@ def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4.0, n=128):

# Gridding
coord = _scale_coord(coord, oshape, oversamp)
kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype,
backend.get_device(input))
output = interp.gridding(input, os_shape, width, kernel, coord)
output = interp.gridding(input, coord, os_shape,
kernel='kaiser_bessel', width=width, param=beta)
output /= width**ndim

# IFFT
output = ifft(output, axes=range(-ndim, 0), norm=None)
Expand Down Expand Up @@ -235,30 +235,6 @@ def _ifftc(input, oshape=None, axes=None, norm='ortho'):
return output


def _get_kaiser_bessel_kernel(n, width, beta, dtype, device):
"""Precompute Kaiser Bessel kernel.
Precomputes Kaiser-Bessel kernel with n points.
Args:
n (int): number of sampling points.
width (float): kernel width.
beta (float): kaiser bessel parameter.
dtype (dtype): output data type.
device (Device): output device.
Returns:
array: Kaiser-Bessel kernel table.
"""
device = backend.Device(device)
xp = device.xp
with device:
x = xp.arange(n, dtype=dtype) / n
kernel = 1 / width * xp.i0(beta * (1 - x**2)**0.5).astype(dtype)
return kernel


def _scale_coord(coord, shape, oversamp):
ndim = coord.shape[-1]
output = coord.copy()
Expand Down

0 comments on commit ac8f0f2

Please sign in to comment.