Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Metal FFT for all N #981

Closed
wants to merge 16 commits into from
99 changes: 82 additions & 17 deletions benchmarks/python/fft_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import matplotlib
import mlx.core as mx
import numpy as np
import sympy
import torch
from time_utils import measure_runtime

matplotlib.use("Agg")
Expand All @@ -16,41 +18,104 @@ def bandwidth_gb(runtime_ms, system_size):
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb


def run_bench(system_size):
def run_bench(system_size, fft_sizes):
def fft(x):
out = mx.fft.fft(x)
mx.eval(out)
return out

bandwidths = []
for k in range(4, 12):
n = 2**k
for n in fft_sizes:
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
bandwidth = bandwidth_gb(runtime_ms, system_size // n * n)
print("bandwidth", n, bandwidth)
bandwidths.append(bandwidth)

return bandwidths


def run_bench_mps(system_size, fft_sizes):
def fft(x):
out = torch.fft.fft(x)
torch.mps.synchronize()
return out

bandwidths = []
for n in fft_sizes:
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()

runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, system_size // n * n)
print("bandwidth", n, bandwidth)
bandwidths.append(bandwidth)

return bandwidths


def time_fft():
x = range(4, 1024)
system_size = int(2**24)

with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)

np.save("gpu_bandwidths", gpu_bandwidths)

mps_bandwidths = run_bench_mps(system_size=system_size, fft_sizes=x)

np.save("mps_bandwidths", mps_bandwidths)

system_size = int(2**21)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)

with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))

# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
np.save("cpu_bandwidths", cpu_bandwidths)

# cpu_bandwidths = np.load("cpu_bandwidths.npy")
# gpu_bandwidths = np.load("gpu_bandwidths.npy")
# mps_bandwidths = np.load("mps_bandwidths.npy")

x = np.array(x)

all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)

for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()

av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)

portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)


if __name__ == "__main__":
Expand Down
94 changes: 94 additions & 0 deletions mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

#include <cassert>
#include <cmath>
#include <numeric>

#include <vecLib/vDSP.h>
#include <vecLib/vForce.h>

#include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
Expand Down Expand Up @@ -34,6 +36,7 @@ DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
Expand Down Expand Up @@ -232,6 +235,97 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}

void BluesteinFFTSetup::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// We need to calculate the Bluestein twiddle factors
// in double precision for the overall numerical stability
// of Bluestein's FFT algorithm to be acceptable.
//
// MLX currently support float64, so instead we
// manually implement the required operations using accelerate.
//
// In numpy:
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
// w_q = np.fft.fft(1/w_k)
// return w_k, w_q
//
assert(inputs.size() == 0);

auto& w_q = outputs[0];
auto& w_k = outputs[1];

size_t fft_size = w_q.shape(0);

int length = 2 * n_ - 1;

std::vector<double> x(length);
std::vector<double> y(length);

std::iota(x.begin(), x.end(), -n_ + 1);
vDSP_vsqD(x.data(), 1, y.data(), 1, x.size());
double theta = (double)1.0 / (double)n_;
vDSP_vsmulD(y.data(), 1, &theta, x.data(), 1, x.size());

std::vector<double> real_part(length);
std::vector<double> imag_part(length);
vvcospi(real_part.data(), x.data(), &length);
vvsinpi(imag_part.data(), x.data(), &length);

double minus_1 = -1.0;
vDSP_vsmulD(x.data(), 1, &minus_1, y.data(), 1, x.size());

// compute w_k
std::vector<double> real_part_w_k(n_);
std::vector<double> imag_part_w_k(n_);
vvcospi(real_part_w_k.data(), y.data() + length - n_, &n_);
vvsinpi(imag_part_w_k.data(), y.data() + length - n_, &n_);

auto convert_float = [](double real, double imag) {
return std::complex<float>(real, imag);
};

// convert back to float now we've done the sincos
std::vector<std::complex<float>> w_k_input(n_, 0.0);
std::transform(
real_part_w_k.begin(),
real_part_w_k.end(),
imag_part_w_k.begin(),
w_k_input.begin(),
convert_float);

w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));

auto w_k_ptr =
reinterpret_cast<std::complex<float>*>(w_k.data<complex64_t>());
memcpy(w_k_ptr, w_k_input.data(), n_ * w_k.itemsize());

// convert back to float now we've done the sincos
std::vector<std::complex<float>> fft_input(fft_size, 0.0);
std::transform(
real_part.begin(),
real_part.end(),
imag_part.begin(),
fft_input.begin(),
convert_float);

w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());

std::ptrdiff_t item_size = w_q.itemsize();

pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ fft_input.data(),
/* data_out= */ w_q_ptr,
/* scale= */ 1.0f);
}

void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(BluesteinFFTSetup)
DEFAULT(Broadcast)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
Expand Down
10 changes: 10 additions & 0 deletions mlx/backend/common/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}

void BluesteinFFTSetup::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("Bluestein is only implemented in accelerate.");
}

void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Expand Down Expand Up @@ -203,6 +209,10 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
}
}

void Conjugate::eval(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[conjugate] conjugate not yet implemented on CPU.");
}

void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
Expand Down
1 change: 0 additions & 1 deletion mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ MTL::Function* Device::get_function_(
}

mtl_func_consts->release();
desc->release();

return mtl_function;
}
Expand Down
Loading