Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MemEff: Set CUDA stream properly #491

Merged
merged 20 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ run_coverage: &run_coverage
when: always
command: |
source $BASH_ENV
CUDA_LAUNCH_BLOCKING=1 $CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose --timeout 600 --cov-report=xml --cov=./ tests
$CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose --timeout 600 --cov-report=xml --cov=./ tests
#Uploading test coverage for Python code
bash <(curl -s https://codecov.io/bash) -f coverage.xml -cF Python

Expand All @@ -194,7 +194,7 @@ run_unittests: &run_unittests
when: always
command: |
source $BASH_ENV
CUDA_LAUNCH_BLOCKING=1 $CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose tests
$CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose tests

run_experimental_unittests: &run_experimental_unittests
- run:
Expand Down
127 changes: 105 additions & 22 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import math
import random
from typing import Sequence, Type

import pytest
import torch
Expand Down Expand Up @@ -73,36 +74,54 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
return shapes


def _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(**kwargs):
for op in [
xformers.ops.MemoryEfficientAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassOp,
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp,
]:
for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op, **kwargs):
ALL_OPS: Sequence[Type[xformers.ops.AttentionOpBase]] = [
xformers.ops.MemoryEfficientAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassOp,
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp,
]


def _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(one_shape_per_op: bool = False):
for op in ALL_OPS:
for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
has_one = False
for device in _devices:
if device not in op.SUPPORTED_DEVICES:
continue
for dtype in op.SUPPORTED_DTYPES:
yield (op, device, dtype, *shape)
has_one = True
if has_one and one_shape_per_op:
break


def _gen_ids(op_device_dtype_B_Mq_Mkv_H_K_Kv):
return [
f"{op.NAME}-{device}-{str(dtype)}-{batch_size},{q_len},{kv_len},{h},{k},{kv}"
for (
op,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) in op_device_dtype_B_Mq_Mkv_H_K_Kv
]


_op_device_dtype_B_Mq_Mkv_H_K_Kv = list(_generate_op_device_dtype_B_Mq_Mkv_H_K_Kv())
_op_device_dtype_B_Mq_Mkv_H_K_Kv_ids = [
f"{op.NAME}-{device}-{str(dtype)}-{batch_size},{q_len},{kv_len},{h},{k},{kv}"
for (
op,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) in _op_device_dtype_B_Mq_Mkv_H_K_Kv
]
_op_device_dtype_B_Mq_Mkv_H_K_Kv_ids = _gen_ids(_op_device_dtype_B_Mq_Mkv_H_K_Kv)

_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs = list(
_generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(one_shape_per_op=True)
)
_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids = _gen_ids(
_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs
)


def assert_allclose(
Expand Down Expand Up @@ -818,3 +837,67 @@ def test_slice(s):

# tensors[::2]
test_slice(slice(None, None, 2))


@pytest.mark.parametrize(
"op_device_dtype_B_Mq_Mkv_H_K_Kv",
_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs,
ids=_op_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids,
)
def test_cuda_streams(
op_device_dtype_B_Mq_Mkv_H_K_Kv,
):
(
op,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) = op_device_dtype_B_Mq_Mkv_H_K_Kv
if device != "cuda":
pytest.skip("Not CUDA")
# Needs to be big enough so kernels take some time
# as we are trying to do a race-condition here
q_len = 1024
kv_len = 1024
op_device_dtype_B_Mq_Mkv_H_K_Kv = [
op,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
]
s_hipri = torch.cuda.Stream(priority=-1)
s_lopri = torch.cuda.Stream(priority=0)
with torch.cuda.stream(s_lopri):
query, key, value, attn_bias = create_tensors(
*op_device_dtype_B_Mq_Mkv_H_K_Kv, attn_bias_type=None, fmt="BMHK"
)
# Queue a lot of kernels
for i in range(20):
query = query.relu()
query = query * 2
s_hipri.wait_stream(s_lopri)
with torch.cuda.stream(s_hipri):
out = xformers.ops.memory_efficient_attention(query, key, value, op=op)
# This will run in hi-pri AFTER the kernel if it
# runs on the correct stream
out = out / 2
torch.cuda.synchronize()
ref = ref_attention(query, key, value) / 2
assert out.shape == ref.shape, out.shape

assert_allclose(
out.float(),
ref.float(),
atol=op.FORWARD_ERROR_ATOL[dtype],
rtol=op.FORWARD_ERROR_RTOL.get(dtype, 1e-5),
)
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ mem_efficient_attention_backward_cutlass(
CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value);

at::cuda::CUDAGuard device_guard(query.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int64_t B = query.size(0);
int64_t M = query.size(1);
Expand Down Expand Up @@ -223,7 +224,7 @@ mem_efficient_attention_backward_cutlass(
checkBinaryArchMatches(), "Something went wrong in the build process");
#endif

kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
};

DISPATCH_KERNEL(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ std::tuple<at::Tensor, at::Tensor> efficient_attention_forward_cutlass(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
}
Kernel::check_supported(p);
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
};
// Dispatch to the right kernel
DISPATCH_KERNEL(query, key, value, ([&]() {
Expand Down