Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mul-mat error for older GPUs #669

Merged
merged 4 commits into from
Dec 29, 2023

Conversation

bssrdf
Copy link
Contributor

@bssrdf bssrdf commented Dec 28, 2023

This PR fixed issue 668. The two test cases test-conv1d and test-conv2d passed with this PR on a GTX 1070 with CUDA v12.1. It also fixed problems downstream in other projects which use ggml.

@bssrdf bssrdf changed the title Fix mul-mat error for old GPUs Fix mul-mat error for older GPUs Dec 28, 2023
src/ggml-cuda.cu Outdated
@@ -7615,16 +7615,25 @@ static void ggml_cuda_op_mul_mat_cublas(
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
else {
else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix this trailing whitespace

src/ggml-cuda.cu Outdated
Comment on lines 7643 to 7647
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha, src0_ddf_i, ne00,
src1_ddf_i, ne10,
src1_ddf1_i, ne10,
&beta, dst_dd_i, ldc));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha, src0_ddf_i, ne00,
src1_ddf_i, ne10,
src1_ddf1_i, ne10,
&beta, dst_dd_i, ldc));
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha, src0_ddf_i, CUDA_R_32F, ne00,
src1_ddf1_i, CUDA_R_32F, ne10,
&beta, dst_ddf1, CUDA_R_32F, ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Green-Sky Can you test the modifications I suggested above? I am not a professional CUDA developer, but after my testing, such modifications will greatly reduce the probability of bad images. However, I am not sure whether there is a compatibility problem with cublasGemmEx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does the dst_ddf1 come frome?

/build/xxx-source/ggml/src/ggml-cuda.cu(7456): error: identifier "dst_ddf1" is undefined

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry it‘s dst_dd_i

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i cant really see any significant correlation to the issues i observe. Since they look like synchronization issues, a slightly different invocation can accidentally make them go away or less likely.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use stable-diffusion.cpp
Will save_tensor_to_file in sample export the tensor? The export fails and succeeds in the case of the same seed. Yes, then export the failed and successful ones at decode.

Copy link
Contributor

@Cyberhan123 Cyberhan123 Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tested, the tensor in the sample method successed and the failed tensor were equal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in the case of decoding, they have a large deviation. After I added the correction of cublasGemmEx, the deviation became smaller, but I am more confused as to why such a deviation occurs.

Copy link
Contributor Author

@bssrdf bssrdf Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @slaren, for the style fix and other updates.
@Cyberhan123, I'll leave it to @slaren and others to decide whether replacing cublasSgemm with cublasGemmEx is a good idea. I am not an expert on CUDA/CUBlas.

@slaren slaren force-pushed the fix_mul_mat_cublas_for_cp_lt_70 branch from 55ba78b to 8f137dd Compare December 28, 2023 12:29
@FSSRepo
Copy link
Collaborator

FSSRepo commented Dec 28, 2023

@bssrdf i am going to review and test this fix on my GPU, even though it works fine for me, just to ensure there is no impact on the current performance.

@bssrdf
Copy link
Contributor Author

bssrdf commented Dec 28, 2023

@bssrdf i am going to review and test this fix on my GPU, even though it works fine for me, just to ensure there is no impact on the current performance.

@FSSRepo, sounds good. Thanks for the test.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for digging into this and resolving the issues. I should have used more GGML_ASSERTs to avoid such kind of issues - will try to do so in the future

@ggerganov ggerganov merged commit dbd0295 into ggerganov:master Dec 29, 2023
4 checks passed
@bssrdf bssrdf deleted the fix_mul_mat_cublas_for_cp_lt_70 branch December 29, 2023 13:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

test-conv1d and test-conv2d failed on GPUs with computation capability <= 6.1 CUDA cannot generate images
6 participants