Skip to content

Commit

Permalink
Add deterministic field for cuda flash attention bwd.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jun 7, 2024
1 parent ac3c3ba commit 3b85024
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 28 deletions.
1 change: 1 addition & 0 deletions lib/nnc/ccv_nnc.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ typedef struct {
float scale; /**< [scaled_dot_product_attention.scale] The scale we multiple to the dot product of Q & K */
int is_causal; /**< [scaled_dot_product_attention.is_causal] Whether we have causal matrix associated with the attention. The attention mask will be cut to triangular if provided. */
int upcast; /**< [scaled_dot_product_attention.upcast] Whether we want to run the attention computation at higher precision (from FP16 to FP32). */
int deterministic; /**< [scaled_dot_product_attention.deterministic] Whether we want the attention computation to be deterministic (CUDA only). */
} scaled_dot_product_attention;
struct {
int type; /**< [pad.type] The type of pad, can be either zeros or replicating edge. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,25 +407,38 @@ static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, c
params.do_row_stride = D * Hq;
params.do_head_stride = D;
params.do_batch_stride = R * Hq * D;
params.deterministic = false; // If it is deterministic, we need to zero out dq_accum.
params.dq_accum_split_stride = 0;
params.deterministic = cmd.info.scaled_dot_product_attention.deterministic;

size_t dq_accum_size;
if (params.deterministic)
{
const ccv_nnc_cuda_device_prop_t props = ccv_nnc_gpu_device_props();
const int nsplits = (props.multi_processor_count + batch_size * Hq - 1) / (batch_size * Hq);
dq_accum_size = sizeof(float) * nsplits * batch_size * params.seqlen_q_rounded * Hq * params.d_rounded;
params.dq_accum_split_stride = batch_size * params.seqlen_q_rounded * Hq * params.d_rounded;
} else {
dq_accum_size = sizeof(float) * batch_size * params.seqlen_q_rounded * Hq * params.d_rounded;
params.dq_accum_split_stride = 0;
}

cudaStream_t stream = ccv_nnc_stream_context_get_stream(stream_context);
params.softmax_lse_ptr = saved_softmax_lse->data.u8;
if (Hq != Hk)
{
unsigned char* const workspace = (unsigned char*)ccv_nnc_stream_context_get_workspace(stream_context, (batch_size * Hq * params.seqlen_q_rounded + batch_size * params.seqlen_q_rounded * Hq * params.d_rounded) * sizeof(float) + batch_size * Hq * C * D * 2 * 2, CCV_TENSOR_GPU_MEMORY);
unsigned char* const workspace = (unsigned char*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * batch_size * Hq * params.seqlen_q_rounded + dq_accum_size + sizeof(short) * batch_size * Hq * C * D * 2, CCV_TENSOR_GPU_MEMORY);
params.dsoftmax_sum = workspace;
params.dq_accum_ptr = workspace + batch_size * Hq * params.seqlen_q_rounded * sizeof(float);
params.dk_ptr = workspace + (batch_size * Hq * params.seqlen_q_rounded + batch_size * params.seqlen_q_rounded * Hq * params.d_rounded) * sizeof(float);
params.dv_ptr = workspace + (batch_size * Hq * params.seqlen_q_rounded + batch_size * params.seqlen_q_rounded * Hq * params.d_rounded) * sizeof(float) + batch_size * Hq * C * D * 2;
params.dq_accum_ptr = workspace + sizeof(float) * batch_size * Hq * params.seqlen_q_rounded;
params.dk_ptr = workspace + sizeof(float) * batch_size * Hq * params.seqlen_q_rounded + dq_accum_size;
params.dv_ptr = workspace + sizeof(float) * batch_size * Hq * params.seqlen_q_rounded + dq_accum_size + sizeof(short) * batch_size * Hq * C * D;
} else {
unsigned char* const workspace = (unsigned char*)ccv_nnc_stream_context_get_workspace(stream_context, (batch_size * Hq * params.seqlen_q_rounded + batch_size * params.seqlen_q_rounded * Hq * params.d_rounded) * sizeof(float), CCV_TENSOR_GPU_MEMORY);
unsigned char* const workspace = (unsigned char*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * batch_size * Hq * params.seqlen_q_rounded + dq_accum_size, CCV_TENSOR_GPU_MEMORY);
params.dsoftmax_sum = workspace;
params.dq_accum_ptr = workspace + batch_size * Hq * params.seqlen_q_rounded * sizeof(float);
params.dq_accum_ptr = workspace + sizeof(float) * batch_size * Hq * params.seqlen_q_rounded;
params.dk_accum_ptr = 0;
params.dv_accum_ptr = 0;
}
cudaStream_t stream = ccv_nnc_stream_context_get_stream(stream_context);
if (params.deterministic)
cudaMemsetAsync(params.dq_accum_ptr, 0, dq_accum_size, stream);
run_mha_bwd(params, stream);
CUDA_ENFORCE(cudaGetLastError());
if (Hq != Hk)
Expand Down
42 changes: 23 additions & 19 deletions test/int/nnc/cublas.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -2836,29 +2836,31 @@ TEST_CASE("scaled dot product attention gradient with flash_attn")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF) &&
ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, CCV_NNC_BACKEND_GPU_REF));
#define num_long_trials 4
#define num_short_trials 2
#define num_long_trials 8
#define num_short_trials 4
#define num_trials (num_long_trials + num_short_trials)

dsfmt_t dsfmt;
dsfmt_init_gen_rand(&dsfmt, 10);
for (int trial = 0; trial < num_trials; ++trial) {
int B_candidates[num_trials] = { 32, 12, 16, 1, 2, 1 };
int R_candidates[num_trials] = { 160, 256, 128, 77, 77, 5 };
int C_candidates[num_trials] = { 128, 128, 128, 128, 128, 5 };
int Hq_candidates[num_trials] = { 8, 8, 8, 8, 8, 32 };
int Hk_candidates[num_trials] = { 8, 8, 8, 8, 2, 8 };
int D_candidates[num_trials] = { 64, 40, 160, 192, 192, 128 };
int is_causal_candidates[num_trials] = { 1, 0, 1, 1, 0, 1 };

int B = B_candidates[trial];
int R = R_candidates[trial];
int C = C_candidates[trial];
int Hq = Hq_candidates[trial];
int Hk = Hk_candidates[trial];
int D = D_candidates[trial];
int is_causal = is_causal_candidates[trial];
float scale = 1.0 / sqrt((float)D);
const int B_candidates[num_trials] = { 32, 12, 16, 1, 2, 1, 32, 12, 16, 1, 2, 1 };
const int R_candidates[num_trials] = { 160, 256, 128, 77, 77, 5, 160, 256, 128, 77, 77, 5 };
const int C_candidates[num_trials] = { 128, 128, 128, 128, 128, 5, 128, 128, 128, 128, 128, 5 };
const int Hq_candidates[num_trials] = { 8, 8, 8, 8, 8, 32, 8, 8, 8, 8, 8, 32 };
const int Hk_candidates[num_trials] = { 8, 8, 8, 8, 2, 8, 8, 8, 8, 8, 2, 8 };
const int D_candidates[num_trials] = { 64, 40, 160, 192, 192, 128, 64, 40, 160, 192, 192, 128 };
const int is_causal_candidates[num_trials] = { 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1 };
const int deterministic_candidates[num_trials] = { 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1 };

const int B = B_candidates[trial];
const int R = R_candidates[trial];
const int C = C_candidates[trial];
const int Hq = Hq_candidates[trial];
const int Hk = Hk_candidates[trial];
const int D = D_candidates[trial];
const int is_causal = is_causal_candidates[trial];
const int deterministic = deterministic_candidates[trial];
const float scale = 1.0 / sqrt((float)D);

ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0);
ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0);
Expand Down Expand Up @@ -2901,7 +2903,9 @@ TEST_CASE("scaled dot product attention gradient with flash_attn")
ccv_nnc_tensor_t* const gpu_softmax_lse = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, Hq, R), 0);
ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, is_causal), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, NULL, NULL, NULL), TENSOR_LIST(gpu_o_tensor, gpu_softmax_lse), 0);

ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD(scale, is_causal), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_do_tensor, 0, 0, gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, 0, 0, 0, gpu_o_tensor, gpu_softmax_lse), TENSOR_LIST(gpu_dq_tensor, gpu_dk_tensor, gpu_dv_tensor), 0);
ccv_nnc_cmd_t cmd = CMD_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD(scale, is_causal);
cmd.info.scaled_dot_product_attention.deterministic = deterministic;
ccv_nnc_cmd_exec(cmd, ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_do_tensor, 0, 0, gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, 0, 0, 0, gpu_o_tensor, gpu_softmax_lse), TENSOR_LIST(gpu_dq_tensor, gpu_dk_tensor, gpu_dv_tensor), 0);

ccv_nnc_tensor_t* const copy_of_gpu_dq_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0);
ccv_nnc_tensor_t* const copy_of_gpu_dk_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0);
Expand Down

0 comments on commit 3b85024

Please sign in to comment.