Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EXLA FFT much slower than Torchx on CPU #1234

Closed
mortont opened this issue May 30, 2023 · 7 comments
Closed

EXLA FFT much slower than Torchx on CPU #1234

mortont opened this issue May 30, 2023 · 7 comments
Labels
area:exla Applies to EXLA

Comments

@mortont
Copy link

mortont commented May 30, 2023

From the discussion elixir-nx/nx_signal#14, it looks like the EXLA FFT implementation is ~7x slower than Torchx when run on CPU. tf.signal has run into the same limitations tensorflow/tensorflow#6541 and Jax has worked around it by using PocketFFT: google/jax#2952. Would EXLA be able to do something similar?

Here's the (modified) benchmark:

Mix.install([
  :benchee,
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:exla, github: "elixir-nx/nx", sparse: "exla"},
  {:torchx, github: "elixir-nx/nx", sparse: "torchx"}
])

Benchee.run(
  %{
    "EXLA fft" => &Enum.each(1..10, fn _ -> Nx.fft(&1.exla) end),
    "Torchx fft" => &Enum.each(1..10, fn _ -> Nx.fft(&1.torchx) end)
  },
  inputs: %{
    "1x400" => %{
      exla: Nx.iota({1, 400}, backend: {EXLA.Backend, client: :host}),
      torchx: Nx.iota({1, 400}, backend: {Torchx.Backend, device: :cpu})
    },
    "1x512" => %{
      exla: Nx.iota({1, 512}, backend: {EXLA.Backend, client: :host}),
      torchx: Nx.iota({1, 512}, backend: {Torchx.Backend, device: :cpu})
    },
    "3kx400" => %{
      exla: Nx.iota({3000, 400}, backend: {EXLA.Backend, client: :host}),
      torchx: Nx.iota({3000, 400}, backend: {Torchx.Backend, device: :cpu})
    },
    "3kx512" => %{
      exla: Nx.iota({3000, 512}, backend: {EXLA.Backend, client: :host}),
      torchx: Nx.iota({3000, 512}, backend: {Torchx.Backend, device: :cpu})
    },
    "4096x512" => %{
      exla: Nx.iota({4096, 512}, backend: {EXLA.Backend, client: :host}),
      torchx: Nx.iota({4096, 512}, backend: {Torchx.Backend, device: :cpu})
    }
  },
  warmup: 5,
  time: 10
)

nil

results:

Name                 ips        average  deviation         median         99th %
Torchx fft        9.28 K      107.76 μs    ±24.21%       98.73 μs      219.44 μs
EXLA fft          1.91 K      524.48 μs     ±3.29%      522.91 μs      567.09 μs

Comparison:
Torchx fft        9.28 K
EXLA fft          1.91 K - 4.87x slower +416.72 μs

##### With input 1x512 #####
Name                 ips        average  deviation         median         99th %
Torchx fft        8.83 K      113.31 μs    ±19.13%      107.08 μs      218.46 μs
EXLA fft          5.41 K      185.00 μs    ±11.96%      183.52 μs      242.78 μs

Comparison:
Torchx fft        8.83 K
EXLA fft          5.41 K - 1.63x slower +71.69 μs

##### With input 3kx400 #####
Name                 ips        average  deviation         median         99th %
Torchx fft         40.66       0.0246 s    ±14.40%       0.0242 s       0.0340 s
EXLA fft            0.95         1.06 s     ±0.32%         1.06 s         1.06 s

Comparison:
Torchx fft         40.66
EXLA fft            0.95 - 42.91x slower +1.03 s

##### With input 3kx512 #####
Name                 ips        average  deviation         median         99th %
Torchx fft         36.06       27.73 ms    ±14.71%       27.41 ms       43.30 ms
EXLA fft            5.22      191.66 ms     ±1.29%      191.50 ms      198.61 ms

Comparison:
Torchx fft         36.06
EXLA fft            5.22 - 6.91x slower +163.93 ms

##### With input 4096x512 #####
Name                 ips        average  deviation         median         99th %
Torchx fft         23.34       42.84 ms    ±11.43%       41.46 ms       61.73 ms
EXLA fft            3.60      277.86 ms     ±6.52%      273.19 ms      328.88 ms

Comparison:
Torchx fft         23.34
EXLA fft            3.60 - 6.49x slower +235.02 ms
@seanmor5
Copy link
Collaborator

Yes we can do something similar, I will look into it

@seanmor5
Copy link
Collaborator

seanmor5 commented Jun 2, 2023

JAX now uses ducc fft per: google/jax#12122

ducc is the successor to pocketfft, so we can mirror this as well. The challenge is determining where/how to build the third party repo into EXLA. My opinion is we can add the ducc specific library code upstream to the XLA precompiled library. It will require some changes to the build process (and I might have to patch tensorflow :/)

Implementing the custom operator once we figure out the dependency stuff is easy

@polvalente
Copy link
Contributor

@seanmor5 I agree with adding the custom FFT code to the precompiled XLA library.

Why do you think we'd have to patch TF? Can't we use an XLA::CustomCall to the job?

@josevalim josevalim added the area:exla Applies to EXLA label Jun 20, 2023
@josevalim
Copy link
Collaborator

Good news: google/jax@390022a

This will be solved in future XLA versions. :)

@jonatanklosko
Copy link
Member

jonatanklosko commented Nov 10, 2023

Here are the results with EXLA main and the new XLA (run on M1):

##### With input 1x400 #####
Name                 ips        average  deviation         median         99th %
Torchx fft        5.65 K      177.00 μs    ±11.12%      174.91 μs      240.33 μs
EXLA fft          4.16 K      240.23 μs    ±13.35%      234.08 μs      367.75 μs

Comparison: 
Torchx fft        5.65 K
EXLA fft          4.16 K - 1.36x slower +63.23 μs

##### With input 1x512 #####
Name                 ips        average  deviation         median         99th %
Torchx fft        5.50 K      181.83 μs    ±12.76%      178.79 μs      274.83 μs
EXLA fft          4.05 K      246.90 μs    ±13.46%      238.71 μs      380.71 μs

Comparison: 
Torchx fft        5.50 K
EXLA fft          4.05 K - 1.36x slower +65.07 μs

##### With input 3kx400 #####
Name                 ips        average  deviation         median         99th %
EXLA fft           82.69       12.09 ms     ±8.30%       12.14 ms       15.20 ms
Torchx fft         46.54       21.49 ms     ±3.80%       21.31 ms       23.72 ms

Comparison: 
EXLA fft           82.69
Torchx fft         46.54 - 1.78x slower +9.39 ms

##### With input 3kx512 #####
Name                 ips        average  deviation         median         99th %
EXLA fft           60.02       16.66 ms     ±6.51%       16.65 ms       19.42 ms
Torchx fft         34.56       28.94 ms     ±5.84%       28.79 ms       34.60 ms

Comparison: 
EXLA fft           60.02
Torchx fft         34.56 - 1.74x slower +12.28 ms

##### With input 4096x512 #####
Name                 ips        average  deviation         median         99th %
EXLA fft           44.45       22.50 ms     ±2.54%       22.45 ms       24.88 ms
Torchx fft         25.69       38.93 ms     ±7.75%       38.64 ms       52.45 ms

Comparison: 
EXLA fft           44.45
Torchx fft         25.69 - 1.73x slower +16.43 ms

@josevalim
Copy link
Collaborator

YAAAAAAAAAS

@mortont
Copy link
Author

mortont commented Nov 10, 2023

Sweet! Should speed up the Bumblebee whisper model too, last I checked that was using size 400 FFTs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:exla Applies to EXLA
Projects
None yet
Development

No branches or pull requests

5 participants