In [1]:
using Documenter, EllipsisNotation, FFTW, PaddedViews, SpecialFunctions, Base.Cartesian, Base.Threads

┌ Info: Precompiling Documenter [e30172f5-a6a5-5a46-863b-614d45cd2de4]
└ @ Base loading.jl:1273


In [49]:
include("SigPy.jl")



Main.SigPy

In [3]:
using LinearAlgebra, MacroTools, NFFT, BenchmarkTools
using PyCall, Libdl
Libdl.dlopen(ENV["HOME"]*"/.local/lib/python3.6/site-packages/llvmlite/binding/libllvmlite.so",
    Libdl.RTLD_DEEPBIND);
py"""
from math import ceil
import numpy as np
import sigpy as sp
from sigpy import util, interp
"""

## Compare with reference implementation

### Comparing output

In [4]:
sp = pyimport("sigpy")
interp = pyimport("sigpy.interp")
util = pyimport("sigpy.util");

In [5]:
py"""
def nufft(input, coord, oversamp=1.25, width=4.0, n=128):
    ndim = coord.shape[-1]
    beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5
    os_shape = _get_oversamp_shape(input.shape, ndim, oversamp)

    output = input.copy()

    # Apodize
    _apodize(output, ndim, oversamp, width, beta)

    # Zero-pad
    output /= util.prod(input.shape[-ndim:])**0.5
    output = util.resize(output, os_shape)

    # FFT
    output = sp.fft(output, axes=range(-ndim, 0), norm=None)

    # Interpolate
    coord = _scale_coord(coord, input.shape, oversamp)
    kernel = _get_kaiser_bessel_kernel(n, width, beta)
    output = interp.interpolate(output, width, kernel, coord)

    return output

def _get_kaiser_bessel_kernel(n, width, beta):
    x = np.arange(n) / n
    kernel = 1 / width * np.i0(beta * (1 - x**2)**0.5)
    return kernel

def _scale_coord(coord, shape, oversamp):
    ndim = coord.shape[-1]
    scale = [ceil(oversamp * i) / i for i in shape[-ndim:]]
    shift = [ceil(oversamp * i) // 2 for i in shape[-ndim:]]

    coord = scale * coord + shift

    return coord

def _get_oversamp_shape(shape, ndim, oversamp):
    return list(shape)[:-ndim] + [ceil(oversamp * i) for i in shape[-ndim:]]

def estimate_shape(coord):
    ndim = coord.shape[-1]
    return [int(coord[..., i].max() - coord[..., i].min()) for i in range(ndim)]

def _apodize(input, ndim, oversamp, width, beta):

    output = input
    for a in range(-ndim, 0):
        i = output.shape[a]
        os_i = ceil(oversamp * i)
        idx = np.arange(i)

        # Calculate apodization
        apod = (beta**2 - (np.pi * width * (idx - i // 2) / os_i)**2)**0.5
        apod /= np.sinh(apod)
        output *= apod.reshape([i] + [1] * (-a - 1))

    return output

def interpolate(input, width, kernel, coord):
    ndim = coord.shape[-1]

    batch_shape = input.shape[:-ndim]
    batch_size = util.prod(batch_shape)

    pts_shape = coord.shape[:-1]
    npts = util.prod(pts_shape)

    isreal = np.issubdtype(input.dtype, np.floating)

    input = input.reshape([batch_size] + list(input.shape[-ndim:]))
    coord = coord.reshape([npts, ndim])
    output = np.zeros([batch_size, npts], dtype=input.dtype)

    _interpolate3(output, input, width, kernel, coord)

    return output.reshape(batch_shape + pts_shape)

def _interpolate3(output, input, width, kernel, coord):
    batch_size, nz, ny, nx = input.shape
    npts = coord.shape[0]

    for i in range(npts):

        kx, ky, kz = coord[i, -1], coord[i, -2], coord[i, -3]

        x0, y0, z0 = (np.ceil(kx - width / 2).astype(int),
                      np.ceil(ky - width / 2).astype(int),
                      np.ceil(kz - width / 2).astype(int))

        x1, y1, z1 = (np.floor(kx + width / 2).astype(int),
                      np.floor(ky + width / 2).astype(int),
                      np.floor(kz + width / 2).astype(int))

        for z in range(z0, z1 + 1):
            wz = lin_interpolate(kernel, abs(z - kz) / (width / 2))

            for y in range(y0, y1 + 1):
                wy = wz * lin_interpolate(kernel, abs(y - ky) / (width / 2))

                for x in range(x0, x1 + 1):
                    w = wy * lin_interpolate(kernel, abs(x - kx) / (width / 2))

                    for b in range(batch_size):
                        output[b, i] += w * input[b, z % nz, y % ny, x % nx]

    return output

def lin_interpolate(kernel, x):
    if x >= 1:
        return 0.0
    n = len(kernel)
    idx = int(x * n)
    frac = x * n - idx

    left = kernel[idx]
    if idx == n - 1:
        right = 0.0
    else:
        right = kernel[idx + 1]
    return (1.0 - frac) * left + frac * right
"""

In [295]:
M, shape = 1000, (34, 30, 68)
img = rand(Float64, shape)
coord = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2;

In [296]:
width, n, oversamp = 4, 128, 1.25
ndim = ndims(img)
β = π * √(((width / oversamp) * (oversamp - 0.5))^2 - 0.8)

6.996659047674343

Apodization implementations are identical:

In [34]:
signal = copy(img)
output_j = SigPy._apodize!(signal, ndim, oversamp, width, β)
output_py = py"_apodize"(signal, ndim, oversamp, width, β)
print("absolute error: ", norm(output_j - output_py, Inf), "\n")

absolute error: 0.0


But there is a significant difference between FFTW and numpy's FFT:

In [35]:
output_j = SigPy.centering_fft!(convert.(Complex, img), (1,2,3))
output_py = py"sp.fft($img, axes=(1,2,3), norm=None)"
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 1.2448763988069971e-12
relative error: 1.6832693194493325e-12

Also, there is some (but magnitudes smaller) difference between Julia's and numpy's Bessel function:

In [36]:
kernel_py = py"_get_kaiser_bessel_kernel"(n, width, β)
x = range(0, stop=n-1, step=1) ./ n
kernel_j = SigPy.window_kaiser_bessel.(x, width, β)
print("absolute error: ", norm(kernel_j - kernel_py, Inf), "\n",
    "relative error: ", norm((kernel_j - kernel_py) ./ kernel_py, Inf))

absolute error: 1.4210854715202004e-14
relative error: 4.753447469672048e-16

On the other hand, the interpolation function also appears to be identical:

In [37]:
output_j = SigPy.interpolate(img, width, kernel_j, coord);
output_py = interp.interpolate(img, width, kernel_j, coord);
print("absolute error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 0.0

Altogether:

In [297]:
ksp_j = SigPy.nufft(coord, img)
ksp_py = py"nufft"(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 3.7258694339903e-15
relative error: 1.0397644736622794e-13

In [298]:
plan = SigPy.nufft_plan(coord, convert.(ComplexF64, img))
ksp_j = plan * convert.(ComplexF64, img)
ksp_py = py"nufft"(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 3.7258694339903e-15
relative error: 1.0397644736622794e-13

In [299]:
plan = SigPy.nufft_plan(coord, size(img))
ksp_j = plan * convert.(ComplexF64, img)
ksp_py = py"nufft"(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 3.7258694339903e-15
relative error: 1.0397644736622794e-13

In [300]:
output_py = sp.nufft_adjoint(ksp_j, coord)
plan = SigPy.nufft_plan(coord, output_py)
output_j = plan' * ksp_j
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 5.622217108246132e-16
relative error: 1.914576449649916e-13

In [301]:
output_py = sp.nufft_adjoint(ksp_j, coord)
plan = SigPy.nufft_plan(coord)
output_j = plan' * ksp_j
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 5.622217108246132e-16
relative error: 1.914576449649916e-13

In [302]:
output_j = SigPy.nufft_adjoint(coord, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 5.622217108246132e-16
relative error: 1.914576449649916e-13

### Compare running time and output in multiple cases

In [237]:
py"""
import timeit
from math import log10
from statistics import median, mean

def benchmark(cmd_str, setup_str=''):
    t = timeit.Timer(cmd_str, setup=setup_str, globals=globals())
    approx = t.timeit(number=1)
    number = 1
    if approx > 60:
        measurements = [approx]
    elif approx > 30:
        measurements = [approx] + t.repeat(repeat=3, number=1)
    else:
        how_many = 30 / approx
        number = int(max(how_many // 10**(max(3,log10(how_many)-3)), 1))
        repeat = int(ceil(how_many / number))
        measurements = list(map(lambda x: x / number, t.repeat(repeat=repeat, number=number)))
    
    def time_format(sec):
        return f"{sec:.3f} s" if sec > 1 else f"{sec*1000:.3f} ms"

    return f'''
Python benchmark:
  --------------
  minimum time:     {time_format(min(measurements))}
  median time:      {time_format(median(measurements))}
  mean time:        {time_format(mean(measurements))}
  maximum time:     {time_format(max(measurements))}
  --------------
  samples:          {len(measurements)}
  evals/sample:     {number}
    '''
"""

#### Small sized 2D problem

In [398]:
M, shape = 1024, (16, 16)
img = rand(Float64, shape)
coord = rand(Float64, M, 2) .* collect(shape)' .- collect(shape)' ./2
py"""
img = $img
coord = $coord
"""

In [399]:
ksp_j = SigPy.nufft(coord, img)
ksp_py = sp.nufft(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 5.3579045697976356e-15
relative error: 3.403288354090141e-14

In [400]:
output_j = SigPy.nufft_adjoint(coord, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 6.530527749616485e-15
relative error: 6.858262785592868e-15

In [243]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     1.022 ms
  median time:      1.039 ms
  mean time:        1.054 ms
  maximum time:     1.841 ms
  --------------
  samples:          1104
  evals/sample:     9
    

In [255]:
@benchmark SigPy.nufft(coord, img)

BenchmarkTools.Trial: 
  memory estimate:  88.33 KiB
  allocs estimate:  482
  --------------
  minimum time:     551.340 μs (0.00% GC)
  median time:      571.018 μs (0.00% GC)
  mean time:        598.590 μs (2.72% GC)
  maximum time:     18.441 ms (86.62% GC)
  --------------
  samples:          8231
  evals/sample:     1

In [271]:
complexImg = convert.(ComplexF64, img)
plan = SigPy.nufft_plan(coord, complexImg)
output = plan * complexImg
@benchmark mul!(output, plan, complexImg)

BenchmarkTools.Trial: 
  memory estimate:  32.77 KiB
  allocs estimate:  356
  --------------
  minimum time:     247.829 μs (0.00% GC)
  median time:      268.574 μs (0.00% GC)
  mean time:        292.375 μs (4.16% GC)
  maximum time:     37.911 ms (98.40% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [272]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     1.070 ms
  median time:      1.079 ms
  mean time:        1.080 ms
  maximum time:     1.335 ms
  --------------
  samples:          1013
  evals/sample:     21
    

In [276]:
ksp = SigPy.nufft(coord, img)
@benchmark SigPy.nufft_adjoint(coord, ksp)

BenchmarkTools.Trial: 
  memory estimate:  55.22 KiB
  allocs estimate:  1272
  --------------
  minimum time:     510.868 μs (0.00% GC)
  median time:      524.598 μs (0.00% GC)
  mean time:        543.652 μs (1.37% GC)
  maximum time:     11.699 ms (82.88% GC)
  --------------
  samples:          9100
  evals/sample:     1

In [277]:
plan_adj = plan'
@benchmark mul!(complexImg, plan_adj, ksp)

BenchmarkTools.Trial: 
  memory estimate:  19.36 KiB
  allocs estimate:  1109
  --------------
  minimum time:     328.356 μs (0.00% GC)
  median time:      336.836 μs (0.00% GC)
  mean time:        345.142 μs (0.89% GC)
  maximum time:     10.659 ms (95.98% GC)
  --------------
  samples:          10000
  evals/sample:     1

#### Moderate sized 3D problem with batch

In [401]:
M, batch, shape = 16384, 12, (128, 128, 128)
img = rand(Float64, (batch, shape...))
coord = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2
py"""
img = $img
coord = $coord
"""

In [402]:
FFTW.set_num_threads(40)

In [403]:
ksp_j = SigPy.nufft(coord, img)
ksp_py = sp.nufft(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 2.5011331388797805e-14
relative error: 3.305386837857348e-13

In [416]:
output_j = SigPy.nufft_adjoint(coord, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 1.1811088073830869e-15
relative error: 5.8175565400131025e-12

In [365]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     7.145 s
  median time:      7.147 s
  mean time:        7.147 s
  maximum time:     7.150 s
  --------------
  samples:          5
  evals/sample:     1
    

In [373]:
@benchmark SigPy.nufft(coord, img)

BenchmarkTools.Trial: 
  memory estimate:  1.84 GiB
  allocs estimate:  589
  --------------
  minimum time:     2.197 s (2.96% GC)
  median time:      2.315 s (11.30% GC)
  mean time:        2.280 s (9.29% GC)
  maximum time:     2.328 s (11.24% GC)
  --------------
  samples:          3
  evals/sample:     1

In [367]:
complexImg = convert.(ComplexF64, img)
plan = SigPy.nufft_plan(coord, complexImg)
output = plan * complexImg
@benchmark mul!(output, plan, complexImg)

BenchmarkTools.Trial: 
  memory estimate:  38.13 KiB
  allocs estimate:  414
  --------------
  minimum time:     1.536 s (0.00% GC)
  median time:      1.537 s (0.00% GC)
  mean time:        1.553 s (0.40% GC)
  maximum time:     1.601 s (1.56% GC)
  --------------
  samples:          4
  evals/sample:     1

In [368]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     7.295 s
  median time:      7.301 s
  mean time:        7.304 s
  maximum time:     7.320 s
  --------------
  samples:          5
  evals/sample:     1
    

In [374]:
ksp = SigPy.nufft(coord, img)
@benchmark SigPy.nufft_adjoint(coord, ksp)

BenchmarkTools.Trial: 
  memory estimate:  1.44 GiB
  allocs estimate:  16764
  --------------
  minimum time:     2.091 s (0.19% GC)
  median time:      2.214 s (4.60% GC)
  mean time:        2.229 s (5.15% GC)
  maximum time:     2.381 s (10.01% GC)
  --------------
  samples:          3
  evals/sample:     1

In [375]:
plan_adj = plan'
@benchmark mul!(complexImg, plan_adj, ksp)

BenchmarkTools.Trial: 
  memory estimate:  264.81 KiB
  allocs estimate:  16533
  --------------
  minimum time:     1.356 s (0.00% GC)
  median time:      1.419 s (0.00% GC)
  mean time:        1.418 s (0.00% GC)
  maximum time:     1.479 s (0.00% GC)
  --------------
  samples:          4
  evals/sample:     1

#### Large 3D problem with 2D batch

In [433]:
include("SigPy.jl")



Main.SigPy

In [434]:
M, batch, shape = 4186100, (3,4), (34, 30, 68)
img = rand(Float64, (batch..., shape...))
coord = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2
py"""
img = $img
coord = $coord
"""

In [435]:
ksp_j = SigPy.nufft(coord, img)
ksp_py = sp.nufft(img, coord)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 1.7053155762006797e-13
relative error: 1.9703935309171424e-11

In [436]:
output_j = SigPy.nufft_adjoint(coord, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 4.023484415233131e-12
relative error: 5.510712981421691e-12

In [421]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     18.049 s
  median time:      18.125 s
  mean time:        18.125 s
  maximum time:     18.201 s
  --------------
  samples:          2
  evals/sample:     1
    

In [427]:
@benchmark SigPy.nufft(coord, img)

BenchmarkTools.Trial: 
  memory estimate:  925.92 MiB
  allocs estimate:  718
  --------------
  minimum time:     1.539 s (0.79% GC)
  median time:      1.732 s (14.15% GC)
  mean time:        1.682 s (11.14% GC)
  maximum time:     1.773 s (17.18% GC)
  --------------
  samples:          3
  evals/sample:     1

In [442]:
complexImg = convert.(ComplexF64, img)
plan = SigPy.nufft_plan(coord, complexImg)
output = plan * complexImg
@benchmark mul!(output, plan, complexImg)

BenchmarkTools.Trial: 
  memory estimate:  42.66 KiB
  allocs estimate:  559
  --------------
  minimum time:     1.037 s (0.00% GC)
  median time:      1.038 s (0.00% GC)
  mean time:        1.039 s (0.12% GC)
  maximum time:     1.040 s (0.58% GC)
  --------------
  samples:          5
  evals/sample:     1

In [424]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     19.828 s
  median time:      19.856 s
  mean time:        19.856 s
  maximum time:     19.885 s
  --------------
  samples:          2
  evals/sample:     1
    

In [438]:
ksp = SigPy.nufft(coord, img)
@benchmark SigPy.nufft_adjoint(coord, ksp)

BenchmarkTools.Trial: 
  memory estimate:  1.05 GiB
  allocs estimate:  907
  --------------
  minimum time:     3.174 s (14.79% GC)
  median time:      3.950 s (17.86% GC)
  mean time:        3.950 s (17.86% GC)
  maximum time:     4.725 s (19.93% GC)
  --------------
  samples:          2
  evals/sample:     1

In [443]:
complexImg = convert.(ComplexF64, img)
plan_adj = plan'
@benchmark mul!(complexImg, plan_adj, ksp)

BenchmarkTools.Trial: 
  memory estimate:  991.88 MiB
  allocs estimate:  696
  --------------
  minimum time:     3.517 s (13.10% GC)
  median time:      3.980 s (14.52% GC)
  mean time:        3.980 s (14.52% GC)
  maximum time:     4.443 s (15.65% GC)
  --------------
  samples:          2
  evals/sample:     1

## Unused code

In [15]:
function _spline_kernel(x::T, order::T)::T where {T<:Real}
    abs(x) > 1 && return zero(x)

    if order == 0
        return one(x)
    elseif order == 1
        return 1 - abs(x)
    elseif order == 2
        if abs(x) > 1 / 3
            return 9 / 8 * (1 - abs(x))^2
        else
            return 3 / 4 * (1 - 3 * x^2)
        end
    else
        @assert "Only {0,1,2}-order spline kernel is supported"
    end
end

_spline_kernel (generic function with 1 method)

In [16]:
function _kaiser_bessel_kernel(x::T, β::T)::T where {T<:Real}
    abs(x) > 1 && return zero(x)

    x = β * √(1 - x^2)
    t = x / 3.75
    if x < 3.75
        return 1 + 3.5156229 * t^2 + 3.0899424 * t^4 +
            1.2067492 * t^6 + 0.2659732 * t^8 +
            0.0360768 * t^10 + 0.0045813 * t^12
    else
        return x^-0.5 * exp(x) * (
            0.39894228 + 0.01328592 * t^-1 +
            0.00225319 * t^-2 - 0.00157565 * t^-3 +
            0.00916281 * t^-4 - 0.02057706 * t^-5 +
            0.02635537 * t^-6 - 0.01647633 * t^-7 +
            0.00392377 * t^-8)
    end
end

_kaiser_bessel_kernel (generic function with 1 method)