Skip to content

Commit

Permalink
Test default CUDA index detection of gather/scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed May 29, 2019
1 parent 3031b19 commit 634174d
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/test_microbatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
import torch
import torch.cuda

from torchgpipe.microbatch import gather, scatter

Expand Down Expand Up @@ -43,3 +45,17 @@ def test_scatter_tuple():
assert b[0].size() == (1, 1)
assert a[1].size() == (2, 2)
assert b[1].size() == (2, 2)


@pytest.mark.skipif(not torch.cuda.is_available(), reason='cuda required')
def test_default_device_index():
default_cuda = torch.device('cuda')
assert default_cuda.index is None

x = torch.rand(2, 1)
a, b = scatter(x, chunks=2, device=default_cuda)
y = gather([a, b], device=default_cuda)

assert a.is_cuda
assert b.is_cuda
assert y.is_cuda

0 comments on commit 634174d

Please sign in to comment.