Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance vs FFT #84

Closed
roflmaostc opened this issue Apr 13, 2023 · 6 comments
Closed

Performance vs FFT #84

roflmaostc opened this issue Apr 13, 2023 · 6 comments

Comments

@roflmaostc
Copy link

Hi,

thanks for this package :)!

However, I was wondering about the performance. I would have expected maybe x10 in difference to a standard FFT (from my past experience with CPU based NFFTs.). But here, I observe a difference in factor x100. The adjoint is almost a factor of x1000 slower.

Any thoughts on this?

Felix

Code:

import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
import matplotlib.pyplot as plt
import napari
import glob
import imageio.v3 as iio
import os
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

N = 512
N_z = 100
N_angles = int(np.ceil(np.pi * N / 2))
voxels = (torch.zeros(1, N_z, N, N, dtype=torch.float32) + 1j * 0).to(device).to(torch.complex64)


def prepare_nufft(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-0.5, 0.5, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    nufft = tkbn.KbNufft(im_size=(N, N), numpoints=6).to(device)

    adjnufft = tkbn.KbNufftAdjoint(im_size=(N, N), numpoints=6).to(device)
    
    return nufft, adjnufft, ktraj


nufft, adjnufft, ktraj = prepare_nufft(voxels, N_angles)

%%time
kdata = nufft(voxels, ktraj)
CPU times: user 30.1 ms, sys: 0 ns, total: 30.1 ms
Wall time: 29.6 ms

%%time
#kdata = nufft(voxels, ktraj)
img_filter = adjnufft(kdata, ktraj)
#torch.cuda.synchronize()
CPU times: user 268 ms, sys: 172 ms, total: 440 ms
Wall time: 394 ms



%%time
torch.fft.fft(voxels)
CPU times: user 337 µs, sys: 180 µs, total: 517 µs
Wall time: 283 µs
@roflmaostc
Copy link
Author

Currently my ktraj shape is:

ktraj.shape

torch.Size([1, 2, 412160])

Repeating the 0th dimension does only increase the runtime. So I think what I'm doing, looks ok?

@mmuckley
Copy link
Owner

Hello @roflmaostc, I think this is expected.

The first factor is torch.fft.fft only does a 1D FFT, so it is doing a factor N fewer FFTs to begin with.

The second factor is the default grid size is to use a 2X oversampled grid, so based on that it's not an apples-to-apples comparison. This gives some pretty massive 1024 x 1024 2D FFTs inside the NUFFT vs. the 512 1D FFT that you're using for FFT. You could be more conservative and use the more-standard 1.5-factor oversampling by setting grid_size to (768, 768).

The third is that due to the high-level implementation in Python, our interpolation is quite a bit slower. There are several mitigations for this, such as broadcasting across sensitivity coils, but it will never be as fast as a compiled implementation. The advantage is that you never have to worry about compiling torchkbnufft, but the disadvantage is speed.

The last item is I've always had a bit of trouble squeezing out performance on the CPU for the adjoint, and have generally observed the GPU to be much closer to forward performance.

One possible mitigation, if you have a parametrization of your problem that's amenable to rewriting in terms of A'A, you can use the Toeplitz NUFFT for the forward-backward, which only uses FFTs and no interpolation.

@roflmaostc
Copy link
Author

roflmaostc commented Apr 13, 2023

Thanks for your detailed reply!

Yeah, that's right.

Your point regarding ToepNUFFT is interesting, and exactly what I'm looking for.
The kernel it spits out, is that applied with ifft(fft(pad(arr)) * kernel)? where pad would apply the padding?

I'm asking, since we would expect the cost of of 4 (padding) x 2 (back and forth)=8 FFTs.

But in my cases, a naive application of ifft(fft(pad(arr)) * kernel) with precalculated kernels is still ~40 times faster.

import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
import matplotlib.pyplot as plt
import napari
import glob
import imageio.v3 as iio
import os
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = "cpu"

N = 200
N_z = 300
N_angles = int(np.ceil(np.pi * N / 2))
voxels = (torch.zeros(N_z, 1, N, N, dtype=torch.float32) + 1j * 0).to(device).to(torch.complex64)

def toeplitz(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    
    ktraj = ktraj.repeat(1, 1, 1)
    
    toep_ob = tkbn.ToepNufft().to(device)
    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=5).to(device)
    
    f = lambda x: toep_ob(x, kernel)
    
    return f, kernel

def toeplitz2(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    
    ktraj = ktraj.repeat(1, 1, 1)
    
    toep_ob = tkbn.ToepNufft().to(device)
    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=8).to(device)
    

    kernel = kernel.reshape(kernel.shape[0], 1, kernel.shape[1], kernel.shape[2])
    def apply_kernel(x):
        x_pad = torch.zeros(x.shape[0], 1, 2 * x.shape[2], 2 * x.shape[3], device=device)
        x_pad[:, :, 0:x.shape[2], 0:x.shape[3]] = x
        
        res = torch.fft.ifft2(torch.fft.fft2(x_pad) * kernel)[:, :, 0:x.shape[2], 0:x.shape[3]]
        return res
    
    return apply_kernel


toeplitz_f, kernel = toeplitz(voxels, N_angles)
apply_kernel = toeplitz2(voxels, N_angles)



%%time
arr2 = toeplitz_f(voxels)
CPU times: user 20.8 ms, sys: 547 µs, total: 21.3 ms
Wall time: 20.7 ms

%%time
arr3 = apply_kernel(voxels)
CPU times: user 928 µs, sys: 0 ns, total: 928 µs
Wall time: 524 µs

@mmuckley
Copy link
Owner

Are you running on the GPU? You have to call torch.cuda.synchronize(). I get almost the exact same times on the CPU for this code.

import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
import matplotlib.pyplot as plt
import glob
import os
import torch.nn.functional as F
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = "cpu"

N = 200
N_z = 300
N_angles = int(np.ceil(np.pi * N / 2))
voxels = (torch.zeros(N_z, 1, N, N, dtype=torch.float32) + 1j * 0).to(device).to(torch.complex64)

def toeplitz(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    
    ktraj = ktraj.repeat(1, 1, 1)
    
    toep_ob = tkbn.ToepNufft().to(device)
    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=5).to(device)
    
    f = lambda x: toep_ob(x, kernel)
    
    return f, kernel

def toeplitz2(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    
    ktraj = ktraj.repeat(1, 1, 1)
    
    toep_ob = tkbn.ToepNufft().to(device)
    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=8).to(device)

    kernel = kernel.reshape(kernel.shape[0], 1, kernel.shape[1], kernel.shape[2])
    def apply_kernel(x):
        im_size = torch.tensor(x.shape[2:])

        grid_size = torch.tensor(
            kernel.shape[-len(kernel.shape[2:]) :], dtype=torch.long, device=kernel.device
        )
        pad_sizes = []
        for (gd, im) in zip(grid_size.flip((0,)), im_size.flip((0,))):
            pad_sizes.append(0)
            pad_sizes.append(int(gd - im))
        x_pad = F.pad(x, pad_sizes)
        print(x_pad.shape)
        print(kernel.shape)
        
        res = torch.fft.fftn(torch.fft.fftn(x_pad, dim=[-2, -1], norm="ortho") * kernel, dim=[-2, -1], norm="ortho")[:, :, :im_size[-2], :im_size[-1]]
        return res
    
    return apply_kernel


toeplitz_f, kernel = toeplitz(voxels, N_angles)
apply_kernel = toeplitz2(voxels, N_angles)


import time

start = time.perf_counter()
arr2 = toeplitz_f(voxels)
end = time.perf_counter()
print(f"Toeplitz: {end-start}")

start = time.perf_counter()
arr3 = apply_kernel(voxels)
end = time.perf_counter()
print(f"simple pad: {end-start}")

@roflmaostc
Copy link
Author

You're right, torch.cuda.synchronize fixed it.

Then I get ~20ms and ~27ms. A little overhead seems ok :)

Thanks for helping me!

@roflmaostc
Copy link
Author

I think that might also explain the discrepancy in the initial post.
So all fine!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants