-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance vs FFT #84
Comments
Currently my
Repeating the 0th dimension does only increase the runtime. So I think what I'm doing, looks ok? |
Hello @roflmaostc, I think this is expected. The first factor is The second factor is the default grid size is to use a 2X oversampled grid, so based on that it's not an apples-to-apples comparison. This gives some pretty massive 1024 x 1024 2D FFTs inside the NUFFT vs. the 512 1D FFT that you're using for FFT. You could be more conservative and use the more-standard 1.5-factor oversampling by setting The third is that due to the high-level implementation in Python, our interpolation is quite a bit slower. There are several mitigations for this, such as broadcasting across sensitivity coils, but it will never be as fast as a compiled implementation. The advantage is that you never have to worry about compiling The last item is I've always had a bit of trouble squeezing out performance on the CPU for the adjoint, and have generally observed the GPU to be much closer to forward performance. One possible mitigation, if you have a parametrization of your problem that's amenable to rewriting in terms of A'A, you can use the Toeplitz NUFFT for the forward-backward, which only uses FFTs and no interpolation. |
Thanks for your detailed reply! Yeah, that's right. Your point regarding ToepNUFFT is interesting, and exactly what I'm looking for. I'm asking, since we would expect the cost of of 4 (padding) x 2 (back and forth)=8 FFTs. But in my cases, a naive application of import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
import matplotlib.pyplot as plt
import napari
import glob
import imageio.v3 as iio
import os
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = "cpu"
N = 200
N_z = 300
N_angles = int(np.ceil(np.pi * N / 2))
voxels = (torch.zeros(N_z, 1, N, N, dtype=torch.float32) + 1j * 0).to(device).to(torch.complex64)
def toeplitz(arr, N_angles):
N = arr.shape[-1]
k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
k_y = np.cos(angles) * k_1d
k_x = np.sin(angles) * k_1d
ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
ktraj = ktraj.repeat(1, 1, 1)
toep_ob = tkbn.ToepNufft().to(device)
kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=5).to(device)
f = lambda x: toep_ob(x, kernel)
return f, kernel
def toeplitz2(arr, N_angles):
N = arr.shape[-1]
k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
k_y = np.cos(angles) * k_1d
k_x = np.sin(angles) * k_1d
ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
ktraj = ktraj.repeat(1, 1, 1)
toep_ob = tkbn.ToepNufft().to(device)
kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=8).to(device)
kernel = kernel.reshape(kernel.shape[0], 1, kernel.shape[1], kernel.shape[2])
def apply_kernel(x):
x_pad = torch.zeros(x.shape[0], 1, 2 * x.shape[2], 2 * x.shape[3], device=device)
x_pad[:, :, 0:x.shape[2], 0:x.shape[3]] = x
res = torch.fft.ifft2(torch.fft.fft2(x_pad) * kernel)[:, :, 0:x.shape[2], 0:x.shape[3]]
return res
return apply_kernel
toeplitz_f, kernel = toeplitz(voxels, N_angles)
apply_kernel = toeplitz2(voxels, N_angles)
%%time
arr2 = toeplitz_f(voxels)
CPU times: user 20.8 ms, sys: 547 µs, total: 21.3 ms
Wall time: 20.7 ms
%%time
arr3 = apply_kernel(voxels)
CPU times: user 928 µs, sys: 0 ns, total: 928 µs
Wall time: 524 µs
|
Are you running on the GPU? You have to call import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
import matplotlib.pyplot as plt
import glob
import os
import torch.nn.functional as F
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = "cpu"
N = 200
N_z = 300
N_angles = int(np.ceil(np.pi * N / 2))
voxels = (torch.zeros(N_z, 1, N, N, dtype=torch.float32) + 1j * 0).to(device).to(torch.complex64)
def toeplitz(arr, N_angles):
N = arr.shape[-1]
k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
k_y = np.cos(angles) * k_1d
k_x = np.sin(angles) * k_1d
ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
ktraj = ktraj.repeat(1, 1, 1)
toep_ob = tkbn.ToepNufft().to(device)
kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=5).to(device)
f = lambda x: toep_ob(x, kernel)
return f, kernel
def toeplitz2(arr, N_angles):
N = arr.shape[-1]
k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
k_y = np.cos(angles) * k_1d
k_x = np.sin(angles) * k_1d
ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
ktraj = ktraj.repeat(1, 1, 1)
toep_ob = tkbn.ToepNufft().to(device)
kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=8).to(device)
kernel = kernel.reshape(kernel.shape[0], 1, kernel.shape[1], kernel.shape[2])
def apply_kernel(x):
im_size = torch.tensor(x.shape[2:])
grid_size = torch.tensor(
kernel.shape[-len(kernel.shape[2:]) :], dtype=torch.long, device=kernel.device
)
pad_sizes = []
for (gd, im) in zip(grid_size.flip((0,)), im_size.flip((0,))):
pad_sizes.append(0)
pad_sizes.append(int(gd - im))
x_pad = F.pad(x, pad_sizes)
print(x_pad.shape)
print(kernel.shape)
res = torch.fft.fftn(torch.fft.fftn(x_pad, dim=[-2, -1], norm="ortho") * kernel, dim=[-2, -1], norm="ortho")[:, :, :im_size[-2], :im_size[-1]]
return res
return apply_kernel
toeplitz_f, kernel = toeplitz(voxels, N_angles)
apply_kernel = toeplitz2(voxels, N_angles)
import time
start = time.perf_counter()
arr2 = toeplitz_f(voxels)
end = time.perf_counter()
print(f"Toeplitz: {end-start}")
start = time.perf_counter()
arr3 = apply_kernel(voxels)
end = time.perf_counter()
print(f"simple pad: {end-start}") |
You're right, Then I get ~20ms and ~27ms. A little overhead seems ok :) Thanks for helping me! |
I think that might also explain the discrepancy in the initial post. |
Hi,
thanks for this package :)!
However, I was wondering about the performance. I would have expected maybe x10 in difference to a standard FFT (from my past experience with CPU based NFFTs.). But here, I observe a difference in factor x100. The adjoint is almost a factor of x1000 slower.
Any thoughts on this?
Felix
Code:
The text was updated successfully, but these errors were encountered: