-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Difference in performance for STFT vs PyTorch #14
Comments
You have to discount the initial run because that's also including XLA client initialization steps as well as JIT compilation time. If your timings are proportional to mine, you should get a considerable reduction in the execution time. That being said, it seems that the majority of the time is spent calculating the FFT itself. I wonder if we'll have to do in Nx something similar to what we did with SVD. I'll look into it. |
Ah, also, make sure to set Nx.default_backend(EXLA.Backend) too :) |
Ok, I found the main problem: EXLA doesn't deal with radix-N FFT. If you use window_size: 512 or some power of 2, FFT is a lot faster. Also, Anyway, I got a speedup from 1s to 0.1s=100ms by using |
Yes, sorry I wasn't clear, that 1.5s runtime was after the JIT had been "warmed up". First run was 1.7s. I also left out my config where I set EXLA as the default backend, but I confirmed it is using EXLA.
import Config
config :nx, :default_backend, EXLA.Backend I'm seeing similar numbers as yours after the changes, thanks! Although FWIW I wonder why it's still 10x slower? Typically EXLA is faster than PyTorch in my experience. |
The great majority of the slowdown is due to the window size not being a power of 2. I'm gonna experiment with the fix I mentioned in the Nx issue, let's see where that goes. To go into a bit more detail, the most straightforward/intuitive implementation of the Fast Fourier Transform algorithm is only Fast™ for vectors of power of 2 length. After all factors of 2 are exhausted (400 = 2⁴*25) the remainder is performed as a DFT instead (of length 25 in this case). So this amounts to a significant difference. But there are radix-N algorithms that can deal with more prime factors for the recursion step, and thus are faster. |
And even with getting them to comparable implementations, it might be the case that either XLA doesn't optimize complex numbers all that well, or the core FFT implementation isn't as fast as possible (which I doubt). Given the tensor size we might be seeing issues with memory allocation time too, on the EXLA side of things. Lots of factors to explore |
By the way, have you tried running the FFT with a random (or at least nonzero) tensor in both libs? edit: I just ran, so it isn't any zero check that makes it 5x faster |
As mentioned in the Nx issue, this is an upstream issue with no immediate fix on our side. |
@mortont I think we can close this issue for now. Maybe we can open a separate one pad reflect itself. Here's a benchmark I ran in Benchee using CUDA for both EXLA and Torchx: CodeMix.install [:benchee, {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, {:exla, github: "elixir-nx/nx", sparse: "exla"}, {:torchx, github: "elixir-nx/nx", sparse: "torchx"}], system_env: %{"XLA_TARGET" => "cuda114", "LIBTORCH_TARGET" => "cu117"}
Application.put_env(:exla, :clients, [host: [platform: :host], cuda: [platform: :cuda, preallocate: true, memory_fraction: 0.5]])
defmodule TorchxFFT do
import Nx.Defn
defn fft(t) do
while {t = Nx.as_type(t, :c64)}, i <- 1..10, unroll: true do
{Nx.fft(t + i)}
end
end
end
Nx.Defn.default_options(compiler: EXLA)
defmodule EXLAFFT do
import Nx.Defn
defn fft(t) do
while {t = Nx.as_type(t, :c64)}, i <- 1..10, unroll: true do
{Nx.fft(t + i)}
end
end
end
Benchee.run(%{"EXLA while" => fn input -> EXLAFFT.fft(input.exla) end, "EXLA fft" => &Enum.each(1..10, fn _ -> Nx.fft(&1.exla) end), "Torchx fft" => &Enum.each(1..10, fn _ -> Nx.fft(&1.torchx) end), "Torchx while" => fn input -> TorchxFFT.fft(input.torchx) end}, inputs: %{
"1x400" => %{exla: Nx.iota({1, 400}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({1, 400}, backend: {Torchx.Backend, device: :cuda})},
"1x512" => %{exla: Nx.iota({1, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({1, 512}, backend: {Torchx.Backend, device: :cuda})},
"3kx400" => %{exla: Nx.iota({3000, 400}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({3000, 400}, backend: {Torchx.Backend, device: :cuda})},
"3kx512" => %{exla: Nx.iota({3000, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({3000, 512}, backend: {Torchx.Backend, device: :cuda})},
"4096x512" => %{exla: Nx.iota({4096, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({4096, 512}, backend: {Torchx.Backend, device: :cuda})},
}, warmup: 5, time: 10); nil The idea was to check if while had a positive or negative impact in the "let's run this computation 10 times on the GPU" scenario. ResultsWith input 1x400Name ips average deviation median 99th % Comparison: With input 1x512Name ips average deviation median 99th % Comparison: With input 3kx400Name ips average deviation median 99th % Comparison: With input 3kx512Name ips average deviation median 99th % Comparison: With input 4096x512Name ips average deviation median 99th % Comparison: ConclusionI don't quite understand how Torchx with while ends up being more performant than with for since the Defn evaluator is also "pure" Elixir (cc @josevalim). However, the results were consistent on every input kind. It's worth noting that while the relative differences are significant, we're talking about differences of less than 5ms between EXLA fft and Torchx fft, which are directly equivalent, and of less than 10ms between "Torchx while" and "EXLA while", which are probably the more natural use cases. The point I want to make is not that 10ms is an insignificant difference, but is miles away from the many tens of ms in the best case scenario we had with the CPU execution. It's important to ensure that (cuFFT is installed)[https://docs.nvidia.com/cuda/cufft/index.html]. Also note that I benchmarked only FFT because that's where we have less control. |
Only the function dispatch is pure Elixir, it still calls the backend functions so I assume the FFT/TFFT code in XLA is not the fastest? Does Jax does anything in special to speed the FFT code? |
@josevalim I was referring specifically to the fact that the torchx code using while was somehow faster than the elixir code using for. Probably an artifact of the benchmark, though. |
Great, thanks for the exhaustive benchmark! From the upstream thread you mentioned, it looks like Jax has worked around the slow FFT implementation by using PocketFFT rather than XLA (Eigen) specifically for FFTs on the CPU google/jax#2952. Is that something EXLA could do? With that said, changing the FFT size to a power of 2 and changing STFT window padding to |
I've noticed the speed of
NxSignal.stft
is much slower than in PyTorch (on the order of ~500x). Is this expected? I'm usingExla.Backend
, but I feel like I'm missing something. In my tests I see a STFT of zeros with length 480000 taking ~3ms in PyTorch and ~1.5 seconds (after the initial jit compilation) inNxSignal
:This returns ~ 3000us
And (what I think is) the corresponding Elixir
This returns 1596447us after the initial run.
The text was updated successfully, but these errors were encountered: