-
Notifications
You must be signed in to change notification settings - Fork 25.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Add modeling_xxx_fusion.py
to support kernel fusion
#13845
Comments
Looks like an awesome plan, @hyunwoongko! So far your RFC looks excellent to me. I'd just suggest |
@stas00 You're right. It's because I'm not good at English (I didn't know there was a word |
In software we often create new words anyway, so as long as the composite of 2 words makes sense it works for our purposes. Latin-based languages all use a combination of a root with prefx/postfix, so if the word you want is not there already - create one ;) |
Review of Fused Kernels for transformer
If you find other fused kernels, please let me know here. I'll test and record them. :) 1. Module-level KernelsModule-level Kernels are fused kernels for independent operation sets like
I will also review layer-level kernels during this week. ;) |
the tricky part is that it's tied to torch's version, e.g. I had to use pytorch-nightly to get it to work. Otherwise you can only use pt-1.10.0 which is not the latest release (1.10.1 is). In other words this would be quite complex for users to set up. To keep up with details please subscribe to: #15264 |
@stas00 if you're interested, we have other fused layers in https://github.com/facebookresearch/xformers/tree/main/xformers/triton. The only dependency is triton, which is one pip install away (but limited to Cuda and recent enough GPUs). Just FYI, feel free to discard |
@stas00 We'll be cutting a branch that works with PyTorch 1.11.0, and to be honest, I don't think it'd be that hard to cut a release for 1.10.1 now either. So, I think the issues with user setup are not that difficult to resolve. |
Thank you very much, Benjamin! I will tag @hyunwoongko - who is currently researching various fused kernels for him to see if these fit! He has probably already looked there/adopted some. |
sounds good, Horace - let's then work with pt-nightly for now and then by the time we have something to show to users we will make sure they will have an easy pass to follow. Most likely pt-1.11.0 will be out by that time as you're saying. Thank you! |
This looks really useful. Is there a more recent update on this somewhere? |
Introduction
I am an engineer currently working on 3D model parallelism for transformers. When the tensor model parallelism (#13726) is done, I am going to introduce kernel fusion feature to transformers.
For this, I want to create a new modeling file called
modeling_xxx_fusion.py
. This work is currently being discussed with @stas00 and @RezaYazdaniAminabadi (DeepSpeed team).Kernel fusion API
Implementation
The internal module of each model will be re-implemented using kernel fusion method, and the existed module will be replaced with the fused module. The following example is an example of
BertOutput(nn.Module)
.When the user calls the
fuse_modules()
method, the kernel fusion engine findsBertOutput
and replaces it withFusedBertOutput
. and user callsfused_layers
method, engine findsBertLayer
and replcases it withFusedBertLayer
. This is the method thatparallelformers
parallelized transformers models flexibly, and thedeepspeed
also supports kernel fusion in this way.However, the current version of
deepspeed
fuses the entire transformer layer, so the supported models are very limited. For example, bigbird requires random attention mechanism. in this case random attention must be implemented in the custom cuda kernel. However, because the number of models is so large, it is impossible to implement them all. So I propose a flexible way to fuse the kernel on a per-function. This is a strategy of triage. The area that can be fused performs fusion, and the area that can not be fused uses the torch's default module.This is a draft. The API can be changed at any time. I look forward to feedback. I'm going to show you this soon with a framework I'm making. (Like parallelformers, we will pre-open the repositories on our side and merge them later on transformers and deepspeed.)
cc. @stas00 @RezaYazdaniAminabadi @Sylvain
The text was updated successfully, but these errors were encountered: