New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equivalent, faster (?) formulation #22
Comments
@tmassingham-ont Thank you for raising the issue. I just plotted the faster approximation you have presented:
And I can confirm the graph looks the same: I will try running it on a model and do a speed comparison today and will post on the updates. That's an interesting observation. I hadn't noticed that. |
I tried a very naive implementation taking the original mish : and changing it to the code/equation @tmassingham-ont has posted def f_mish_fast(input): operating on tensor.rand(50000) for 1000 batches each mish original ------------ --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ----------------------------------- softplus 80.69% 807.900us 80.69% 807.900us 807.900us 31.55% 1.888us 1.888us 1 [] (similar batches omited) Self CPU time total: 933.216ms and for mish fast mish fast ------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ----------------------------------- neg 22.72% 148.500us 22.72% 148.500us 148.500us 8.55% 1.856us 1.856us 1 [] (similar batches omited) so the "fast" formulation reduces CPU time but increases CUDA time by much more. I also believe that using torch internal functions might accelerate the I'll post my notebook if anyone is interested |
@vrandme
Would be great if you can post your notbook. |
Again, this is a very naive implementation but seemed pythonic. @digantamisra98 how would i output "nvidia-smi"? I am looking for ways to optimize my implementation by using pytorch methods and functions and / or numpy. |
Note that the output "looks" the same but are not exactly the same. (last two lines illustrate this) |
I think i should first profile the code purely in CPU for comparison, since it seems the CUDA mish would be the fastest under Cuda |
I refactored my code to make the tables prettier and CPU vs GPU comparison more clear. CPU execution on the other hand, shows a different story. Please look over my notebook for errors and benchmark it on other systems. mish vs fast_mish-refactor.zip P.S(any and all code i upload on this thread is public domain and/or whatever license @digantamisra98 sees fit) |
Hello; nice improvement with the CPU there. Just to be clear, the formula for I suspect the poor CUDA performance of "fast mish" is due to each of the cheap operations (add, mul, etc) being executed as separate CUDA kernel, rather than being fused together. What do you get if you decorate |
@tmassingham-ont Thank you for all the information. I'll do a clear profiling for jit and original implementation today. @vrandme Thanks for all the work and for uploading the notebooks. I will try to optimize it more and see if a combination of CUDA Mish and Fast Mish can be made possible based on device type. |
since @digantamisra98 did not close this issue, I decided to further refine my notebook using @tmassingham-ont 's suggestion of decorating functions with "@torch.jit.script " All i can say is... wow this definitely needs more testing and replication. mish original GPU mish fast GPU JIT mish original CPU mish fast CPU (Again this result was consistently replicated in my system. Thus mish fast code was 7~8x faster than original code, with JITified code showing even further consistent gains) I think the next step should be to see if these results could be replicated on other platforms(other CPU architectures, past and future versions of CUDA implementations etc). Also, all of my testing so far involved the forward pass. Investigating the alternate implementation(in both its JITified and plain versions) in the backward pass could also be useful. In conclusion, the very simple optimization of decorating with "@torch.jit.script" made the "fast" version of the code to show non inferior results in BOTH GPU and CPU implementations. It is also possible that this is actually faster in all implementations but this requires further investigation in other systems and with a more rigorous benchmarking setup. |
@vrandme Great work! If you can spare the time, would you mind benchmarking the following please? def faster_mish(x):
y = exp(-x)
z = 1 + 2 * y
return x * z / (z + 2* y * y) This replaces a call to |
@tmassingham-ont other people suggested the same optimization but i assumed that JIT decoration would do that automagically for me. |
Thanks. I'd have thought the same, but the presence of the |
Another day, another notebook. Anyway, I also added %timeit calls at the end of the notebook. My testing shows competitive results with even F.relu. F.relu JITified original mish Anyway I think most low hanging fruit is reached at this point. Any pointers would be helpful |
Hello. I've had a chance to run things on a GPU machine (Nvidia Xavier, so unusual balance of CPU & GPU); script and benchmarks attached. With GPU+JIT, the various fast versions of mish are performing similar to ReLU on this platform. Not bad. Tested:
The various versions of faster versions of mish are much faster on the CPU (either with or without JIT) than standard |
@tmassingham-ont wow we posted almost simultaneously. I also replicated competitive to relu performance, but we used the same tensorsize(50000) and batch numbers(10000) My version looks closest to "faster2_mish" I've been thinking of comparisons with the CUDA implementation by Thomas brandon https://github.com/thomasbrandon/mish-cuda Here, mish also shows competitive performance with RELU Maybe faster relu could replace the cuda version for at least certain platforms? Also, I will try other datatypes such as float64, where the CUDA implementation showed inferior performance compared to relu. |
This commit resolves/ works on the following issue: - #22 Additional Fixes/ Changes: - Removed Dedicated Keras - New JIT optimized Torch code for Mish - Device Based (CPU/ GPU) based JIT Mish implementation for Torch. [ci skip]
@tmassingham-ont @vrandme Please find the latest commit which addresses the discussion on this issue thread. |
Good work @digantamisra98! |
@vrandme yes, since Tom's code resulted in non optimized output for double precision format. Until it is successfully tested and verified for all data formats, I can't update the code for CUDA as of now. |
We also have another implementation on Echo for Mish here - https://github.com/digantamisra98/Echo/blob/master/echoAI/Activation/Torch/mish.py |
That's a reasonable approach. #Plain relu #JITified original mish #JITified and inlined "fast mish" (The code below also includes output from autograd profiling) Thus, at least for my implementation(I am not aware of any meaningful differences from @tmassingham-ont 's code and mine but more eyes could help) + my specific code and hardware combinations, float64(double precision) shows competitive results. I find this encouraging, since even @thomasbrandon 's code (https://github.com/thomasbrandon/mish-cuda) showed slowdowns for float64. Also, I have been mirroring this discussion on the Korean PyTorch community facebook page and there are people who are willing to benchmark/test code if you would like. If you have any code that you want testing, particularly any code that you fear of speed issues and or floating point stability issues that you would want tested on other platforms, let me know. |
Using %timeit doesn't properly profile CUDA code as kernels are launched asynchronously. Note that on relu you find equivalent performance on float64 which is clearly wrong. It should be about twice as slow when properly measured as it basically does 2x the 32-bit operations (GPUs don't have any 64-bit math hardware). You see this in my profiling results. |
@thomasbrandon thank you for your input! I always wanted to know what you thought about all this endeavor since you implemented a CUDA version.
I agree with you on that. Taking a closer look at the profiled code, F.relu F.relu float64 mish original JIT GPU mish fast GPU JIT mish fast GPU JIT float64 The only explanation that I could think of is that the CUDA code is somehow truncating the float64 into float32. I don't know how to see if this is the case or how to stop this. Also, do you think the alternative implementation laid out above provide you with insights or help in your own CUDA implementation? |
Will have to have a closer look at the profiling results to try and figure out what's going on. I haven't seen anything to indicate torch will generally do any sort of data conversion. But in general there is pretty limited support for float64 in torch as it really isn't used outside of very limited cases. Many operations aren't implemented for float64. So maybe JIT is automatically converting things to float32. In terms of CUDA and alternate implementations, there isn't really much room for improvement as performance is heavily dominated by CUDA launch overhead and memory access. You can see this in the very similar performance between ReLU and the original Mish implementation on CUDA in spite of the much greater computation needed for Mish. I also think in many networks the latency matters more than the computational demands. So while synthetic tests show some performance difference between Mish and ReLU this doesn't translate into epoch times. I suspect overall the whole network is bottlenecked by memory access (and perhaps CUDA kernel scheduling which again depends a lot on latency) not computation. More computationally demanding architectures such as EfficientNet may show some advantage of reduced computation though. |
Small update. I have shifted the newer device based implementation to Mish/Torch_dev. The reason for the same is due to concerns over speed profiling issues raised due to multiple calls of "torch.is_cuda_available" and also that sometimes even if my device is on CUDA, I would want certain tensor operators to be on CPU level to save memory/ time from shifting the tensor to GPU itself. Additionally, I will try to put up a check not for the device but on which device the data is processed/ stored. |
Okay so what is the current verdict? |
I think @thomasbrandon can provide a more concrete conclusion. |
Note that since the faster version makes use of temporary storage, it can result in out of memory (OOM) errors -- as I've been getting until I realized this* -- if you don't decrease your batch size. Running the original, slower version incurs no such issue. *This was using Keras 2.1.3 & TF 1.4.1, with CUDA. These are old versions of these packages; perhaps newer versions have some clever way of not producing the issue I had. So currently I made a flag for 'fast' and turned it off: def mish(x, fast=False):
if fast: # faster but requires extra storage
y = K.exp(-x) # old TF has no tf.math, but keras backend works
z = 1 + 2 * y
return x * z / (z + 2* y * y)
return x * K.tanh(K.softplus(x)) |
@drscotthawley The faster version was more intended for CPU. Did you face OOM errors in CPU too? |
@digantamisra98 Ah, no. I only tried it on GPU -- and a fairly small one at that: GTX 1080, 8GB of VRAM. |
@drscotthawley ah, alright. Thanks for your inputs :) |
Sorry, haven't had much time to look at this recently. A concrete conclusion as to what exactly?
You can use a Though, as noted in my last reply I'm not really sure there's much scope for improvement in GPU performance as that's limited more by memory access and CUDA kernel launch overhead than computational complexity. This is shown by the pretty similar performance of my CUDA version of Mish and the JIT Autograd version of Swish as compared to ReLU. The similar performance is in spite of Mish/Swish requiring fairly complex computations ( |
@thomasbrandon so conclusion would be it's more than the kernel launch and CUDA overheads rather than the computational complexity of the native functional definition of each activation function (Swish and Mish) that is preventing it from being more cheap than what you could achieve by Mish-CUDA. |
Hello, thanks for the great work. Using the exponential identity for tanh, you can remove two of the transcendental operations (exp, log) and get what, hopefully, should be a faster implementation.
Since
you can express Mish as:
or equivalently (to avoid overflow when x is large)
NB: With a little tweak, there is an interesting connection to the GELU approximated with a logistic distribution ("Logistic Error Linear Unit"?) (i.e. Swish)
c.f. the approximation
x \sigma(1.702 x)
from the GELU paper.The text was updated successfully, but these errors were encountered: