Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalent, faster (?) formulation #22

Closed
tmassingham-ont opened this issue Jan 10, 2020 · 34 comments
Closed

Equivalent, faster (?) formulation #22

tmassingham-ont opened this issue Jan 10, 2020 · 34 comments
Labels
enhancement New feature or request

Comments

@tmassingham-ont
Copy link

tmassingham-ont commented Jan 10, 2020

Hello, thanks for the great work. Using the exponential identity for tanh, you can remove two of the transcendental operations (exp, log) and get what, hopefully, should be a faster implementation.

Since

$$\tanh(x) = (e^2x - 1) / (e^2x + 1)$$

you can express Mish as:

$$y = e^x mish(x) = x y (y + 2) / (y^2 + 2 y + 2)$$

or equivalently (to avoid overflow when x is large)

$$y = e^-x mish(x) = x (1 + 2 y) / (1 + 2 y + 2 y^2)$$

NB: With a little tweak, there is an interesting connection to the GELU approximated with a logistic distribution ("Logistic Error Linear Unit"?) (i.e. Swish)

$$x \tanh(0.5 \log( 1 + e^x) ) = x \sigma(x - \log 2)$$

c.f. the approximation x \sigma(1.702 x) from the GELU paper.

@digantamisra98
Copy link
Owner

@tmassingham-ont Thank you for raising the issue. I just plotted the faster approximation you have presented:

y = e^-x
mish(x) = x (1 + 2 y) / (1 + 2 y + 2 y^2) 

And I can confirm the graph looks the same:
mish_approx

I will try running it on a model and do a speed comparison today and will post on the updates.

That's an interesting observation. I hadn't noticed that.

@digantamisra98 digantamisra98 added the enhancement New feature or request label Jan 12, 2020
@chris-ha458
Copy link

I tried a very naive implementation taking the original mish :
def f_mish(input):
'''
Applies the mish function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
'''
return input * torch.tanh(F.softplus(input))

and changing it to the code/equation @tmassingham-ont has posted
to get

def f_mish_fast(input):
return input * (1+2*y) / (1 + 2 * y + 2 * (y**2))

operating on tensor.rand(50000) for 1000 batches each
"with torch.autograd.profiler.profile(use_cuda=True) as prof"
to profile the code yielded the following results

mish original ------------ --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- -----------------------------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls Input Shapes


softplus 80.69% 807.900us 80.69% 807.900us 807.900us 31.55% 1.888us 1.888us 1 []
tanh 15.76% 157.800us 15.76% 157.800us 157.800us 34.23% 2.048us 2.048us 1 []
mul 3.55% 35.500us 3.55% 35.500us 35.500us 34.23% 2.048us 2.048us 1 []


(similar batches omited)

Self CPU time total: 933.216ms
CUDA time total: 5.607ms

and for mish fast

mish fast ------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- -----------------------------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls Input Shapes


neg 22.72% 148.500us 22.72% 148.500us 148.500us 8.55% 1.856us 1.856us 1 []
exp 16.36% 106.900us 16.36% 106.900us 106.900us 8.55% 1.856us 1.856us 1 []
mul 5.05% 33.000us 5.05% 33.000us 33.000us 9.44% 2.048us 2.048us 1 []
add 11.32% 74.000us 11.32% 74.000us 74.000us 9.44% 2.048us 2.048us 1 []
mul 2.00% 13.100us 2.00% 13.100us 13.100us 9.29% 2.016us 2.016us 1 []
mul 3.34% 21.800us 3.34% 21.800us 21.800us 9.00% 1.952us 1.952us 1 []
add 13.56% 88.600us 13.56% 88.600us 88.600us 8.55% 1.856us 1.856us 1 []
pow 2.63% 17.200us 2.63% 17.200us 17.200us 8.85% 1.920us 1.920us 1 []
mul 19.05% 124.500us 19.05% 124.500us 124.500us 9.44% 2.048us 2.048us 1 []
add 1.99% 13.000us 1.99% 13.000us 13.000us 9.44% 2.048us 2.048us 1 []
div 1.97% 12.900us 1.97% 12.900us 12.900us 9.44% 2.048us 2.048us 1 []


(similar batches omited)
Self CPU time total: 167.891ms
CUDA time total: 21.021ms

so the "fast" formulation reduces CPU time but increases CUDA time by much more.
I am using a 2080TI which seems to be able to compute transcendental functions very quickly.
On other platforms without hardware accelerated fast transcendental function this might be useful.

I also believe that using torch internal functions might accelerate the
(input * (1+2*y) / (1 + 2 * y + 2 * (y**2))) part could be helpful
In the end i think it'll be a trade off between transcendental functions vs multiply adds.

I'll post my notebook if anyone is interested

@digantamisra98
Copy link
Owner

@vrandme
How was the memory consumption on your GPU like? Can you post the output for nvidia-smi ?
Additionally, based on your analysis, we could use a mixture of the faster approximation for CPU by @tmassingham-ont and the CUDA Mish by Tom.
For instance, it would be something like this:

device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
if device == 'cuda':
    # use mish cuda
else:
    # use cpu mish

Would be great if you can post your notbook.
Thanks!

@chris-ha458
Copy link

Again, this is a very naive implementation but seemed pythonic.

mish vs fast_mish.zip

@digantamisra98 how would i output "nvidia-smi"?

I am looking for ways to optimize my implementation by using pytorch methods and functions and / or numpy.

@chris-ha458
Copy link

Note that the output "looks" the same but are not exactly the same. (last two lines illustrate this)
I attribute this to floating point issues but I am not a floating point expert.

@chris-ha458
Copy link

I think i should first profile the code purely in CPU for comparison, since it seems the CUDA mish would be the fastest under Cuda

@chris-ha458
Copy link

chris-ha458 commented Jan 12, 2020

I refactored my code to make the tables prettier and CPU vs GPU comparison more clear.
On GPU, it is considerably slower, but I don't think there are many gains there since the original code, is faster on GPU and i bet the cuda implementation is faster still.
I would guess that the code as is, with a balanced use of both transcendental functions and multiply adds, utilizes GPU resources better.

CPU execution on the other hand, shows a different story.
Even with my naive implementation without any optimizations, shows almost 5x speed up.
Self CPU time total: 860.910ms vs Self CPU time total: 147.185ms
I am sure that even simple optimizations to my code will yield further gains.

Please look over my notebook for errors and benchmark it on other systems.
My system(Ryzen 3950x, RT2080Ti)

mish vs fast_mish-refactor.zip

P.S(any and all code i upload on this thread is public domain and/or whatever license @digantamisra98 sees fit)

@tmassingham-ont
Copy link
Author

Hello; nice improvement with the CPU there. Just to be clear, the formula for mish I posted are mathematically identities, not approximations, and so any difference from the original must be due to numerical issues.

I suspect the poor CUDA performance of "fast mish" is due to each of the cheap operations (add, mul, etc) being executed as separate CUDA kernel, rather than being fused together. What do you get if you decorate Mish.Torch.functional.mish with @torch.jit.script?

@digantamisra98
Copy link
Owner

@tmassingham-ont Thank you for all the information. I'll do a clear profiling for jit and original implementation today. @vrandme Thanks for all the work and for uploading the notebooks. I will try to optimize it more and see if a combination of CUDA Mish and Fast Mish can be made possible based on device type.

@chris-ha458
Copy link

since @digantamisra98 did not close this issue, I decided to further refine my notebook using @tmassingham-ont 's suggestion of decorating functions with "@torch.jit.script "

All i can say is... wow this definitely needs more testing and replication.

mish original GPU
Self CPU time total: 207.180ms
CUDA time total: 397.442ms
(original code and JIT decorated original code shows similar results. sometimes one comes on top, sometimes the other. Further research might show a more meaningful difference, as would profiling. But considering the existence of well optimized cuda code, i didnt think it was useful for me to look into it.)

mish fast GPU JIT
Self CPU time total: 214.289ms
CUDA time total: 368.485ms
Not significantly faster but my in my testing showed consistently lower results compared to the above original GPU code

mish original CPU
Self CPU time total: 868.308ms
(the above points regarding original vs JITified original wrt GPU code applies here)

mish fast CPU
Self CPU time total: 125.045ms
mish fast CPU JIT
Self CPU time total: 97.992ms

(Again this result was consistently replicated in my system. Thus mish fast code was 7~8x faster than original code, with JITified code showing even further consistent gains)

I think the next step should be to see if these results could be replicated on other platforms(other CPU architectures, past and future versions of CUDA implementations etc).
I do not have access to other baremetal systems but maybe trying my notebook on kaggle/collab and or other paid and free platforms could be a reachable short term goal.

Also, all of my testing so far involved the forward pass. Investigating the alternate implementation(in both its JITified and plain versions) in the backward pass could also be useful.

In conclusion, the very simple optimization of decorating with "@torch.jit.script" made the "fast" version of the code to show non inferior results in BOTH GPU and CPU implementations. It is also possible that this is actually faster in all implementations but this requires further investigation in other systems and with a more rigorous benchmarking setup.

mish vs fast_mish-jit-script.zip

@tmassingham-ont
Copy link
Author

tmassingham-ont commented Jan 13, 2020

@vrandme Great work! If you can spare the time, would you mind benchmarking the following please?

def faster_mish(x):
    y = exp(-x)
    z = 1 + 2 * y
    return x * z / (z + 2* y * y)

This replaces a call to pow with a explicit multiplication and factors out a common term. A quick test on my local CPU suggests it is about 20% faster than fast_mish.

@chris-ha458
Copy link

@tmassingham-ont other people suggested the same optimization but i assumed that JIT decoration would do that automagically for me.
I'll try the change later today if nobody beats me to it.

@tmassingham-ont
Copy link
Author

@tmassingham-ont other people suggested the same optimization but i assumed that JIT decoration would do that automagically for me.
I'll try the change later today if nobody beats me to it.

Thanks. I'd have thought the same, but the presence of the pow call in the profiling suggests the JIT isn't clever enough to make this optimisation. There's a further variant x * (1 + 2 * y) / (1 + y * (2 + 2 * y)) that makes the three fused multiply-add operations explicit, but I'm not observing as much as an improvement using this.

@chris-ha458
Copy link

Another day, another notebook.
Turns out, @tmassingham-ont 's was right refactored inlined faster mish is further faster
I removed all other versions except (Original mish JITified, Faster mish JITified)
This made me think. We would you NOT want JITified code?

Anyway, I also added %timeit calls at the end of the notebook.

My testing shows competitive results with even F.relu.
I attribute to the specific benchmarking conditions and relu and or mish might be bottlenecked by other parts of the code or computation (memory operations etc).

F.relu
54.8 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) (CPU)
50.6 µs ± 2.59 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

JITified original mish
896 µs ± 28.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) (CPU)
114 µs ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
JITified and inlined "fast mish"
130 µs ± 1.83 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)(CPU)
52.7 µs ± 6.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Anyway I think most low hanging fruit is reached at this point.
I would like to move on to testing backward passes, but I don't know how to approach this.
Would I need to make a full model? Could this be tested as isolated as the forward pass was?

Any pointers would be helpful

mish vs fast_mish-jit_v2.zip

@tmassingham-ont
Copy link
Author

Hello. I've had a chance to run things on a GPU machine (Nvidia Xavier, so unusual balance of CPU & GPU); script and benchmarks attached. With GPU+JIT, the various fast versions of mish are performing similar to ReLU on this platform. Not bad.

Tested:

  • fast_mish, original proposal
  • faster_mish, FMA made explicit
  • faster2_mish, common factor lifted + don't use pow to square

The various versions of faster versions of mish are much faster on the CPU (either with or without JIT) than standard mish. When run on the GPU, they are only faster when the JIT is enabled.
The faster2_mish is quicker on the CPU than fast_mish but about the same speed for the GPU+JIT.

mish.zip
mish_benchmark.txt

@chris-ha458
Copy link

@tmassingham-ont wow we posted almost simultaneously.

I also replicated competitive to relu performance, but we used the same tensorsize(50000) and batch numbers(10000)
But we have different platforms(RTX 2080Ti vs Xavier) so thats a meaningful datapoint.

My version looks closest to "faster2_mish"
I like it the most because its easier to read for me.

I've been thinking of comparisons with the CUDA implementation by Thomas brandon

https://github.com/thomasbrandon/mish-cuda

Here, mish also shows competitive performance with RELU
(I tried looking over the profiling code to understand benchmarking conditions but I failed to fully comprehend the code)
https://github.com/thomasbrandon/mish-cuda/blob/master/test/perftest.py

Maybe faster relu could replace the cuda version for at least certain platforms?

Also, I will try other datatypes such as float64, where the CUDA implementation showed inferior performance compared to relu.

digantamisra98 added a commit that referenced this issue Jan 15, 2020
This commit resolves/ works on the following issue:
- #22

Additional Fixes/ Changes:
- Removed Dedicated Keras
- New JIT optimized Torch code for Mish
- Device Based (CPU/ GPU) based JIT Mish implementation for Torch.

[ci skip]
@digantamisra98
Copy link
Owner

@tmassingham-ont @vrandme Please find the latest commit which addresses the discussion on this issue thread.

@chris-ha458
Copy link

Good work @digantamisra98!
However, considering that testing by me and @tmassingham-ont showed that the faster mish is better even for GPUs, is there a reason you maintained the original mish for CUDA? Do you feel that more extensive testing should be done before completely changing the code?

@digantamisra98
Copy link
Owner

@vrandme yes, since Tom's code resulted in non optimized output for double precision format. Until it is successfully tested and verified for all data formats, I can't update the code for CUDA as of now.

@digantamisra98
Copy link
Owner

We also have another implementation on Echo for Mish here - https://github.com/digantamisra98/Echo/blob/master/echoAI/Activation/Torch/mish.py

@chris-ha458
Copy link

Tom's code resulted in non optimized output for double precision format. Until it is successfully tested and verified for all data formats, I can't update the code for CUDA as of now.

That's a reasonable approach.
If this helps, I would like to show you my code where double precision wasn't an issue.

#Plain relu
%timeit F.relu(testdata_CPU)
%timeit F.relu(testdata)
%timeit F.relu(testdata_float64)
51.2 µs ± 934 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
53.1 µs ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
56.8 µs ± 8.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

#JITified original mish
%timeit f_mish_JIT(testdata_CPU)
%timeit f_mish_JIT(testdata)
%timeit f_mish_JIT(testdata_float64)
904 µs ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
98 µs ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
87.2 µs ± 11.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

#JITified and inlined "fast mish"
%timeit f_mish_fast_JIT(testdata_CPU)
%timeit f_mish_fast_JIT(testdata)
%timeit f_mish_fast_JIT(testdata_float64)
111 µs ± 2.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
64.5 µs ± 9.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
**48.1 µs ± 4.15 µs per loop (**mean ± std. dev. of 7 runs, 10000 loops each)

(The code below also includes output from autograd profiling)

Thus, at least for my implementation(I am not aware of any meaningful differences from @tmassingham-ont 's code and mine but more eyes could help) + my specific code and hardware combinations, float64(double precision) shows competitive results.

I find this encouraging, since even @thomasbrandon 's code (https://github.com/thomasbrandon/mish-cuda) showed slowdowns for float64.

Also, I have been mirroring this discussion on the Korean PyTorch community facebook page and there are people who are willing to benchmark/test code if you would like.
One person even offered Titan X, Xp, RTX, Tesla P100, V100 systems for testing.

If you have any code that you want testing, particularly any code that you fear of speed issues and or floating point stability issues that you would want tested on other platforms, let me know.

mish vs fast_mish-jit_v3.zip

@thomasbrandon
Copy link

thomasbrandon commented Jan 15, 2020

Using %timeit doesn't properly profile CUDA code as kernels are launched asynchronously. Note that on relu you find equivalent performance on float64 which is clearly wrong. It should be about twice as slow when properly measured as it basically does 2x the 32-bit operations (GPUs don't have any 64-bit math hardware). You see this in my profiling results.
Using %time and repeating the operation a few times actually tends to give fairly reasonable results (I guess the repeated operations make it wait for at least some to complete). Or for more accurate results you can use either event based timing as I do in perftest.py, call torch.cuda.synchronize() to wait for outstanding CUDA operations, or use torch.autograd.profile as @tmassingham-ont did.

@chris-ha458
Copy link

chris-ha458 commented Jan 15, 2020

@thomasbrandon thank you for your input! I always wanted to know what you thought about all this endeavor since you implemented a CUDA version.

It should be about twice as slow when properly measured as it basically does 2x the 32-bit operations (GPUs don't have any 64-bit math hardware). You see this in my profiling results.

I agree with you on that. Taking a closer look at the profiled code,

F.relu
Self CPU time total: 83.589ms
CUDA time total: 147.860ms

F.relu float64
Self CPU time total: 43.811ms
CUDA time total: 101.819ms

mish original JIT GPU
Self CPU time total: 256.652ms
CUDA time total: 456.403ms

mish fast GPU JIT
Self CPU time total: 200.933ms
CUDA time total: 353.591ms

mish fast GPU JIT float64
Self CPU time total: 205.205ms
CUDA time total: 343.209ms

The only explanation that I could think of is that the CUDA code is somehow truncating the float64 into float32. I don't know how to see if this is the case or how to stop this.
I'd appreciate any pointers.

Also, do you think the alternative implementation laid out above provide you with insights or help in your own CUDA implementation?
mish vs fast_mish-jit_v4.zip

@chris-ha458
Copy link

At least the outputs of the from the activations show this results so if it is silently changing into float32, it is changing back into float64 at some point
float64

@thomasbrandon
Copy link

Will have to have a closer look at the profiling results to try and figure out what's going on. I haven't seen anything to indicate torch will generally do any sort of data conversion. But in general there is pretty limited support for float64 in torch as it really isn't used outside of very limited cases. Many operations aren't implemented for float64. So maybe JIT is automatically converting things to float32.

In terms of CUDA and alternate implementations, there isn't really much room for improvement as performance is heavily dominated by CUDA launch overhead and memory access. You can see this in the very similar performance between ReLU and the original Mish implementation on CUDA in spite of the much greater computation needed for Mish. I also think in many networks the latency matters more than the computational demands. So while synthetic tests show some performance difference between Mish and ReLU this doesn't translate into epoch times. I suspect overall the whole network is bottlenecked by memory access (and perhaps CUDA kernel scheduling which again depends a lot on latency) not computation. More computationally demanding architectures such as EfficientNet may show some advantage of reduced computation though.
So alternative implementations are not likely to provide much advantage. More promising is allowing a fully fused JIT implementation and so avoiding the need for a custom CUDA version. The softplus in the original Mish implementation cannot be fused so the JIT version is a bit slower than the CUDA version (whereas for Swish, which I also did in CUDA, JIT performs the same as it fully fuses). So a fully fusable JIT version should eliminate the need for a custom CUDA kernel.

@digantamisra98
Copy link
Owner

Small update. I have shifted the newer device based implementation to Mish/Torch_dev. The reason for the same is due to concerns over speed profiling issues raised due to multiple calls of "torch.is_cuda_available" and also that sometimes even if my device is on CUDA, I would want certain tensor operators to be on CPU level to save memory/ time from shifting the tensor to GPU itself. Additionally, I will try to put up a check not for the device but on which device the data is processed/ stored.

@DonaldTsang
Copy link

Okay so what is the current verdict?

@digantamisra98
Copy link
Owner

I think @thomasbrandon can provide a more concrete conclusion.

@drscotthawley
Copy link

drscotthawley commented Feb 15, 2020

Note that since the faster version makes use of temporary storage, it can result in out of memory (OOM) errors -- as I've been getting until I realized this* -- if you don't decrease your batch size. Running the original, slower version incurs no such issue.

*This was using Keras 2.1.3 & TF 1.4.1, with CUDA. These are old versions of these packages; perhaps newer versions have some clever way of not producing the issue I had.

So currently I made a flag for 'fast' and turned it off:

def mish(x, fast=False):
    if fast:                    # faster but requires extra storage
        y = K.exp(-x)           # old TF has no tf.math, but keras backend works
        z = 1 + 2 * y
        return x * z / (z + 2* y * y)
    return x * K.tanh(K.softplus(x))  

@digantamisra98
Copy link
Owner

@drscotthawley The faster version was more intended for CPU. Did you face OOM errors in CPU too?

@drscotthawley
Copy link

@digantamisra98 Ah, no. I only tried it on GPU -- and a fairly small one at that: GTX 1080, 8GB of VRAM.
Thanks.

@digantamisra98
Copy link
Owner

@drscotthawley ah, alright. Thanks for your inputs :)

@thomasbrandon
Copy link

thomasbrandon commented Feb 25, 2020

@digantamisra98:

I think @thomasbrandon can provide a more concrete conclusion.

Sorry, haven't had much time to look at this recently. A concrete conclusion as to what exactly?

@drscotthawley :

Note that since the faster version makes use of temporary storage, it can result in out of memory (OOM) errors

You can use a torch.autograd.Function to eliminate the use of temporary variables. But you then need to take care of the backward calculations yourself. You might also try using @torch.jit.script as that might eliminate the temporary storage while still doing an automatic backward (though not at all sure this will eliminate temporaries). See here for basic Autograd function implementation and memory comparisons and here for how to combine JIT and Autograd functions.

Though, as noted in my last reply I'm not really sure there's much scope for improvement in GPU performance as that's limited more by memory access and CUDA kernel launch overhead than computational complexity. This is shown by the pretty similar performance of my CUDA version of Mish and the JIT Autograd version of Swish as compared to ReLU. The similar performance is in spite of Mish/Swish requiring fairly complex computations (exps and logs) while ReLU is just an x if x > 0 else 0.

@digantamisra98
Copy link
Owner

@thomasbrandon so conclusion would be it's more than the kernel launch and CUDA overheads rather than the computational complexity of the native functional definition of each activation function (Swish and Mish) that is preventing it from being more cheap than what you could achieve by Mish-CUDA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants