Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Mish - 1.5x slower training, 2.9X more memory usage vs LeakyReLU(0.1) #18

Closed
glenn-jocher opened this issue Nov 17, 2019 · 6 comments
Labels
question Further information is requested

Comments

@glenn-jocher
Copy link

glenn-jocher commented Nov 17, 2019

Hi, thanks for this interesting new activation function. I've tested it with YOLOv3-SPP on a V100 from https://github.com/ultralytics/yolov3 and have mixed feedback. The performance improves slightly, but the training time is much slower and the GPU memory requirements are much higher vs LeakyReLU(0.1). Any suggestions on how to improve speed/memory in PyTorch? Thanks!

From AlexeyAB/darknet#3114 (comment):

mAP@0.5 mAP0.5:0.95 GPU memory Epoch time
LeakyReLU(0.1) 48.9 29.6 4.0G 31min
Mish() 50.9 31.2 11.1G 46min
class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        return x.mul_(torch.sigmoid(x))


class Mish(nn.Module):  # https://github.com/digantamisra98/Mish
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.mul_(F.softplus(x).tanh())
@glenn-jocher glenn-jocher changed the title PyTorch - 1.5x slower training, 2.9X more memory usage vs LeakyReLU(0.1) PyTorch Mish - 1.5x slower training, 2.9X more memory usage vs LeakyReLU(0.1) Nov 17, 2019
@digantamisra98 digantamisra98 added the question Further information is requested label Nov 18, 2019
@digantamisra98
Copy link
Owner

Hey @glenn-jocher
Thanks for raising the issue. I'm aware of the concerns of Mish being slower and more computationally expensive in regards to other activation functions like Leaky ReLU. There is one CUDA based implementation which you can try as it has shown considerable improvements over that of the implementation you're using.

Following is a table which shows the improvement in speed profiling for Mish:
x

The implementation can be found here - https://github.com/thomasbrandon/mish-cuda

Though this doesn't work well with Double Precision just to say. Let me know how the profiling looks using Mish CUDA.

I'm working on optimizing Mish further and will keep you posted on the progress.

@AlexeyAB
Copy link

@digantamisra98 Hi, why you didn't integrate this ttps://github.com/thomasbrandon/mish-cuda CUDA-mish-implementation to your repository?

@digantamisra98
Copy link
Owner

@AlexeyAB I have added it in the Readme. To give him credits I kept his integration as his own repository and added only the baseline implementation in my own. Also his implementation was constructed recently based on a Fast.ai discussion forum long after Mish was already there.

@AlexeyAB
Copy link

Just wondering how optimal it is to specify only Forward-implementation function without Backward-implementation? https://github.com/ultralytics/yolov3/blob/7ebb7d131078bd8357aeddf23fd68414d1593612/models.py#L124-L131

@digantamisra98
Copy link
Owner

@AlexeyAB It isn't optimal. It's rather a quick start. For optimal performance, I will obviously go for both Forward and Backward Pass Implementation which has been followed by Mish CUDA, deeplearning4j, TensorFlow Addons for Mish including your implementation in darknet.

@digantamisra98
Copy link
Owner

Closing this issue due to inactivity. Feel free to re-open if not resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants