-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matmul refactor using only cuBLASLt + GELU Fusion #653
Conversation
0073f5f
to
6e29dd0
Compare
@@ -121,6 +121,7 @@ int main(int argc, char *argv[]) { | |||
if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash | |||
if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); } | |||
else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); } | |||
else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.gelu_fusion = atoi(argv[i+1]); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mildly scary indexing into [2] here
Should we add
if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { error_usage(); } // must be -x[y] (one dash, one or two letters)
like in train code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, hadn't noticed that was missing in test! Adding it now.
(it's a bit non-obvious that only works because strlen==2 means [2] is '\0', and that we'd need a better check if we wanted to use argv[i][3], but probably good enough for now)
@@ -1167,13 +1165,11 @@ void common_start(bool override_enable_tf32 = true, bool print_device_info = tru | |||
nvtxNameCudaStreamA(main_stream, "main stream"); | |||
|
|||
// set up cuBLAS and cuBLASLt | |||
cublasCheck(cublasCreate(&cublas_handle)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
In preparation for FP8, this replaces all cuBLAS calls by cuBLASLt which is now wrapped by a single matmul_cublaslt() function.
It also adds support for GELU fusion which can be controlled on the command line with "-ge": 0 for disabled, 1 for forward only, 2 for forward+backward. The default is 2 is for H100+ and 0 for older GPUs based on seeing regressions on RTX 4090 previously but you might want to consider disabling it by default before merging due to the following:
In terms of accuracy and validation loss, the fused GELU seems very slightly worse than ours (how/why?!) which is not ideal especially when combined with GELU recomputation since it means the activations used for the backward pass won't be bit-for-bit identical to the ones used in the forward pass.
It's hard to tell how much this is just noise because the loss is only slightly worse for fused GELU (and the tensor thresholds are still too aggressive by default, so out of sheer luck the fused GELU passes but the non-fused doesn't on my system!) - based on this data, I think it's probably real. But then again with the non-deterministic cuDNN performance runs below (before my other PR) the "best" val loss is seen with "-ge 1" and the worst with "-ge 0" so it's very much within the noise threshold in that case... so who knows?
The performance is noticeably improved (H100 with cuDNN enabled) - I did 2 runs of each since it wasn't deterministic due to cuDNN:
==> +3.4%! (for results of the 1st run)
-r 0 -ge 0 (BF16 cuDNN disabled so fully deterministic):
-r 0 -ge 1:
-r 0 -ge 2:
-r 2 -ge 0:
-r 2 -ge 1
-r 2 -ge 2