Skip to content

Commit

Permalink
[hotfix][ragged attention] Minor, smaller workload to fit a T4 (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 19, 2022
1 parent deddf93 commit b24f222
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions tests/ragged_inference/test_seq_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def _create_indices(n_ctx_per_kv_cache):
return torch.tensor(indices_list, device="cuda")


@pytest.mark.n_gpus(2)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a CUDA device"
)
def test_garbage_pad_seq_kv_cache_correctness():
seq_kv_cache = [
_single_seq_kv_cache(n_ctx=1, value=33, d_model=2),
Expand All @@ -173,7 +175,9 @@ def test_garbage_pad_seq_kv_cache_correctness():
assert_eq(padded_values[2, :7, :], seq_kv_cache[2].values)


@pytest.mark.n_gpus(2)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a CUDA device"
)
def test_extend_kv_caches_correctness():
d_model = 6
seq_kv_cache = [
Expand Down Expand Up @@ -210,10 +214,12 @@ def test_extend_kv_caches_correctness():
assert_eq(new_cache[2].values[:, 0].cpu(), [55, 55, 55, 55, 55, 55, 55, 2])


@pytest.mark.n_gpus(2)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a CUDA device"
)
def test_index_select_throughput():
n_ctx_per_seq = 8192
n_seqs = 100
n_ctx_per_seq = 4096
n_seqs = 20
d_model_per_gpu = 12 * 1024 // 8

keys = _single_seq_kv_cache(
Expand Down Expand Up @@ -263,7 +269,9 @@ def do_the_op():
)


@pytest.mark.n_gpus(2)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a CUDA device"
)
def test_garbage_pad_seq_kv_cache_throughput(n_ctx_per_seq=1024):
n_seqs = 20
d_model_per_gpu = 12 * 1024 // 8
Expand Down

0 comments on commit b24f222

Please sign in to comment.