Skip to content

Commit

Permalink
Revert "[Unity][BYOC] Integrate Flash attention v2 kernel into CUTLAS…
Browse files Browse the repository at this point in the history
…S BYOC (#15467)"

This reverts commit 38e2b88.
  • Loading branch information
MasterJH5574 authored and tqchen committed Sep 6, 2023
1 parent 8ddeed9 commit c14964d
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 184 deletions.
4 changes: 0 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -881,10 +881,6 @@ if(USE_CUDA AND USE_CUTLASS)
install(TARGETS fpA_intB_gemm EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX})
target_link_libraries(tvm PRIVATE fpA_intB_gemm)
target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm)

install(TARGETS flash_attn EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX})
target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn)
target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn)
endif()

if(USE_CUDA AND USE_NCCL)
Expand Down
1 change: 0 additions & 1 deletion cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ if(USE_CUDA AND USE_CUTLASS)

set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)

message(STATUS "Build with CUTLASS")
Expand Down
97 changes: 0 additions & 97 deletions python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,100 +159,3 @@ def instantiate_attention_template(attrs):
)

return substitute_template(template, attrs)


def instantiate_flash_attention_template(attrs):
"""Return host code for flash attention."""

template = """
int q_head_stride = ${head_dim};
int k_head_stride = ${head_dim};
int v_head_stride = ${head_dim};
int o_head_stride = ${head_dim};
int q_row_stride = q_head_stride * ${num_heads};
int k_row_stride = k_head_stride * ${num_heads};
int v_row_stride = v_head_stride * ${num_heads};
int o_row_stride = o_head_stride * ${num_heads};
int q_batch_stride = q_row_stride * ${num_queries};
int k_batch_stride = k_row_stride * ${num_keys};
int v_batch_stride = v_row_stride * ${num_keys};
int o_batch_stride = o_row_stride * ${num_queries};
flash_attn::flash_attention_forward(
static_cast<const cutlass::half_t*>(${query}->data),
static_cast<const cutlass::half_t*>(${key}->data),
static_cast<const cutlass::half_t*>(${value}->data),
static_cast<cutlass::half_t*>(out0->data),
${num_batches},
${num_queries},
${num_keys},
${num_heads},
${num_heads},
${head_dim},
q_batch_stride,
k_batch_stride,
v_batch_stride,
o_batch_stride,
q_head_stride,
k_head_stride,
v_head_stride,
o_head_stride,
q_row_stride,
k_row_stride,
v_row_stride,
o_row_stride,
${scale},
${is_causal},
nullptr);
"""

template_stacked = """
int q_head_stride = ${head_dim};
int k_head_stride = ${head_dim};
int v_head_stride = ${head_dim};
int o_head_stride = ${head_dim};
int row_stride = q_head_stride * ${num_heads} +
k_head_stride * ${num_heads} +
v_head_stride * ${num_heads};
int q_row_stride = row_stride;
int k_row_stride = row_stride;
int v_row_stride = row_stride;
int o_row_stride = o_head_stride * ${num_heads};
int q_batch_stride = q_row_stride * ${num_queries};
int k_batch_stride = k_row_stride * ${num_keys};
int v_batch_stride = v_row_stride * ${num_keys};
int o_batch_stride = o_row_stride * ${num_queries};
flash_attn::flash_attention_forward(
static_cast<const cutlass::half_t*>(${qkv}->data),
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads},
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads} * 2,
static_cast<cutlass::half_t*>(out0->data),
${num_batches},
${num_queries},
${num_keys},
${num_heads},
${num_heads},
${head_dim},
q_batch_stride,
k_batch_stride,
v_batch_stride,
o_batch_stride,
q_head_stride,
k_head_stride,
v_head_stride,
o_head_stride,
q_row_stride,
k_row_stride,
v_row_stride,
o_row_stride,
${scale},
${is_causal},
nullptr);
"""

if "qkv" in attrs:
return substitute_template(template_stacked, attrs)

return substitute_template(template, attrs)
2 changes: 0 additions & 2 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
cutlass_attention_include = os.path.join(cutlass_root, "examples/41_fused_multi_head_attention")
cutlass_fpA_intB_gemm_include = os.path.join(cutlass_root, "../cutlass_fpA_intB_gemm")
flash_attn_include = os.path.join(cutlass_root, "../libflash_attn/include")

kwargs = {}
kwargs["cc"] = "nvcc"
Expand All @@ -78,7 +77,6 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
f"-I{cutlass_util_include}",
f"-I{cutlass_attention_include}",
f"-I{cutlass_fpA_intB_gemm_include}",
f"-I{flash_attn_include}",
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
Expand Down
132 changes: 53 additions & 79 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
from tvm.tir import IntImm

from . import _ffi_api as ffi
from .attention_operation import (
instantiate_attention_template,
instantiate_flash_attention_template,
)
from .attention_operation import instantiate_attention_template
from .conv2d_operation import instantiate_conv2d_template
from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul
from .layer_norm_operation import instantiate_layer_norm_template
Expand Down Expand Up @@ -723,6 +720,7 @@ def get_batch_on_arg(arg_name, arg_shape):
return CodegenResult(code, headers)

elif "attention" in func_name:
headers.append("kernel_forward.h")
data_type = dtype_map[annotations["arg0_dtype"]]

attrs["qkv_layout"] = annotations["qkv_layout"]
Expand All @@ -749,86 +747,62 @@ def get_batch_on_arg(arg_name, arg_shape):
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"]))

data_type_size = DataTypeSize[data_type]
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
attrs["kIsAligned"] = True
elif (h % 4 == 0) and (h_v % 4 == 0):
attrs["kIsAligned"] = False
else:
raise NotImplementedError()
if h_v > 64:
attrs["kQueriesPerBlock"] = 32
attrs["kKeysPerBlock"] = 128
attrs["kSingleValueIteration"] = h_v <= 128
else:
attrs["kQueriesPerBlock"] = 64
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True
attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
attrs["scale"] = (
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
)

use_flash = (
annotations["ret_dtype"] == "float16"
and "bias" not in attrs
and int(attrs["head_dim"]) <= 256
and int(attrs["head_dim"]) % 8 == 0
and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
# We have not thoroughly validated flash with causal mask yet, so for now we support
# only non-causal cases.
and int(annotations["custom_mask_type"]) == 0
# Flash v2 is currently not supported for sm < 80
and int(annotations["arch"]) >= 80
)

if use_flash:
headers.append("flash.h")
attrs["is_causal"] = int(annotations["custom_mask_type"]) == 0
code = instantiate_flash_attention_template(attrs)
else:
headers.append("kernel_forward.h")

data_type_size = DataTypeSize[data_type]
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
attrs["kIsAligned"] = True
elif (h % 4 == 0) and (h_v % 4 == 0):
attrs["kIsAligned"] = False
else:
raise NotImplementedError()
if h_v > 64:
attrs["kQueriesPerBlock"] = 32
attrs["kKeysPerBlock"] = 128
attrs["kSingleValueIteration"] = h_v <= 128
else:
attrs["kQueriesPerBlock"] = 64
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True

assert (
attrs["scale"] > 0 or attrs["scale"] < 0
), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False

attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"

attrs["custom_mask_type"] = annotations["custom_mask_type"]

for arg in func_args:
if "workspace" in arg:
attrs["workspace"] = arg
if "bias" in attrs:
attrs["kSupportsBias"] = True
if len(annotations["bias_shape"]) == 4:
strides = "p.num_keys"
if annotations["bias_shape"][2] == 1:
attrs["bias_strideM"] = 0
else:
attrs["bias_strideM"] = strides
strides = f"p.num_queries * {strides}"
if annotations["bias_shape"][1] == 1:
attrs["bias_strideH"] = 0
else:
attrs["bias_strideH"] = strides
strides = f"p.num_heads * {strides}"
if annotations["bias_shape"][0] == 1:
attrs["bias_strideB"] = 0
else:
attrs["bias_strideB"] = strides
attrs["custom_mask_type"] = annotations["custom_mask_type"]

assert (
attrs["scale"] > 0 or attrs["scale"] < 0
), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False

for arg in func_args:
if "workspace" in arg:
attrs["workspace"] = arg
if "bias" in attrs:
attrs["kSupportsBias"] = True
if len(annotations["bias_shape"]) == 4:
strides = "p.num_keys"
if annotations["bias_shape"][2] == 1:
attrs["bias_strideM"] = 0
else:
attrs["bias_strideM"] = strides
strides = f"p.num_queries * {strides}"
if annotations["bias_shape"][1] == 1:
attrs["bias_strideH"] = 0
else:
attrs["bias_strideH"] = strides
strides = f"p.num_heads * {strides}"
if annotations["bias_shape"][0] == 1:
attrs["bias_strideB"] = 0
else:
raise NotImplementedError()
attrs["bias_strideB"] = strides
else:
# To support negative scale in current Cutlass implementation,
# kSupportsBias should be set true, or there are nan's as result.
attrs["kSupportsBias"] = attrs["scale"] < 0

code = instantiate_attention_template(attrs)

raise NotImplementedError()
else:
# To support negative scale in current Cutlass implementation,
# kSupportsBias should be set true, or there are nan's as result.
attrs["kSupportsBias"] = attrs["scale"] < 0
code = instantiate_attention_template(attrs)
return CodegenResult(code, headers)
elif "layer_norm" in func_name:
headers.append("cutlass/util/device_layernorm.h")
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ def stacked_attention_size(request):

def test_stacked_attention_split_offload(stacked_attention_size):
b, s, n, (h, h_v), bias_shape, scale, single_shape = stacked_attention_size
qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float16")
qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float32")
if scale == "none":
mod = get_relax_stacked_attention_module(
qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape
Expand Down

0 comments on commit c14964d

Please sign in to comment.