diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index babe590..215479e 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -249,24 +249,11 @@ void set_params_dgrad( } void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device - ); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.has_mask, Has_mask, [&] { BOOL_SWITCH(params.has_bias, Has_bias, [&] { - // splitkv kernel is not supported for head_dim >= 128 in sm89 due to smem limits - bool splitkv_forbidden = (kHeadDim >= 128) && (max_smem_per_block < 112 * 1024); - params.num_splits = splitkv_forbidden ? 1 : params.num_splits; if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 run_mha_fwd_(params, stream); } else {