Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference in performance for STFT vs PyTorch #14

Closed
mortont opened this issue May 26, 2023 · 12 comments
Closed

Difference in performance for STFT vs PyTorch #14

mortont opened this issue May 26, 2023 · 12 comments

Comments

@mortont
Copy link

mortont commented May 26, 2023

I've noticed the speed of NxSignal.stft is much slower than in PyTorch (on the order of ~500x). Is this expected? I'm using Exla.Backend, but I feel like I'm missing something. In my tests I see a STFT of zeros with length 480000 taking ~3ms in PyTorch and ~1.5 seconds (after the initial jit compilation) in NxSignal:

import time
import torch


class STFT(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fft_length = 400
        self.hop_length = 160

        self.window = torch.nn.Parameter(
            torch.hann_window(self.fft_length, periodic=True), requires_grad=False
        )

    def forward(self, sample):
        stft = torch.stft(
            sample,
            self.fft_length,
            self.hop_length,
            window=self.window,
            return_complex=True,
        )

        return stft


if __name__ == "__main__":
    s = STFT().eval()
    i = torch.zeros([480000])

    start = time.time()
    result = s(i)
    print(f"took {(time.time() - start)*1000000}us")

This returns ~ 3000us
And (what I think is) the corresponding Elixir

defmodule StftTest do
  import Nx.Defn

  defn test(sample) do
    fft_length = 400
    sample_rate = 16000
    hop_length = 160

    window = NxSignal.Windows.hann(n: fft_length, is_periodic: true)

    {stft, _, _} =
      NxSignal.stft(sample, window,
        sampling_rate: sample_rate,
        fft_length: fft_length,
        overlap_length: fft_length - hop_length,
        window_padding: :reflect
      )

    stft
  end

  def test_jit() do
    sample = Nx.broadcast(0.0, {480_000})
    {time, _} = :timer.tc(fn -> Nx.Defn.jit(&test/1, compiler: EXLA).(sample) end)
    IO.puts("STFT took #{time}us")
  end
end

This returns 1596447us after the initial run.

@polvalente
Copy link
Collaborator

You have to discount the initial run because that's also including XLA client initialization steps as well as JIT compilation time. If your timings are proportional to mine, you should get a considerable reduction in the execution time.

That being said, it seems that the majority of the time is spent calculating the FFT itself. I wonder if we'll have to do in Nx something similar to what we did with SVD. I'll look into it.

@polvalente
Copy link
Collaborator

Ah, also, make sure to set Nx.default_backend(EXLA.Backend) too :)

@polvalente
Copy link
Collaborator

Ok, I found the main problem: EXLA doesn't deal with radix-N FFT. If you use window_size: 512 or some power of 2, FFT is a lot faster. Also, padding: :reflect will slow down your calculations. I just noticed that your Pytorch code doesn't have padding='reflect' set anywhere, so maybe that also contributes.

Anyway, I got a speedup from 1s to 0.1s=100ms by using padding: :valid (which is the default), and then to ~30ms by changing the window size from 400 to 512 (same hop) or ~20ms@window_size=256. All of the times after a few warmup runs to ignore any memory allocation and JIT compilation effects on the measurement.

@mortont
Copy link
Author

mortont commented May 26, 2023

Yes, sorry I wasn't clear, that 1.5s runtime was after the JIT had been "warmed up". First run was 1.7s. I also left out my config where I set EXLA as the default backend, but I confirmed it is using EXLA.

config/config.exs:

import Config
config :nx, :default_backend, EXLA.Backend

I'm seeing similar numbers as yours after the changes, thanks! Although FWIW padding='reflect' is PyTorch's default, and explicitly setting it doesn't seem to alter the speed.

I wonder why it's still 10x slower? Typically EXLA is faster than PyTorch in my experience.

@polvalente
Copy link
Collaborator

The great majority of the slowdown is due to the window size not being a power of 2. I'm gonna experiment with the fix I mentioned in the Nx issue, let's see where that goes.

To go into a bit more detail, the most straightforward/intuitive implementation of the Fast Fourier Transform algorithm is only Fast™ for vectors of power of 2 length. After all factors of 2 are exhausted (400 = 2⁴*25) the remainder is performed as a DFT instead (of length 25 in this case). So this amounts to a significant difference.

But there are radix-N algorithms that can deal with more prime factors for the recursion step, and thus are faster.

@polvalente
Copy link
Collaborator

And even with getting them to comparable implementations, it might be the case that either XLA doesn't optimize complex numbers all that well, or the core FFT implementation isn't as fast as possible (which I doubt).

Given the tensor size we might be seeing issues with memory allocation time too, on the EXLA side of things.

Lots of factors to explore

@polvalente
Copy link
Collaborator

polvalente commented May 26, 2023

By the way, have you tried running the FFT with a random (or at least nonzero) tensor in both libs?

edit: I just ran, so it isn't any zero check that makes it 5x faster

@polvalente
Copy link
Collaborator

As mentioned in the Nx issue, this is an upstream issue with no immediate fix on our side.
tensorflow/tensorflow#6541

@polvalente
Copy link
Collaborator

polvalente commented May 30, 2023

@mortont I think we can close this issue for now. Maybe we can open a separate one pad reflect itself.

Here's a benchmark I ran in Benchee using CUDA for both EXLA and Torchx:

Code

Mix.install [:benchee, {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, {:exla, github: "elixir-nx/nx", sparse: "exla"}, {:torchx, github: "elixir-nx/nx", sparse: "torchx"}], system_env: %{"XLA_TARGET" => "cuda114", "LIBTORCH_TARGET" => "cu117"}
Application.put_env(:exla, :clients, [host: [platform: :host], cuda: [platform: :cuda, preallocate: true, memory_fraction: 0.5]])

defmodule TorchxFFT do
  import Nx.Defn
    
  defn fft(t) do
    while {t = Nx.as_type(t, :c64)}, i <- 1..10, unroll: true do
      {Nx.fft(t + i)}
    end
  end
end

Nx.Defn.default_options(compiler: EXLA)
defmodule EXLAFFT do
  import Nx.Defn
    
  defn fft(t) do
    while {t = Nx.as_type(t, :c64)}, i <- 1..10, unroll: true do
      {Nx.fft(t + i)}
    end
  end
end

Benchee.run(%{"EXLA while" => fn input -> EXLAFFT.fft(input.exla) end, "EXLA fft" => &Enum.each(1..10, fn _ -> Nx.fft(&1.exla) end), "Torchx fft" => &Enum.each(1..10, fn _ -> Nx.fft(&1.torchx) end), "Torchx while" => fn input -> TorchxFFT.fft(input.torchx) end}, inputs: %{
  "1x400" => %{exla: Nx.iota({1, 400}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({1, 400}, backend: {Torchx.Backend, device: :cuda})}, 
  "1x512" => %{exla: Nx.iota({1, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({1, 512}, backend: {Torchx.Backend, device: :cuda})}, 
  "3kx400" => %{exla: Nx.iota({3000, 400}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({3000, 400}, backend: {Torchx.Backend, device: :cuda})}, 
  "3kx512" => %{exla: Nx.iota({3000, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({3000, 512}, backend: {Torchx.Backend, device: :cuda})},
  "4096x512" => %{exla: Nx.iota({4096, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({4096, 512}, backend: {Torchx.Backend, device: :cuda})},
}, warmup: 5, time: 10); nil

The idea was to check if while had a positive or negative impact in the "let's run this computation 10 times on the GPU" scenario.

Results

With input 1x400

Name ips average deviation median 99th %
Torchx fft 1599.15 0.63 ms ±13.99% 0.61 ms 0.97 ms
Torchx while 893.97 1.12 ms ±10.95% 1.10 ms 1.55 ms
EXLA fft 855.54 1.17 ms ±12.23% 1.14 ms 1.67 ms
EXLA while 268.87 3.72 ms ±10.61% 3.63 ms 4.96 ms

Comparison:
Torchx fft 1599.15
Torchx while 893.97 - 1.79x slower +0.49 ms
EXLA fft 855.54 - 1.87x slower +0.54 ms
EXLA while 268.87 - 5.95x slower +3.09 ms

With input 1x512

Name ips average deviation median 99th %
Torchx fft 1472.02 0.68 ms ±15.55% 0.65 ms 0.97 ms
Torchx while 906.34 1.10 ms ±14.54% 1.07 ms 1.63 ms
EXLA fft 861.77 1.16 ms ±11.67% 1.13 ms 1.62 ms
EXLA while 277.07 3.61 ms ±8.88% 3.54 ms 4.61 ms

Comparison:
Torchx fft 1472.02
Torchx while 906.34 - 1.62x slower +0.42 ms
EXLA fft 861.77 - 1.71x slower +0.48 ms
EXLA while 277.07 - 5.31x slower +2.93 ms

With input 3kx400

Name ips average deviation median 99th %
Torchx while 231.23 4.32 ms ±2.77% 4.32 ms 4.57 ms
Torchx fft 224.89 4.45 ms ±0.97% 4.45 ms 4.49 ms
EXLA fft 143.81 6.95 ms ±6.81% 6.98 ms 8.00 ms
EXLA while 94.47 10.58 ms ±5.62% 10.53 ms 12.01 ms

Comparison:
Torchx while 231.23
Torchx fft 224.89 - 1.03x slower +0.122 ms
EXLA fft 143.81 - 1.61x slower +2.63 ms
EXLA while 94.47 - 2.45x slower +6.26 ms

With input 3kx512

Name ips average deviation median 99th %
Torchx while 231.56 4.32 ms ±1.61% 4.30 ms 4.55 ms
Torchx fft 181.71 5.50 ms ±0.79% 5.50 ms 5.57 ms
EXLA fft 137.46 7.27 ms ±7.17% 7.28 ms 8.46 ms
EXLA while 94.11 10.63 ms ±4.05% 10.62 ms 11.61 ms

Comparison:
Torchx while 231.56
Torchx fft 181.71 - 1.27x slower +1.18 ms
EXLA fft 137.46 - 1.68x slower +2.96 ms
EXLA while 94.11 - 2.46x slower +6.31 ms

With input 4096x512

Name ips average deviation median 99th %
Torchx while 174.26 5.74 ms ±1.66% 5.74 ms 5.97 ms
Torchx fft 134.62 7.43 ms ±0.48% 7.43 ms 7.48 ms
EXLA fft 112.89 8.86 ms ±5.64% 8.90 ms 10.00 ms
EXLA while 82.10 12.18 ms ±5.32% 12.08 ms 13.90 ms

Comparison:
Torchx while 174.26
Torchx fft 134.62 - 1.29x slower +1.69 ms
EXLA fft 112.89 - 1.54x slower +3.12 ms
EXLA while 82.10 - 2.12x slower +6.44 ms

Conclusion

I don't quite understand how Torchx with while ends up being more performant than with for since the Defn evaluator is also "pure" Elixir (cc @josevalim). However, the results were consistent on every input kind.

It's worth noting that while the relative differences are significant, we're talking about differences of less than 5ms between EXLA fft and Torchx fft, which are directly equivalent, and of less than 10ms between "Torchx while" and "EXLA while", which are probably the more natural use cases.

The point I want to make is not that 10ms is an insignificant difference, but is miles away from the many tens of ms in the best case scenario we had with the CPU execution. It's important to ensure that (cuFFT is installed)[https://docs.nvidia.com/cuda/cufft/index.html].

Also note that I benchmarked only FFT because that's where we have less control.
It might be worth opening a separate issue if even with the assurances of using cuFFT you still have a less than desirable performance with STFT. padding: :reflect was a problematic option to implement, so bear that in mind when .

@josevalim
Copy link
Contributor

Only the function dispatch is pure Elixir, it still calls the backend functions so I assume the FFT/TFFT code in XLA is not the fastest? Does Jax does anything in special to speed the FFT code?

@polvalente
Copy link
Collaborator

@josevalim I was referring specifically to the fact that the torchx code using while was somehow faster than the elixir code using for. Probably an artifact of the benchmark, though.

@mortont
Copy link
Author

mortont commented May 30, 2023

Great, thanks for the exhaustive benchmark!

From the upstream thread you mentioned, it looks like Jax has worked around the slow FFT implementation by using PocketFFT rather than XLA (Eigen) specifically for FFTs on the CPU google/jax#2952. Is that something EXLA could do?

With that said, changing the FFT size to a power of 2 and changing STFT window padding to same has fixed the performance issue for my application, so I'll close this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants