In [1]:
using FFTW

In [None]:
include("SigJl.jl")

In [None]:
using LinearAlgebra, MacroTools, NFFT, BenchmarkTools
using PyCall, Libdl
Libdl.dlopen(ENV["HOME"]*"/.local/lib/python3.6/site-packages/llvmlite/binding/libllvmlite.so",
    Libdl.RTLD_DEEPBIND);
py"""
from math import ceil
import numpy as np
import SigJl as sp
from SigJl import util, interp
"""

## Compare with reference implementation

### Comparing output

In [None]:
sp = pyimport("SigJl")
interp = pyimport("SigJl.interp")
util = pyimport("SigJl.util");

In [None]:
py"""
def nufft(input, coord, oversamp=1.25, width=4.0, n=128):
    ndim = coord.shape[-1]
    beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5
    os_shape = _get_oversamp_shape(input.shape, ndim, oversamp)

    output = input.copy()

    # Apodize
    _apodize(output, ndim, oversamp, width, beta)

    # Zero-pad
    output /= util.prod(input.shape[-ndim:])**0.5
    output = util.resize(output, os_shape)

    # FFT
    output = sp.fft(output, axes=range(-ndim, 0), norm=None)

    # Interpolate
    coord = _scale_coord(coord, input.shape, oversamp)
    kernel = _get_kaiser_bessel_kernel(n, width, beta)
    output = interp.interpolate(output, width, kernel, coord)

    return output

def _get_kaiser_bessel_kernel(n, width, beta):
    x = np.arange(n) / n
    kernel = 1 / width * np.i0(beta * (1 - x**2)**0.5)
    return kernel

def _scale_coord(coord, shape, oversamp):
    ndim = coord.shape[-1]
    scale = [ceil(oversamp * i) / i for i in shape[-ndim:]]
    shift = [ceil(oversamp * i) // 2 for i in shape[-ndim:]]

    coord = scale * coord + shift

    return coord

def _get_oversamp_shape(shape, ndim, oversamp):
    return list(shape)[:-ndim] + [ceil(oversamp * i) for i in shape[-ndim:]]

def estimate_shape(coord):
    ndim = coord.shape[-1]
    return [int(coord[..., i].max() - coord[..., i].min()) for i in range(ndim)]

def _apodize(input, ndim, oversamp, width, beta):

    output = input
    for a in range(-ndim, 0):
        i = output.shape[a]
        os_i = ceil(oversamp * i)
        idx = np.arange(i)

        # Calculate apodization
        apod = (beta**2 - (np.pi * width * (idx - i // 2) / os_i)**2)**0.5
        apod /= np.sinh(apod)
        output *= apod.reshape([i] + [1] * (-a - 1))

    return output

def interpolate(input, width, kernel, coord):
    ndim = coord.shape[-1]

    batch_shape = input.shape[:-ndim]
    batch_size = util.prod(batch_shape)

    pts_shape = coord.shape[:-1]
    npts = util.prod(pts_shape)

    isreal = np.issubdtype(input.dtype, np.floating)

    input = input.reshape([batch_size] + list(input.shape[-ndim:]))
    coord = coord.reshape([npts, ndim])
    output = np.zeros([batch_size, npts], dtype=input.dtype)

    _interpolate3(output, input, width, kernel, coord)

    return output.reshape(batch_shape + pts_shape)

def _interpolate3(output, input, width, kernel, coord):
    batch_size, nz, ny, nx = input.shape
    npts = coord.shape[0]

    for i in range(npts):

        kx, ky, kz = coord[i, -1], coord[i, -2], coord[i, -3]

        x0, y0, z0 = (np.ceil(kx - width / 2).astype(int),
                      np.ceil(ky - width / 2).astype(int),
                      np.ceil(kz - width / 2).astype(int))

        x1, y1, z1 = (np.floor(kx + width / 2).astype(int),
                      np.floor(ky + width / 2).astype(int),
                      np.floor(kz + width / 2).astype(int))

        for z in range(z0, z1 + 1):
            wz = lin_interpolate(kernel, abs(z - kz) / (width / 2))

            for y in range(y0, y1 + 1):
                wy = wz * lin_interpolate(kernel, abs(y - ky) / (width / 2))

                for x in range(x0, x1 + 1):
                    w = wy * lin_interpolate(kernel, abs(x - kx) / (width / 2))

                    for b in range(batch_size):
                        output[b, i] += w * input[b, z % nz, y % ny, x % nx]

    return output

def lin_interpolate(kernel, x):
    if x >= 1:
        return 0.0
    n = len(kernel)
    idx = int(x * n)
    frac = x * n - idx

    left = kernel[idx]
    if idx == n - 1:
        right = 0.0
    else:
        right = kernel[idx + 1]
    return (1.0 - frac) * left + frac * right
"""

In [11]:
M, shape = 1000, (34, 30, 68)
img = rand(Float64, shape)
coord_py = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2
coord_j = permutedims(coord_py, (2,1));

In [12]:
width, n, oversamp = 4, 128, 1.25
ndim = ndims(img)
β = π * √(((width / oversamp) * (oversamp - 0.5))^2 - 0.8)

6.996659047674343

Apodization implementations are identical:

In [14]:
signal = copy(img)
output_j = SigJl._apodize!(signal, ndim, oversamp, width, β)
output_py = py"_apodize"(signal, ndim, oversamp, width, β)
print("absolute error: ", norm(output_j - output_py, Inf), "\n")

absolute error: 0.0


But there is a significant difference between FFTW and numpy's FFT:

In [15]:
output_j = SigJl.centering_fft!(convert.(Complex, img), (1,2,3))
output_py = py"sp.fft($img, axes=(1,2,3), norm=None)"
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 1.2069779387208717e-12
relative error: 1.5021574004802928e-13

Also, there is some (but magnitudes smaller) difference between Julia's and numpy's Bessel function:

In [16]:
kernel_py = py"_get_kaiser_bessel_kernel"(n, width, β)
x = range(0, stop=n-1, step=1) ./ n
kernel_j = SigJl.window_kaiser_bessel.(x, width, β)
print("absolute error: ", norm(kernel_j - kernel_py, Inf), "\n",
    "relative error: ", norm((kernel_j - kernel_py) ./ kernel_py, Inf))

absolute error: 1.4210854715202004e-14
relative error: 4.753447469672048e-16

Altogether:

In [17]:
ksp_j = SigJl.nufft(coord_j, img)
ksp_py = py"nufft"(img, coord_py)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 1.0223701381680905e-14
relative error: 1.0132851590096376e-13

In [18]:
plan = SigJl.nufft_plan(coord_j, convert.(ComplexF64, img))
ksp_j = plan * convert.(ComplexF64, img)
ksp_py = py"nufft"(img, coord_py)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 1.0223701381680905e-14
relative error: 1.0132851590096376e-13

In [19]:
plan = SigJl.nufft_plan(coord_j, size(img))
ksp_j = plan * convert.(ComplexF64, img)
ksp_py = py"nufft"(img, coord_py)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 1.0223701381680905e-14
relative error: 1.0132851590096376e-13

In [20]:
output_py = sp.nufft_adjoint(ksp_j, coord_py)
plan = SigJl.nufft_plan(coord_j, output_py)
output_j = plan' * ksp_j
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 5.67456129498071e-16
relative error: 9.534559559989506e-13

In [21]:
output_py = sp.nufft_adjoint(ksp_j, coord_py)
plan = SigJl.nufft_plan(coord_j)
output_j = plan' * ksp_j
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 5.67456129498071e-16
relative error: 9.534559559989506e-13

In [22]:
output_j = SigJl.nufft_adjoint(coord_j, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord_py)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 5.67456129498071e-16
relative error: 9.534559559989506e-13

### Compare running time and output in multiple cases

In [23]:
py"""
import timeit
from math import log10
from statistics import median, mean

def benchmark(cmd_str, setup_str=''):
    t = timeit.Timer(cmd_str, setup=setup_str, globals=globals())
    approx = t.timeit(number=1)
    number = 1
    if approx > 60:
        measurements = [approx]
    elif approx > 30:
        measurements = [approx] + t.repeat(repeat=3, number=1)
    else:
        how_many = 30 / approx
        number = int(max(how_many // 10**(max(3,log10(how_many)-3)), 1))
        repeat = int(ceil(how_many / number))
        measurements = list(map(lambda x: x / number, t.repeat(repeat=repeat, number=number)))
    
    def time_format(sec):
        return f"{sec:.3f} s" if sec > 1 else f"{sec*1000:.3f} ms"

    return f'''
Python benchmark:
  --------------
  minimum time:     {time_format(min(measurements))}
  median time:      {time_format(median(measurements))}
  mean time:        {time_format(mean(measurements))}
  maximum time:     {time_format(max(measurements))}
  --------------
  samples:          {len(measurements)}
  evals/sample:     {number}
    '''
"""

#### Small sized 2D problem

In [25]:
M, shape = 1024, (16, 16)
img = rand(Float64, shape)
coord_py = rand(Float64, M, 2) .* collect(shape)' .- collect(shape)' ./2
coord_j = permutedims(coord_py, (2,1))
py"""
img = $img
coord = $coord_py
"""

In [26]:
ksp_j = SigJl.nufft(coord_j, img)
ksp_py = sp.nufft(img, coord_py)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 5.341271894614334e-15
relative error: 2.4400971837401695e-14

In [27]:
output_j = SigJl.nufft_adjoint(coord_j, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord_py)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 6.218007833105156e-15
relative error: 1.453958579923098e-14

In [28]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     1.041 ms
  median time:      1.046 ms
  mean time:        1.047 ms
  maximum time:     1.233 ms
  --------------
  samples:          1003
  evals/sample:     17
    

In [29]:
@benchmark SigJl.nufft(coord_j, img)

BenchmarkTools.Trial: 
  memory estimate:  88.39 KiB
  allocs estimate:  462
  --------------
  minimum time:     528.508 μs (0.00% GC)
  median time:      578.802 μs (0.00% GC)
  mean time:        612.044 μs (1.60% GC)
  maximum time:     8.596 ms (82.34% GC)
  --------------
  samples:          8046
  evals/sample:     1

In [30]:
complexImg = convert.(ComplexF64, img)
plan = SigJl.nufft_plan(coord_j, complexImg)
output = plan * complexImg
@benchmark mul!(output, plan, complexImg)

BenchmarkTools.Trial: 
  memory estimate:  32.11 KiB
  allocs estimate:  337
  --------------
  minimum time:     239.127 μs (0.00% GC)
  median time:      261.550 μs (0.00% GC)
  mean time:        311.860 μs (1.59% GC)
  maximum time:     23.677 ms (29.05% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [31]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     1.084 ms
  median time:      1.092 ms
  mean time:        1.092 ms
  maximum time:     1.168 ms
  --------------
  samples:          1001
  evals/sample:     25
    

In [32]:
ksp = SigJl.nufft(coord_j, img)
@benchmark SigJl.nufft_adjoint(coord_j, ksp)

BenchmarkTools.Trial: 
  memory estimate:  47.67 KiB
  allocs estimate:  766
  --------------
  minimum time:     560.668 μs (0.00% GC)
  median time:      573.342 μs (0.00% GC)
  mean time:        580.079 μs (0.75% GC)
  maximum time:     5.674 ms (73.14% GC)
  --------------
  samples:          8531
  evals/sample:     1

In [33]:
plan_adj = plan'
@benchmark mul!(complexImg, plan_adj, ksp)

BenchmarkTools.Trial: 
  memory estimate:  10.75 KiB
  allocs estimate:  580
  --------------
  minimum time:     341.538 μs (0.00% GC)
  median time:      348.559 μs (0.00% GC)
  mean time:        349.487 μs (0.23% GC)
  maximum time:     4.487 ms (90.98% GC)
  --------------
  samples:          10000
  evals/sample:     1

#### Small sized 2D problem with shaped coord vector

In [34]:
M, shape = (4, 256), (16, 16)
D = length(shape)
img = rand(Float64, shape)
scale_and_shift = reshape(collect(shape), fill(1, length(M))..., D)
coord_py = rand(Float64, M..., D) .* scale_and_shift .- scale_and_shift ./2
coord_j = permutedims(coord_py, (3,1,2))
py"""
img = $img
coord = $coord_py
"""

In [35]:
ksp_j = SigJl.nufft(coord_j, img)
ksp_py = sp.nufft(img, coord_py)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 3.1264164235997256e-15
relative error: 1.802739049650836e-14

In [36]:
output_j = SigJl.nufft_adjoint(coord_j, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord_py)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 6.6867899593804964e-15
relative error: 1.495324007899287e-14

#### Moderate sized 3D problem with batch

In [37]:
M, batch, shape = 16384, 12, (128, 128, 128)
img = rand(Float64, (batch, shape...))
coord_py = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2
coord_j = permutedims(coord_py, (2,1))
py"""
img = $img
coord = $coord_py
"""

In [38]:
FFTW.set_num_threads(40)

In [39]:
ksp_j = SigJl.nufft(coord_j, img)
ksp_py = sp.nufft(img, coord_py)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 3.130435208578516e-14
relative error: 4.1080348047135497e-13

In [40]:
output_j = SigJl.nufft_adjoint(coord_j, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord_py)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 1.004684985911257e-15
relative error: 1.147562021996739e-11

In [41]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     6.475 s
  median time:      6.512 s
  mean time:        6.539 s
  maximum time:     6.634 s
  --------------
  samples:          5
  evals/sample:     1
    

In [42]:
@benchmark SigJl.nufft(coord_j, img)

BenchmarkTools.Trial: 
  memory estimate:  1.84 GiB
  allocs estimate:  532
  --------------
  minimum time:     1.668 s (1.63% GC)
  median time:      1.670 s (1.63% GC)
  mean time:        1.733 s (4.72% GC)
  maximum time:     1.861 s (11.61% GC)
  --------------
  samples:          3
  evals/sample:     1

In [43]:
complexImg = convert.(ComplexF64, img)
plan = SigJl.nufft_plan(coord_j, complexImg)
output = plan * complexImg
@benchmark mul!(output, plan, complexImg)

BenchmarkTools.Trial: 
  memory estimate:  37.50 KiB
  allocs estimate:  402
  --------------
  minimum time:     1.280 s (0.00% GC)
  median time:      1.300 s (0.00% GC)
  mean time:        1.306 s (0.00% GC)
  maximum time:     1.345 s (0.00% GC)
  --------------
  samples:          4
  evals/sample:     1

In [44]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     7.189 s
  median time:      7.200 s
  mean time:        7.216 s
  maximum time:     7.252 s
  --------------
  samples:          5
  evals/sample:     1
    

In [45]:
ksp = SigJl.nufft(coord_j, img)
@benchmark SigJl.nufft_adjoint(coord_j, ksp)

BenchmarkTools.Trial: 
  memory estimate:  1.44 GiB
  allocs estimate:  16215
  --------------
  minimum time:     1.972 s (0.09% GC)
  median time:      2.528 s (1.88% GC)
  mean time:        2.350 s (2.62% GC)
  maximum time:     2.550 s (5.30% GC)
  --------------
  samples:          3
  evals/sample:     1

In [46]:
plan_adj = plan'
@benchmark mul!(complexImg, plan_adj, ksp)

BenchmarkTools.Trial: 
  memory estimate:  256.33 KiB
  allocs estimate:  16008
  --------------
  minimum time:     1.356 s (0.00% GC)
  median time:      1.387 s (0.00% GC)
  mean time:        1.393 s (0.00% GC)
  maximum time:     1.442 s (0.00% GC)
  --------------
  samples:          4
  evals/sample:     1

#### Large 3D problem with 2D batch

In [47]:
M, batch, shape = 4186100, (3,4), (34, 30, 68)
img = rand(Float64, (batch..., shape...))
coord_py = rand(Float64, M, 3) .* collect(shape)' .- collect(shape)' ./2
coord_j = permutedims(coord_py, (2,1))
py"""
img = $img
coord = $coord_py
"""

In [48]:
ksp_j = SigJl.nufft(coord_j, img)
ksp_py = sp.nufft(img, coord_py)
print("absolute error: ", norm(ksp_j - ksp_py, Inf), "\n",
    "relative error: ", norm((ksp_j - ksp_py) ./ ksp_py, Inf))

absolute error: 9.963074863817178e-14
relative error: 8.17618951645943e-12

In [49]:
output_j = SigJl.nufft_adjoint(coord_j, ksp_j)
output_py = sp.nufft_adjoint(ksp_j, coord_py)
print("absolute error: ", norm(output_j - output_py, Inf), "\n",
    "relative error: ", norm((output_j - output_py) ./ output_py, Inf))

absolute error: 2.5898140618858957e-12
relative error: 6.982564500361215e-12

In [50]:
print(py"benchmark('sp.nufft(img, coord)')")


Python benchmark:
  --------------
  minimum time:     19.105 s
  median time:      19.711 s
  mean time:        19.711 s
  maximum time:     20.317 s
  --------------
  samples:          2
  evals/sample:     1
    

In [51]:
@benchmark SigJl.nufft(coord_j, img)

BenchmarkTools.Trial: 
  memory estimate:  925.92 MiB
  allocs estimate:  689
  --------------
  minimum time:     1.192 s (0.41% GC)
  median time:      1.522 s (9.10% GC)
  mean time:        1.457 s (7.63% GC)
  maximum time:     1.593 s (9.14% GC)
  --------------
  samples:          4
  evals/sample:     1

In [52]:
complexImg = convert.(ComplexF64, img)
plan = SigJl.nufft_plan(coord_j, complexImg)
output = plan * complexImg
@benchmark mul!(output, plan, complexImg)

BenchmarkTools.Trial: 
  memory estimate:  42.08 KiB
  allocs estimate:  544
  --------------
  minimum time:     1.025 s (0.00% GC)
  median time:      1.074 s (0.00% GC)
  mean time:        1.076 s (0.00% GC)
  maximum time:     1.106 s (0.00% GC)
  --------------
  samples:          5
  evals/sample:     1

In [53]:
py"""
ksp = sp.nufft(img, coord)
"""
print(py"benchmark('sp.nufft_adjoint(ksp, coord)')")


Python benchmark:
  --------------
  minimum time:     16.464 s
  median time:      17.101 s
  mean time:        17.101 s
  maximum time:     17.738 s
  --------------
  samples:          2
  evals/sample:     1
    

In [54]:
ksp = SigJl.nufft(coord_j, img)
@benchmark SigJl.nufft_adjoint(coord_j, ksp)

BenchmarkTools.Trial: 
  memory estimate:  1.11 GiB
  allocs estimate:  4186571
  --------------
  minimum time:     3.871 s (6.08% GC)
  median time:      3.914 s (7.64% GC)
  mean time:        3.914 s (7.64% GC)
  maximum time:     3.956 s (9.18% GC)
  --------------
  samples:          2
  evals/sample:     1

In [55]:
plan_adj = plan'
@benchmark mul!(complexImg, plan_adj, ksp)

BenchmarkTools.Trial: 
  memory estimate:  63.91 MiB
  allocs estimate:  4186243
  --------------
  minimum time:     3.301 s (0.00% GC)
  median time:      3.324 s (0.00% GC)
  mean time:        3.324 s (0.00% GC)
  maximum time:     3.346 s (0.00% GC)
  --------------
  samples:          2
  evals/sample:     1

## Unused code

In [15]:
function _spline_kernel(x::T, order::T)::T where {T<:Real}
    abs(x) > 1 && return zero(x)

    if order == 0
        return one(x)
    elseif order == 1
        return 1 - abs(x)
    elseif order == 2
        if abs(x) > 1 / 3
            return 9 / 8 * (1 - abs(x))^2
        else
            return 3 / 4 * (1 - 3 * x^2)
        end
    else
        @assert "Only {0,1,2}-order spline kernel is supported"
    end
end

_spline_kernel (generic function with 1 method)

In [16]:
function _kaiser_bessel_kernel(x::T, β::T)::T where {T<:Real}
    abs(x) > 1 && return zero(x)

    x = β * √(1 - x^2)
    t = x / 3.75
    if x < 3.75
        return 1 + 3.5156229 * t^2 + 3.0899424 * t^4 +
            1.2067492 * t^6 + 0.2659732 * t^8 +
            0.0360768 * t^10 + 0.0045813 * t^12
    else
        return x^-0.5 * exp(x) * (
            0.39894228 + 0.01328592 * t^-1 +
            0.00225319 * t^-2 - 0.00157565 * t^-3 +
            0.00916281 * t^-4 - 0.02057706 * t^-5 +
            0.02635537 * t^-6 - 0.01647633 * t^-7 +
            0.00392377 * t^-8)
    end
end

_kaiser_bessel_kernel (generic function with 1 method)