Skip to content

Commit

Permalink
MemEff: Set CUDA stream properly
Browse files Browse the repository at this point in the history
ghstack-source-id: f01ce1f7e4d4b2e85caf17c7a786b0d27dbc8559
Pull Request resolved: #491
  • Loading branch information
danthe3rd committed Oct 25, 2022
1 parent faa88b1 commit bd076f7
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 26 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ run_coverage: &run_coverage
when: always
command: |
source $BASH_ENV
CUDA_LAUNCH_BLOCKING=1 $CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose --timeout 600 --cov-report=xml --cov=./ tests
$CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose --timeout 600 --cov-report=xml --cov=./ tests
#Uploading test coverage for Python code
bash <(curl -s https://codecov.io/bash) -f coverage.xml -cF Python
Expand All @@ -194,7 +194,7 @@ run_unittests: &run_unittests
when: always
command: |
source $BASH_ENV
CUDA_LAUNCH_BLOCKING=1 $CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose tests
$CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose tests
run_experimental_unittests: &run_experimental_unittests
- run:
Expand Down
125 changes: 103 additions & 22 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import math
import random
from typing import Sequence, Type

import pytest
import torch
Expand Down Expand Up @@ -73,36 +74,54 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
return shapes


def _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(**kwargs):
for op in [
xformers.ops.MemoryEfficientAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassOp,
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp,
]:
for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op, **kwargs):
ALL_OPS: Sequence[Type[xformers.ops.AttentionOpBase]] = [
xformers.ops.MemoryEfficientAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassOp,
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp,
]


def _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(one_shape_per_op: bool = False):
for op in ALL_OPS:
for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
has_one = False
for device in _devices:
if device not in op.SUPPORTED_DEVICES:
continue
for dtype in op.SUPPORTED_DTYPES:
yield (op, device, dtype, *shape)
has_one = True
if has_one and one_shape_per_op:
break


def _gen_ids(op_device_dtype_B_Mq_Mkv_H_K_Kv):
return [
f"{op.NAME}-{device}-{str(dtype)}-{batch_size},{q_len},{kv_len},{h},{k},{kv}"
for (
op,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) in op_device_dtype_B_Mq_Mkv_H_K_Kv
]


_op_device_dtype_B_Mq_Mkv_H_K_Kv = list(_generate_op_device_dtype_B_Mq_Mkv_H_K_Kv())
_op_device_dtype_B_Mq_Mkv_H_K_Kv_ids = [
f"{op.NAME}-{device}-{str(dtype)}-{batch_size},{q_len},{kv_len},{h},{k},{kv}"
for (
op,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) in _op_device_dtype_B_Mq_Mkv_H_K_Kv
]
_op_device_dtype_B_Mq_Mkv_H_K_Kv_ids = _gen_ids(_op_device_dtype_B_Mq_Mkv_H_K_Kv)

_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs = list(
_generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(one_shape_per_op=True)
)
_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids = _gen_ids(
_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs
)


def assert_allclose(
Expand Down Expand Up @@ -818,3 +837,65 @@ def test_slice(s):

# tensors[::2]
test_slice(slice(None, None, 2))


@pytest.mark.parametrize(
"op_device_dtype_B_Mq_Mkv_H_K_Kv",
_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs,
ids=_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids,
)
def test_cuda_streams(
op_device_dtype_B_Mq_Mkv_H_K_Kv,
):
(
op,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) = op_device_dtype_B_Mq_Mkv_H_K_Kv
# Needs to be big enough so kernels take some time
# as we are trying to do a race-condition here
q_len = 1024
kv_len = 1024
op_device_dtype_B_Mq_Mkv_H_K_Kv = [
op,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
]
s_hipri = torch.cuda.Stream(priority=-1)
s_lopri = torch.cuda.Stream(priority=0)
with torch.cuda.stream(s_lopri):
query, key, value, attn_bias = create_tensors(
*op_device_dtype_B_Mq_Mkv_H_K_Kv, attn_bias_type=None, fmt="BMHK"
)
# Queue a lot of kernels
for i in range(20):
query = query.relu()
query = query * 2
s_hipri.wait_stream(s_lopri)
with torch.cuda.stream(s_hipri):
out = xformers.ops.memory_efficient_attention(query, key, value, op=op)
# This will run in hi-pri AFTER the kernel if it
# runs on the correct stream
out = out / 2
torch.cuda.synchronize()
ref = ref_attention(query, key, value) / 2
assert out.shape == ref.shape, out.shape

assert_allclose(
out.float(),
ref.float(),
atol=op.FORWARD_ERROR_ATOL[dtype],
rtol=op.FORWARD_ERROR_RTOL.get(dtype, 1e-5),
)
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ mem_efficient_attention_backward_cutlass(
CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value);

at::cuda::CUDAGuard device_guard(query.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int64_t B = query.size(0);
int64_t M = query.size(1);
Expand Down Expand Up @@ -223,7 +224,7 @@ mem_efficient_attention_backward_cutlass(
checkBinaryArchMatches(), "Something went wrong in the build process");
#endif

kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
};

DISPATCH_KERNEL(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ std::tuple<at::Tensor, at::Tensor> efficient_attention_forward_cutlass(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
}
Kernel::check_supported(p);
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
};
// Dispatch to the right kernel
DISPATCH_KERNEL(query, key, value, ([&]() {
Expand Down

0 comments on commit bd076f7

Please sign in to comment.