Skip to content

Commit

Permalink
Raise ValueError on <1 numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed May 21, 2019
1 parent f1ea97a commit 2a403bc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ def test_balance_wrong_length():
GPipe(model, balance=[3])


def test_balance_less_than_1():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)

model = nn.Sequential(a, b)

with pytest.raises(ValueError):
GPipe(model, balance=[0, 2])

with pytest.raises(ValueError):
GPipe(model, balance=[-1, 3])


def test_chunks_less_than_1():
model = nn.Sequential(nn.Linear(1, 1))

with pytest.raises(ValueError):
GPipe(model, balance=[1], devices=['cpu'], chunks=0)

with pytest.raises(ValueError):
GPipe(model, balance=[1], devices=['cpu'], chunks=-1)


def test_too_few_devices():
x = nn.Linear(1, 1)
model = nn.Sequential(x, x, x, x)
Expand Down
5 changes: 5 additions & 0 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(self,
self._partitions, self.balance, self.in_device, self.out_device = \
self.partition(module, balance, devices)

if chunks <= 0:
raise ValueError('number of chunks must be positive integer')
self.chunks = chunks

if checkpoint not in ['always', 'except_last', 'never']:
Expand Down Expand Up @@ -159,6 +161,9 @@ def partition(module: nn.Sequential,
if len(module) != sum(balance):
raise ValueError('module and sum of balance have different length '
'(module: %d, sum of balance: %d)' % (len(module), sum(balance)))
if any(x <= 0 for x in balance):
raise ValueError('all balance numbers must be positive integer '
'(balance: %r)' % balance)

if devices is None:
devices = [torch.device(d) for d in range(torch.cuda.device_count())]
Expand Down

0 comments on commit 2a403bc

Please sign in to comment.