In [None]:
using Documenter, EllipsisNotation, FFTW, PaddedViews, SpecialFunctions, Base.Cartesian, Base.Threads

In [2]:
using LinearAlgebra, MacroTools, NFFT, BenchmarkTools
using PyCall, Libdl
Libdl.dlopen(ENV["HOME"]*"/.local/lib/python3.6/site-packages/llvmlite/binding/libllvmlite.so",
    Libdl.RTLD_DEEPBIND);
py"""
from math import ceil
import numpy as np
import sigpy as sp
from sigpy import util, interp
"""

In [3]:
@doc raw"""
Non-uniform Fast Fourier Transform.

**Arguments**:
- `input (ArrayType{T} or ArrayType{Complex{T}})`: input signal domain array of shape
    ``(n_k, \ldots, n_{ndim + 1}, n_{ndim}, \ldots, n_2, n_1)``,
    where ``ndim`` is specified by `size(coord)[end]`. The nufft
    is applied on the last ``ndim`` axes, and looped over
    the remaining axes. `ArrayType` can be any `AbstractArray`.
- `coord (ArrayType{T})`: Fourier domain coordinate array of shape ``(m_l, \ldots, m_1, ndim)``.
    ``ndim`` determines the number of dimensions to apply the nufft.
    `coord[..., i]` should be scaled to have its range ``[-n_i \div 2, n_i \div 2]``.
- `oversamp (Float32)`: oversampling factor (default: $1.25$)
- `width (Int64)`: interpolation kernel full-width in terms of
    oversampled grid. (default: $4$)
- `n (Int64)`: number of sampling points of the interpolation kernel. (default: $128$)

**Returns**:
- `ArrayType{Complex{T}}`: Fourier domain data of shape
    ``(n_k, \ldots, n_{ndim + 1}, m_l, \ldots, m_1)``.

**References**:
- Fessler, J. A., & Sutton, B. P. (2003).
  "Nonuniform fast Fourier transforms using min-max interpolation",
  *IEEE Transactions on Signal Processing*, 51(2), 560-574.
- Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005).
  "Rapid gridding reconstruction with a minimal oversampling ratio,"
  *IEEE transactions on medical imaging*, 24(6), 799-808.

"""
function nufft(
        input::AbstractArray{Complex{T}},
        coord::AbstractArray{T},
        oversamp::Float64 = 1.25,
        width::Int64 = 4,
        n::Int64 = 128) where {T}
    
    ndim = size(coord)[end]
    @assert(ndims(input) ≥ ndim,
        "The size of coord along the last dimension should be greater or equal then the dimensionality of input.")
    
    β = π * √(((width / oversamp) * (oversamp - 0.5))^2 - 0.8)
    shape = size(input)
    oversampled_shape = _get_oversamp_shape(shape, ndim, oversamp)

    output = copy(input)

    # Apodize
    _apodize!(output, ndim, oversamp, width, β)

    # Zero-pad
    output /= √(prod(shape[end-ndim+1:end]))
    shift = oversampled_shape .÷ 2 .- shape .÷ 2 .+ 1
    output = PaddedView(0, output, oversampled_shape, shift)

    # FFT
    all_dims = ndims(input)
    output = centering_fft!(convert.(Complex, output), tuple((all_dims-ndim+1:all_dims)...))

    # Interpolate
    coord = _scale_coord(coord, size(input), oversamp)
    x = range(0, stop=n-1, step=1) ./ n
    kernel = window_kaiser_bessel.(x, width, β)
    return interpolate(output, width, kernel, coord)
    
end

nufft

In [4]:
function nufft(
        input::AbstractArray{T},
        coord::AbstractArray{T},
        oversamp::T = 1.25,
        width::Int64 = 4,
        n::Int64 = 128) where {
            T<:Real,
            FloatArray<:AbstractArray{T}}
    return nufft(convert.(Complex, img), coord, oversamp, width)
end

nufft (generic function with 8 methods)

In [79]:
@doc raw"""
Adjoint non-uniform Fast Fourier Transform.

**Arguments**:
- `input (ArrayType{T} or ArrayType{Complex{T}})`: input Fourier domain array of shape
    ``(n_k, \ldots, n_{l + 1}, m_l, \ldots, m_1)``,
    where ``ndim`` is specified by `size(coord)[end]`.
    That is, the last dimensions
    of input must match the first dimensions of coord.
    The nufft_adjoint is applied on the last coord.ndim - 1 axes,
    and looped over the remaining axes.
- `coord (ArrayType{T})`: Fourier domain coordinate array of shape ``(m_l, \ldots, m_1, ndim)``.
    ``ndim`` determines the number of dimensions to apply the nufft.
    `coord[..., i]` should be scaled to have its range ``[-n_i \div 2, n_i \div 2]``.
- `oshape (NTuple{N, Int})`: output shape of the form
            ``(o_l, \ldots, o_{ndim + 1}, n_{ndim}, \ldots, n_2, n_1)``. (optional)
- `oversamp (Float32)`: oversampling factor (default: $1.25$)
- `width (Int64)`: interpolation kernel full-width in terms of
    oversampled grid. (default: $4$)
- `n (Int64)`: number of sampling points of the interpolation kernel. (default: $128$)

**Returns**:
- `ArrayType{Complex{T}}`: Fourier domain data of shape
    ``(n_k, \ldots, n_{ndim + 1}, n_{ndim}, \ldots, n_2, n_1)`` or
    ``(o_l, \ldots, o_{ndim + 1}, n_{ndim}, \ldots, n_2, n_1)`` if `oshape` is given.

**References**:
- Fessler, J. A., & Sutton, B. P. (2003).
  "Nonuniform fast Fourier transforms using min-max interpolation",
  *IEEE Transactions on Signal Processing*, 51(2), 560-574.
- Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005).
  "Rapid gridding reconstruction with a minimal oversampling ratio,"
  *IEEE transactions on medical imaging*, 24(6), 799-808.

"""
function nufft_adjoint(
        input::AbstractArray{Complex{T}},
        coord::AbstractArray{T},
        oshape::Union{NTuple{N, Int}, Nothing} = nothing,
        oversamp::Float64 = 1.25,
        width::Int64 = 4,
        n::Int64 = 128) where {T, N}
    
    ndim = size(coord)[end]
    
    β = π * √(((width / oversamp) * (oversamp - 0.5))^2 - 0.8)
    (oshape isa Nothing) && (oshape = tuple(size(input)[1:end-ndims(coord)+1]..., estimate_shape(coord)...))
    oversampled_shape = _get_oversamp_shape(oshape, ndim, oversamp)

    # Gridding
    coord = _scale_coord(coord, oshape, oversamp)
    x = range(0, stop=n-1, step=1) ./ n
    kernel = window_kaiser_bessel.(x, width, β)
    output = gridding(input, oversampled_shape, width, kernel, coord)

    # IFFT
    all_dims = ndims(output)
    output = centering_ifft!(convert.(Complex, output), tuple((all_dims-ndim+1:all_dims)...))

    # Crop
    output = resize(output, oshape)
    output .*= prod(oversampled_shape[end-ndim+1:end]) ./ √(prod(oshape[end-ndim+1:end]))

    # Apodize
    _apodize!(output, ndim, oversamp, width, β)

    return output
end

nufft_adjoint

In [80]:
"""
Estimate array shape from coordinates.

Shape is estimated by the different between maximum and minimum of
coordinates in each axis.

Args:
    `coord (AbstractArray)`: Coordinates.
"""
function estimate_shape(coord)
    dims = tuple((1:ndims(coord)-1)...)
    return floor.(Int64, dropdims(maximum(coord, dims=dims), dims=dims) -
        dropdims(minimum(coord, dims=dims), dims=(dims)))
end

estimate_shape

In [81]:
function fftshift!(
        output::AbstractArray,
        input::AbstractArray,
        dims::NTuple{N,Int}) where {N}
    
    @assert input !== output "input and output must be two distinct arrays"
    @assert any(dims .> 0) "dims can contain only positive values!"
    @assert any(dims .<= ndims(input)) "dims cannot contain larger value than ndims(input) (=$(ndims(input)))"
    @assert size(output) == size(input) "input and output must have the same size"
    @assert eltype(output) == eltype(input) "input and output must have the same eltype"
    
    shifts = [dim in dims ? size(input, dim) ÷ 2 : 0 for dim in 1:ndims(input)]
    circshift!(output, input, shifts)
    
end

function ifftshift!(
        output::AbstractArray,
        input::AbstractArray,
        dims::NTuple{N,Int}) where {N}
    
    @assert input !== output "input and output must be two distinct arrays"
    @assert any(dims .> 0) "dims can contain only positive values!"
    @assert any(dims .<= ndims(input)) "dims cannot contain larger value than ndims(input) (=$(ndims(input)))"
    @assert size(output) == size(input) "input and output must have the same size"
    @assert eltype(output) == eltype(input) "input and output must have the same eltype"
    
    shifts = [dim in dims ? size(input, dim) ÷ 2 + size(input, dim) % 2 : 0 for dim in 1:ndims(input)]
    circshift!(output, input, shifts)
    
end

fftshift!(output::AbstractArray, input::AbstractArray, dims::Int) =
    fftshift!(output, input, (dims,))

ifftshift!(output::AbstractArray, input::AbstractArray, dims::Int) =
    ifftshift!(output, input, (dims,))

ifftshift! (generic function with 2 methods)

In [82]:
"""
FFT function that supports centering.

**Arguments**:
    input (AbstractArray{T,N}): input array.
    dims (NTuple{K,Int64}): Axes over which to compute the FFT (optional).

**Returns**:
    AbstractArray: FFT result.

"""
function centering_fft!(
        input::AbstractArray{Complex{T},N},
        dims::Union{NTuple{K,Int64},Nothing} = nothing) where {N,K,T}
    (dims isa Nothing) && (dims = tuple(collect(1:ndims(input))...))
    output = ifftshift(input, dims)
    fft!(output, dims)
    fftshift!(input, output, dims)
    input
end

centering_fft!

In [83]:
"""
inverse FFT function that supports centering.

**Arguments**:
    input (AbstractArray{T,N}): input array.
    dims (NTuple{K,Int64}): Axes over which to compute the inverse FFT (optional).

**Returns**:
    AbstractArray: inverse FFT result.

"""
function centering_ifft!(
        input::AbstractArray{Complex{T},N},
        dims::Union{NTuple{K,Int64},Nothing} = nothing) where {N,K,T}
    (dims isa Nothing) && (dims = tuple(collect(1:ndims(input))...))
    output = ifftshift(input, dims)
    ifft!(output, dims)
    fftshift!(input, output, dims)
    input
end

centering_ifft!

In [84]:
"""
Resize with zero-padding or cropping.

**Arguments**:
    `input (AbstractArray{T,N})`: Input array.
    `oshape (NTuple{N,Int64})`: Output shape.
    `ishift (NTuple{N,Int64})`: Input shift (optional).
    `oshift (NTuple{N,Int64})`: Output shift (optional).

**Returns**:
    `array`: Zero-padded or cropped result.
"""
function resize(
        input::AbstractArray{T,N},
        oshape::NTuple{N,Int64},
        ishift::Union{NTuple{N,Int64}, Nothing} = nothing,
        oshift::Union{NTuple{N,Int64}, Nothing} = nothing) where {N,T}

    ishape1, oshape1 = _expand_shapes(size(input), oshape)

    if ishape1 == oshape1
        return reshape(input, oshape)
    end

    if ishift isa Nothing
        ishift = [max(i ÷ 2 - o ÷ 2, 0) for (i, o) in zip(collect(ishape1), collect(oshape1))]
    end

    if oshift isa Nothing
        oshift = [max(o ÷ 2 - i ÷ 2, 0) for (i, o) in zip(collect(ishape1), collect(oshape1))]
    end

    copy_shape = [min(i - si, o - so)
                  for (i, si, o, so) in zip(collect(ishape1), ishift, collect(oshape1), oshift)]
    islice = collect(si+1:si+c for (si, c) in zip(ishift, copy_shape))
    oslice = collect(so+1:so+c for (so, c) in zip(oshift, copy_shape))

    output = zeros(eltype(input), oshape1)
    input = reshape(input, ishape1)
    output[oslice...] = input[islice...]

    return reshape(output, oshape)
end

resize

In [85]:
function _get_oversamp_shape(shape, ndim, oversamp)
    return tuple(vcat(shape[1:end-ndim]..., [ceil(Int64, oversamp * i) for i in shape[end-ndim+1:end]]...)...)
end

_get_oversamp_shape (generic function with 1 method)

In [86]:
function _apodize!(signal, apodized_dims, oversamp, width, β)
    
    all_dims = ndims(signal)
    untouched_dims = all_dims - apodized_dims
    for axis in range(untouched_dims + 1, all_dims, step=1)
        axis_size = size(signal, axis)
        oversampled_size = ceil(oversamp * axis_size)
        
        # Calculate apodization window
        recip_iFFT_Kaiser_Bessel_kernel(x) = begin
            tmp = √(β^2 - (π * width * (x - axis_size ÷ 2) / oversampled_size)^2)
            tmp /= sinh(tmp)
            tmp
        end
        window = recip_iFFT_Kaiser_Bessel_kernel.(0:axis_size-1)
        
        # Apply point-wise along selected axis, broadcast along all other dimensions
        broadcast_shape = ones(Int64, all_dims)
        broadcast_shape[axis] = axis_size
        signal .*= reshape(window, broadcast_shape...)
    end
    return signal
end

_apodize! (generic function with 1 method)

In [87]:
function _scale_coord(coord, shape, oversamp)
    ndim = size(coord)[end]
    scale = reshape([ceil(oversamp * i) / i for i in shape[end-ndim+1:end]], 1, ndim)
    shift = reshape([ceil(oversamp * i) ÷ 2 for i in shape[end-ndim+1:end]], 1, ndim)
    return scale .* coord .+ shift
end

_scale_coord (generic function with 1 method)

In [88]:
window_kaiser_bessel(x::Real, m::Int64, β::Real)::Real = 1 / m * besseli(0, β * √(1 - x^2))

window_kaiser_bessel (generic function with 1 method)

In [89]:
function _expand_shapes(shapes...)
    max_ndim = maximum(length, shapes)
    return map(shape -> tuple(vcat(repeat([1], max_ndim - length(shape)), collect(shape))...), shapes)
end

_expand_shapes (generic function with 1 method)

In [90]:
@doc raw"""
    Interpolation from array to points specified by coordinates.

    Let ``x`` be the input, ``y`` be the output,
    ``c`` be the coordinates, ``W`` be the kernel width,
    and ``K`` be the interpolation kernel, then the function computes,

    ```math
        y[j] = \sum_{i : \| i - c[j] \|_\infty \leq W / 2}
               K\left(\frac{i - c[j]}{W / 2}\right) x[i]
    ```

    **Arguments**:
        `input (AbstractArray)`: Input array of shape
            ``(n_1, n_2, \ldots, n_{ndim})``.
        `width (Int)`: Interpolation kernel full-width.
        `kernel (AbstractArray{T, 1})`: Interpolation kernel.
        `coord (AbstractArray)`: Coordinate array of shape ``(m_l, \ldots, m_1, ndim)``

    **Returns**:
        `output (AbstractArray)`: Output array of shape ``(m_l, \ldots, m_1)``
    """
function interpolate(
        input::AbstractArray,
        width::Int,
        kernel::AbstractArray{T, 1},
        coord::AbstractArray{T}) where {T<:Real}
    ndim = size(coord, 2)
    npts = size(coord, 1)
    @assert(ndims(input) ≥ ndim, "The size of coord along the last dimension
        should be greater or equal then the dimensionality of input.")
    
    is_complex(x::AbstractArray{Complex{T}}) where T = true
    is_complex(x::AbstractArray{T}) where T = false
    get_complex_subtype(x::AbstractArray{Complex{T}}) where T = T
    if is_complex(input)
        @assert(get_complex_subtype(input) == eltype(coord),
            "Precision of eltype of input and coord should match: $(get_complex_subtype(input)) in $(eltype(input)) vs $(eltype(coord))")
    else
        @assert(eltype(input) == eltype(coord),
            "Precision of eltype of input and coord should match: $(eltype(input)) vs $(eltype(coord))")
    end
    
    batch_shape = size(input)[1:end-ndim]
    batch_size = prod(batch_shape)

    pts_shape = size(coord)[1:end-1]
    npts = prod(pts_shape)

    input = reshape(input, tuple(batch_size, size(input)[end-ndim+1:end]...))
    coord = reshape(coord, (npts, ndim))
    output = zeros(eltype(input), (batch_size, npts))

    _interpolate!(output, input, coord, kernel, width)

    return reshape(output, tuple(batch_shape..., pts_shape...))
end

interpolate

In [91]:
@doc raw"""
    Gridding of points specified by coordinates to array.

    Let ``x`` be the input, ``y`` be the output,
    ``c`` be the coordinates, ``W`` be the kernel width,
    and ``K`` be the interpolation kernel, then the function computes,

    ```math
        y[j] = \sum_{i : \| i - c[j] \|_\infty \leq W / 2}
               K\left(\frac{i - c[j]}{W / 2}\right) x[i]
    ```

    **Arguments**:
        `input (AbstractArray)`: Input array of shape
            ``(m_l, \ldots, m_1)``.
        `oshape (Ntuple{N, Int})`: Shape of output
        `width (Real or NTuple{N, Real})`: Interpolation kernel full-width.
        `kernel (AbstractArray{T, 1})`: Interpolation kernel.
        `coord (AbstractArray)`: Coordinate array of shape ````(m_l, \ldots, m_1, ndim)``

    **Returns**:
        `output (AbstractArray)`: Output array.
    """
function gridding(
        input::AbstractArray,
        oshape::NTuple{N, Int},
        width::Int64,
        kernel::AbstractArray{T, 1},
        coord::AbstractArray{T}) where {T<:Real, N}
    
    ndim = size(coord, 2)
    batch_shape = oshape[1:end-ndim]
    batch_size = prod(batch_shape)
    
    pts_shape = size(coord)[1:end-1]
    npts = prod(pts_shape)
    
    @assert(size(input)[length(batch_shape)+1:end] == pts_shape,
        "The size of input must match the size of coord, excluding the last dimension.")
    
    is_complex(x::AbstractArray{Complex{T}}) where T = true
    is_complex(x::AbstractArray{T}) where T = false
    get_complex_subtype(x::AbstractArray{Complex{T}}) where T = T
    if is_complex(input)
        @assert(get_complex_subtype(input) == eltype(coord),
            "Precision of eltype of input and coord should match: $(get_complex_subtype(input)) in $(eltype(input)) vs $(eltype(coord))")
    else
        @assert(eltype(input) == eltype(coord),
            "Precision of eltype of input and coord should match: $(eltype(input)) vs $(eltype(coord))")
    end

    input = reshape(input, tuple(batch_size, npts))
    coord = reshape(coord, (npts, ndim))
    output = zeros(eltype(input), tuple(batch_size, oshape[end-ndim+1:end]...))

    _gridding!(output, input, coord, kernel, width)

    return reshape(output, oshape)
end

gridding

In [18]:
function lin_interpolate(kernel::AbstractArray{T,1}, x::T)::T where T
    x >= 1 && return zero(x)
    
    n = length(kernel)
    idx = floor(Int64, x * n)
    frac = x * n - idx

    left = kernel[idx + 1]
    right = idx < n - 1 ? kernel[idx + 2] : zero(x)
    
    return (1 - frac) * left + frac * right
end

lin_interpolate (generic function with 1 method)

In [19]:
macro interpolate_point(ndim)
    batch_plus_ndim = ndim + 1
    esc(quote
        input_shape = size(input)
        batch_size = input_shape[1]
        @nextract($ndim, n, d -> input_shape[$batch_plus_ndim-d+1])
        @nextract($ndim, interval_middle,
                d -> coord[point_index, $batch_plus_ndim - d])
        @nextract($ndim, interval_start,
                d -> ceil(Int64, interval_middle_d - width / 2))
        @nextract($ndim, interval_end,
                d -> floor(Int64, interval_middle_d + width / 2))
        $(Symbol(:w_, batch_plus_ndim)) = 1
        @nloops($ndim, i, d -> interval_start_d:interval_end_d,
            d -> w_d = w_{d+1} *
                lin_interpolate(kernel, abs(i_d - interval_middle_d) / (width / 2)),
            for b in 1:batch_size
                output[b, point_index] += w_1 * input[b,
                    @ntuple($ndim, d -> (i_{$ndim-d+1} + n_{$ndim-d+1}) % n_{$ndim-d+1} + 1)...]
            end)
    end)
end

@interpolate_point (macro with 1 method)

In [195]:
macro gridding_point(ndim)
    batch_plus_ndim = ndim + 1
    esc(quote
        output_shape = size(output[1])
        batch_size = output_shape[1]
        @nextract($ndim, n, d -> output_shape[$batch_plus_ndim-d+1])
        @nextract($ndim, interval_middle,
                d -> coord[point_index, $batch_plus_ndim - d])
        @nextract($ndim, interval_start,
                d -> ceil(Int64, interval_middle_d - width / 2))
        @nextract($ndim, interval_end,
                d -> floor(Int64, interval_middle_d + width / 2))
        $(Symbol(:w_, batch_plus_ndim)) = 1
        @nloops($ndim, i, d -> interval_start_d:interval_end_d,
            d -> w_d = w_{d+1} *
                lin_interpolate(kernel, abs(i_d - interval_middle_d) / (width / 2)),
            for b in 1:batch_size
                output[thread_id][b, @ntuple($ndim, d -> (i_{$ndim-d+1} + n_{$ndim-d+1}) % n_{$ndim-d+1} + 1)...] +=
                    w_1 * input[b, point_index] 
            end)
    end)
end

@gridding_point (macro with 1 method)

In [196]:
# It is not a necessary function, but might help to debug transformed code
# Add line numbers matching the lines of transformed code (the call stack will be more 
# informative then if an error occures inside the transformed code)
function addLineNumbers(expr)
    # Remove previous line numbers and add correct ones instead
    result = Meta.parse(string(MacroTools.striplines(expr)))
    # fix LineNumberNodes to be more informative
    MacroTools.postwalk(result) do x
        x isa LineNumberNode ?
            LineNumberNode(x.line, Symbol("generated_code_of_interpolate")) :
            x
    end
end

addLineNumbers (generic function with 1 method)

In [197]:
function pretty_print_expression(expr; withLineNumbers=false)
    expr = MacroTools.prewalk(x -> MacroTools.isgensym(x) ? Symbol(MacroTools.gensymname(x)) : x, expr) 
    expr = MacroTools.prewalk(unblock, expr)
    expr = addLineNumbers(expr)
    print(withLineNumbers ? expr : MacroTools.striplines(expr))
end

pretty_print_expression (generic function with 1 method)

In [198]:
pretty_print_expression(@macroexpand(@gridding_point(3)), withLineNumbers=false)

begin
    output_shape = size(output[1])
    batch_size = output_shape[1]
    begin
        n_1 = output_shape[4]
        n_2 = output_shape[3]
        n_3 = output_shape[2]
    end
    begin
        interval_middle_1 = coord[point_index, 3]
        interval_middle_2 = coord[point_index, 2]
        interval_middle_3 = coord[point_index, 1]
    end
    begin
        interval_start_1 = ceil(Int64, interval_middle_1 - width / 2)
        interval_start_2 = ceil(Int64, interval_middle_2 - width / 2)
        interval_start_3 = ceil(Int64, interval_middle_3 - width / 2)
    end
    begin
        interval_end_1 = floor(Int64, interval_middle_1 + width / 2)
        interval_end_2 = floor(Int64, interval_middle_2 + width / 2)
        interval_end_3 = floor(Int64, interval_middle_3 + width / 2)
    end
    w_4 = 1
    for i_3 = interval_start_3:interval_end_3
        w_3 = w_4 * lin_interpolate(kernel, abs(i_3 - interval_middle_3) / (width / 2))
        for i_2 = interval_start_2:interval_end_2
 

In [199]:
@generated function _interpolate_point!(
        output::AbstractArray{Complex{T}, 2},
        input::AbstractArray{Complex{T}, D},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int64,
        point_index::Int64) where {T, D}
    quote
        @interpolate_point $(D-1)
    end
end

_interpolate_point! (generic function with 2 methods)

In [200]:
@generated function _interpolate_point!(
        output::AbstractArray{T, 2},
        input::AbstractArray{T, D},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int64,
        point_index::Int64) where {T, D}
    quote
        @interpolate_point $(D-1)
    end
end

_interpolate_point! (generic function with 2 methods)

In [201]:
function _interpolate!(
        output::AbstractArray{Complex{T}, 2},
        input::AbstractArray{Complex{T}, D},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int64) where {T, D}
    Threads.@threads for point_index in 1:size(coord, 1)
        _interpolate_point!(output, input, coord, kernel, width, point_index)
    end
end

_interpolate! (generic function with 2 methods)

In [202]:
function _interpolate!(
        output::AbstractArray{T, 2},
        input::AbstractArray{T, D},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int64) where {T, D}
    Threads.@threads for point_index in 1:size(coord, 1)
        _interpolate_point!(output, input, coord, kernel, width, point_index)
    end
end

_interpolate! (generic function with 2 methods)

In [203]:
@generated function _gridding_point!(
        output::NTuple{N, AbstractArray{Complex{T}, D}},
        input::AbstractArray{Complex{T}, 2},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int,
        point_index::Int,
        thread_id::Int) where {T, D, N}
    quote
        @gridding_point $(D-1)
    end
end

_gridding_point! (generic function with 6 methods)

In [204]:
@generated function _gridding_point!(
        output::NTuple{N, AbstractArray{T, D}},
        input::AbstractArray{T, 2},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int64,
        point_index::Int,
        thread_id::Int) where {T, D, N}
    quote
        @gridding_point $(D-1)
    end
end

_gridding_point! (generic function with 6 methods)

In [265]:
function __gridding!(
        output::AbstractArray,
        input::AbstractArray{Complex{T}, 2},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int64) where {T, D}
    if Threads.nthreads() > 1 && # if threading enabled
            size(input, 2) > 10000 && # if the problem large enough
            Sys.free_memory() * .8 > sizeof(output) * (Threads.nthreads()-1) # if we have enough memory
        threaded_output = tuple(output, [copy(output) for _ in 2:Threads.nthreads()]...)
        Threads.@threads for point_index in 1:size(coord, 1)
            _gridding_point!(threaded_output, input, coord, kernel, width, point_index, Threads.threadid())
        end
        for i in 2:length(threaded_output)
            output .+= threaded_output[i]
        end
    else
        for point_index in 1:size(coord, 1)
            _gridding_point!((output,), input, coord, kernel, width, point_index, 1)
        end
    end
end

__gridding! (generic function with 1 method)

In [240]:
function _gridding!(
        output::AbstractArray{Complex{T}, D},
        input::AbstractArray{Complex{T}, 2},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int64) where {T, D}
    __gridding!(output, input, coord, kernel, width)
end

_gridding! (generic function with 2 methods)

In [241]:
function _gridding!(
        output::AbstractArray{T, D},
        input::AbstractArray{T, 2},
        coord::AbstractArray{T, 2},
        kernel::AbstractArray{T, 1},
        width::Int64) where {T, D}
    __gridding!(output, input, coord, kernel, width)
end

_gridding! (generic function with 2 methods)

## Compare with reference implementation

### Comparing output

In [178]:
sp = pyimport("sigpy")
interp = pyimport("sigpy.interp")
util = pyimport("sigpy.util");

In [36]:
py"""
def nufft(input, coord, oversamp=1.25, width=4.0, n=128):
    ndim = coord.shape[-1]
    beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5
    os_shape = _get_oversamp_shape(input.shape, ndim, oversamp)

    output = input.copy()

    # Apodize
    _apodize(output, ndim, oversamp, width, beta)

    # Zero-pad
    output /= util.prod(input.shape[-ndim:])**0.5
    output = util.resize(output, os_shape)

    # FFT
    output = sp.fft(output, axes=range(-ndim, 0), norm=None)

    # Interpolate
    coord = _scale_coord(coord, input.shape, oversamp)
    kernel = _get_kaiser_bessel_kernel(n, width, beta)
    output = interp.interpolate(output, width, kernel, coord)

    return output

def _get_kaiser_bessel_kernel(n, width, beta):
    x = np.arange(n) / n
    kernel = 1 / width * np.i0(beta * (1 - x**2)**0.5)
    return kernel

def _scale_coord(coord, shape, oversamp):
    ndim = coord.shape[-1]
    scale = [ceil(oversamp * i) / i for i in shape[-ndim:]]
    shift = [ceil(oversamp * i) // 2 for i in shape[-ndim:]]

    coord = scale * coord + shift

    return coord

def _get_oversamp_shape(shape, ndim, oversamp):
    return list(shape)[:-ndim] + [ceil(oversamp * i) for i in shape[-ndim:]]

def estimate_shape(coord):
    ndim = coord.shape[-1]
    return [int(coord[..., i].max() - coord[..., i].min()) for i in range(ndim)]

def _apodize(input, ndim, oversamp, width, beta):

    output = input
    for a in range(-ndim, 0):
        i = output.shape[a]
        os_i = ceil(oversamp * i)
        idx = np.arange(i)

        # Calculate apodization
        apod = (beta**2 - (np.pi * width * (idx - i // 2) / os_i)**2)**0.5
        apod /= np.sinh(apod)
        output *= apod.reshape([i] + [1] * (-a - 1))

    return output

def interpolate(input, width, kernel, coord):
    ndim = coord.shape[-1]

    batch_shape = input.shape[:-ndim]
    batch_size = util.prod(batch_shape)

    pts_shape = coord.shape[:-1]
    npts = util.prod(pts_shape)

    isreal = np.issubdtype(input.dtype, np.floating)

    input = input.reshape([batch_size] + list(input.shape[-ndim:]))
    coord = coord.reshape([npts, ndim])
    output = np.zeros([batch_size, npts], dtype=input.dtype)

    _interpolate3(output, input, width, kernel, coord)

    return output.reshape(batch_shape + pts_shape)

def _interpolate3(output, input, width, kernel, coord):
    batch_size, nz, ny, nx = input.shape
    npts = coord.shape[0]

    for i in range(npts):

        kx, ky, kz = coord[i, -1], coord[i, -2], coord[i, -3]

        x0, y0, z0 = (np.ceil(kx - width / 2).astype(int),
                      np.ceil(ky - width / 2).astype(int),
                      np.ceil(kz - width / 2).astype(int))

        x1, y1, z1 = (np.floor(kx + width / 2).astype(int),
                      np.floor(ky + width / 2).astype(int),
                      np.floor(kz + width / 2).astype(int))

        for z in range(z0, z1 + 1):
            wz = lin_interpolate(kernel, abs(z - kz) / (width / 2))

            for y in range(y0, y1 + 1):
                wy = wz * lin_interpolate(kernel, abs(y - ky) / (width / 2))

                for x in range(x0, x1 + 1):
                    w = wy * lin_interpolate(kernel, abs(x - kx) / (width / 2))

                    for b in range(batch_size):
                        output[b, i] += w * input[b, z % nz, y % ny, x % nx]

    return output

def lin_interpolate(kernel, x):
    if x >= 1:
        return 0.0
    n = len(kernel)
    idx = int(x * n)
    frac = x * n - idx

    left = kernel[idx]
    if idx == n - 1:
        right = 0.0
    else:
        right = kernel[idx + 1]
    return (1.0 - frac) * left + frac * right
"""

In [242]:
M, shape = 1000, (34, 30, 68) #4186100, (34, 30, 68)
img = rand(Float64, shape)
coord = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2;

In [243]:
width, n, oversamp = 4, 128, 1.25
ndim = ndims(img)
β = π * √(((width / oversamp) * (oversamp - 0.5))^2 - 0.8)

6.996659047674343

Apodization implementations are identical:

In [244]:
signal = copy(img)
output_j = _apodize!(signal, ndim, oversamp, width, β)
output_py = py"_apodize"(signal, ndim, oversamp, width, β)
print("absolute error: ", norm(output_j - output_py, Inf), "\n")

absolute error: 0.0


But there is a significant difference between FFTW and numpy's FFT:

In [245]:
output_j = centering_fft!(convert.(Complex, img), (1,2,3))
output_py = py"sp.fft($img, axes=(1,2,3), norm=None)"
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 7.275957614183426e-12
relative error: 8.707166080035383e-14

Also, there is some (but magnitudes smaller) difference between Julia's and numpy's Bessel function:

In [246]:
kernel_py = py"_get_kaiser_bessel_kernel"(n, width, β)
x = range(0, stop=n-1, step=1) ./ n
kernel_j = window_kaiser_bessel.(x, width, β)
print("absolute error: ", norm(kernel_j - kernel_py, Inf), "\n",
    "relative error: ", norm((kernel_j - kernel_py) ./ kernel_py, Inf))

absolute error: 1.4210854715202004e-14
relative error: 4.753447469672048e-16

On the other hand, the interpolation function also appears to be identical:

In [247]:
output_j = interpolate(img, width, kernel_j, coord);
output_py = interp.interpolate(img, width, kernel_j, coord);
print("absolute error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 0.0

Altogether:

In [248]:
ksp_j = nufft(img, coord)
ksp_py = py"nufft"(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 6.793258640404412e-15
relative error: 3.986266999957607e-14

In [249]:
output_j = nufft_adjoint(ksp_j, coord)
output_py = sp.nufft_adjoint(ksp_j, coord)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 6.859422151143303e-16
relative error: 5.611038030124373e-13

### Compare running time and output in multiple cases

In [211]:
py"""
import timeit
from math import log10
from statistics import median, mean

def benchmark(cmd_str, setup_str=''):
    t = timeit.Timer(cmd_str, setup=setup_str, globals=globals())
    approx = t.timeit(number=1)
    number = 1
    if approx > 60:
        measurements = [approx]
    elif approx > 30:
        measurements = [approx] + t.repeat(repeat=3, number=1)
    else:
        how_many = 30 / approx
        number = int(max(how_many // 10**(max(3,log10(how_many)-3)), 1))
        repeat = int(ceil(how_many / number))
        measurements = list(map(lambda x: x / number, t.repeat(repeat=repeat, number=number)))
    
    def time_format(sec):
        return f"{sec:.3f} s" if sec > 1 else f"{sec*1000:.3f} ms"

    return f'''
Python benchmark:
  --------------
  minimum time:     {time_format(min(measurements))}
  median time:      {time_format(median(measurements))}
  mean time:        {time_format(mean(measurements))}
  maximum time:     {time_format(max(measurements))}
  --------------
  samples:          {len(measurements)}
  evals/sample:     {number}
    '''
"""

#### Small sized 2D problem

In [250]:
M, shape = 1024, (16, 16)
img = rand(Float64, shape)
coord = rand(Float64, M, 2) .* collect(shape)' .- collect(shape)' ./2
py"""
img = $img
coord = $coord
"""

In [251]:
ksp_j = nufft(img, coord)
ksp_py = sp.nufft(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 3.612092524874097e-15
relative error: 2.51831417063076e-14

In [252]:
output_j = nufft_adjoint(ksp_j, coord)
output_py = sp.nufft_adjoint(ksp_j, coord)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 4.9164205652118054e-15
relative error: 2.6977978159394108e-14

In [120]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     1.019 ms
  median time:      1.022 ms
  mean time:        1.026 ms
  maximum time:     1.833 ms
  --------------
  samples:          1001
  evals/sample:     19
    

In [113]:
@benchmark nufft(img, coord)

BenchmarkTools.Trial: 
  memory estimate:  94.89 KiB
  allocs estimate:  419
  --------------
  minimum time:     491.981 μs (0.00% GC)
  median time:      512.722 μs (0.00% GC)
  mean time:        537.037 μs (2.06% GC)
  maximum time:     7.040 ms (76.90% GC)
  --------------
  samples:          9173
  evals/sample:     1

In [111]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     1.154 ms
  median time:      1.168 ms
  mean time:        1.170 ms
  maximum time:     1.508 ms
  --------------
  samples:          1080
  evals/sample:     5
    

In [253]:
ksp = nufft(img, coord)
@benchmark nufft_adjoint(ksp, coord)

BenchmarkTools.Trial: 
  memory estimate:  69.36 KiB
  allocs estimate:  1275
  --------------
  minimum time:     838.099 μs (0.00% GC)
  median time:      875.301 μs (0.00% GC)
  mean time:        891.544 μs (0.86% GC)
  maximum time:     13.805 ms (65.58% GC)
  --------------
  samples:          5570
  evals/sample:     1

#### Moderate sized 3D problem with batch

In [254]:
M, batch, shape = 16384, 12, (128, 128, 128)
img = rand(Float64, (batch, shape...))
coord = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2
py"""
img = $img
coord = $coord
"""

In [255]:
ksp_j = nufft(img, coord)
ksp_py = sp.nufft(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 2.7056661287843684e-14
relative error: 1.3199294899990275e-12

In [256]:
output_j = nufft_adjoint(ksp_j, coord)
output_py = sp.nufft_adjoint(ksp_j, coord)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 1.1995471724637809e-15
relative error: 7.427065570392663e-12

In [124]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     4.976 s
  median time:      5.911 s
  mean time:        5.692 s
  maximum time:     5.979 s
  --------------
  samples:          6
  evals/sample:     1
    

In [134]:
FFTW.set_num_threads(40)

In [130]:
@benchmark nufft(img, coord)

BenchmarkTools.Trial: 
  memory estimate:  2.59 GiB
  allocs estimate:  476
  --------------
  minimum time:     1.758 s (0.38% GC)
  median time:      2.013 s (11.74% GC)
  mean time:        1.999 s (8.49% GC)
  maximum time:     2.226 s (11.96% GC)
  --------------
  samples:          3
  evals/sample:     1

In [126]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     7.875 s
  median time:      7.880 s
  mean time:        7.881 s
  maximum time:     7.888 s
  --------------
  samples:          4
  evals/sample:     1
    

In [266]:
ksp = nufft(img, coord)
@benchmark nufft_adjoint(ksp, coord)

BenchmarkTools.Trial: 
  memory estimate:  2.89 GiB
  allocs estimate:  16762
  --------------
  minimum time:     3.263 s (3.36% GC)
  median time:      3.432 s (8.57% GC)
  mean time:        3.432 s (8.57% GC)
  maximum time:     3.602 s (13.29% GC)
  --------------
  samples:          2
  evals/sample:     1

#### Large 3D problem with batch

In [268]:
M, batch, shape = 4186100, 12, (34, 30, 68)
img = rand(Float64, (batch, shape...))
coord = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2
py"""
img = $img
coord = $coord
"""

In [270]:
ksp_j = nufft(img, coord)
ksp_py = sp.nufft(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 1.1370851969750854e-13
relative error: 2.1542510252213323e-11

In [271]:
output_j = nufft_adjoint(ksp_j, coord)
output_py = sp.nufft_adjoint(ksp_j, coord)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 2.5760726640001485e-12
relative error: 3.5198830919164485e-12

In [135]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     18.943 s
  median time:      19.020 s
  mean time:        19.020 s
  maximum time:     19.096 s
  --------------
  samples:          2
  evals/sample:     1
    

In [136]:
@benchmark nufft(img, coord)

BenchmarkTools.Trial: 
  memory estimate:  951.31 MiB
  allocs estimate:  472
  --------------
  minimum time:     3.060 s (0.00% GC)
  median time:      3.096 s (0.04% GC)
  mean time:        3.096 s (0.04% GC)
  maximum time:     3.132 s (0.08% GC)
  --------------
  samples:          2
  evals/sample:     1

In [132]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     15.532 s
  median time:      15.559 s
  mean time:        15.559 s
  maximum time:     15.586 s
  --------------
  samples:          2
  evals/sample:     1
    

In [272]:
ksp = nufft(img, coord)
@benchmark nufft_adjoint(ksp, coord)

BenchmarkTools.Trial: 
  memory estimate:  1.10 GiB
  allocs estimate:  764
  --------------
  minimum time:     4.598 s (22.56% GC)
  median time:      4.718 s (23.15% GC)
  mean time:        4.718 s (23.15% GC)
  maximum time:     4.839 s (23.71% GC)
  --------------
  samples:          2
  evals/sample:     1

## Unused code

In [122]:
nufft(img, coord);

  0.000013 seconds (7 allocations: 464 bytes)
  0.002394 seconds (2 allocations: 15.259 MiB)
  0.003585 seconds (24 allocations: 3.563 KiB)
  0.004763 seconds (5 allocations: 15.259 MiB)
  0.000002 seconds (2 allocations: 112 bytes)
  0.065851 seconds (90 allocations: 59.609 MiB)
  0.000118 seconds (1 allocation: 1.141 KiB)
  0.002683 seconds (52 allocations: 18.313 KiB)


In [15]:
function _spline_kernel(x::T, order::T)::T where {T<:Real}
    abs(x) > 1 && return zero(x)

    if order == 0
        return one(x)
    elseif order == 1
        return 1 - abs(x)
    elseif order == 2
        if abs(x) > 1 / 3
            return 9 / 8 * (1 - abs(x))^2
        else
            return 3 / 4 * (1 - 3 * x^2)
        end
    else
        @assert "Only {0,1,2}-order spline kernel is supported"
    end
end

_spline_kernel (generic function with 1 method)

In [16]:
function _kaiser_bessel_kernel(x::T, β::T)::T where {T<:Real}
    abs(x) > 1 && return zero(x)

    x = β * √(1 - x^2)
    t = x / 3.75
    if x < 3.75
        return 1 + 3.5156229 * t^2 + 3.0899424 * t^4 +
            1.2067492 * t^6 + 0.2659732 * t^8 +
            0.0360768 * t^10 + 0.0045813 * t^12
    else
        return x^-0.5 * exp(x) * (
            0.39894228 + 0.01328592 * t^-1 +
            0.00225319 * t^-2 - 0.00157565 * t^-3 +
            0.00916281 * t^-4 - 0.02057706 * t^-5 +
            0.02635537 * t^-6 - 0.01647633 * t^-7 +
            0.00392377 * t^-8)
    end
end

_kaiser_bessel_kernel (generic function with 1 method)

In [85]:
@time nufft(img, coord);

  7.342226 seconds (190 allocations: 358.736 MiB, 0.06% gc time)


In [26]:
?nufft

search: [0m[1mn[22m[0m[1mu[22m[0m[1mf[22m[0m[1mf[22m[0m[1mt[22m



Non-uniform Fast Fourier Transform.

**Arguments**:

  * `input (ArrayType{T} or ArrayType{Complex{T}})`: input signal domain array of shape   $(n_k, \ldots, n_{ndim + 1}, n_{ndim}, \ldots, n_2, n_1)$,   where $ndim$ is specified by `size(coord)[end]`. The nufft   is applied on the last $ndim$ axes, and looped over   the remaining axes. `ArrayType` can be any `AbstractArray`.
  * `coord (ArrayType{T})`: Fourier domain coordinate array of shape $(m_l, \ldots, m_1, ndim)$.   $ndim$ determines the number of dimensions to apply the nufft.   `coord[..., i]` should be scaled to have its range $[-n_i \div 2, n_i \div 2]$.
  * `oversamp (Float32)`: oversampling factor.
  * `width (Int64)`: interpolation kernel full-width in terms of   oversampled grid.

**Returns**:

  * `ArrayType{Complex{T}}`: Fourier domain data of shape   $(n_k, \ldots, n_{ndim + 1}, m_l, \ldots, m_1)$.

**References**:

  * Fessler, J. A., & Sutton, B. P. (2003). "Nonuniform fast Fourier transforms using min-max interpolation", *IEEE Transactions on Signal Processing*, 51(2), 560-574.
  * Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005). "Rapid gridding reconstruction with a minimal oversampling ratio," *IEEE transactions on medical imaging*, 24(6), 799-808.
