Skip to content

Commit

Permalink
[CI/Test] fix swap test for multi gpu (vllm-project#4689)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and tjohnson31415 committed May 16, 2024
1 parent 2563537 commit a696be1
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ def test_reshape_and_cache_flash(
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.set_default_device(device)

# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device)

qkv = torch.randn(num_tokens,
3,
Expand All @@ -245,6 +246,7 @@ def test_reshape_and_cache_flash(
head_size,
kv_cache_dtype,
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]

Expand Down

0 comments on commit a696be1

Please sign in to comment.