Skip to content

Commit

Permalink
Deny moving parameters and buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Jun 11, 2019
1 parent b1c1a35 commit 38248d0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,41 @@ def test_partitions():
assert isinstance(model.partitions[1], Partition)

assert 'partitions.0.module.0.weight' in model.state_dict()


def test_deny_moving():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)

model = nn.Sequential(a, b)
model = GPipe(model, [1, 1], devices=['cpu', 'cpu'])

# Moving is denied.
with pytest.raises(TypeError):
model.cuda()

with pytest.raises(TypeError):
model.cpu()

with pytest.raises(TypeError):
model.to(torch.device('cuda'))

with pytest.raises(TypeError):
model.to(0)

with pytest.raises(TypeError):
model.to('cuda')

with pytest.raises(TypeError):
model.to(device=0)

with pytest.raises(TypeError):
model.to(torch.rand(1))

with pytest.raises(TypeError):
model.to(tensor=torch.rand(1))

# Casting is allowed.
model.half()
model.to(torch.double)
model.to(dtype=torch.float)
33 changes: 33 additions & 0 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def current_microbatch() -> Optional[Tensor]:
return None


MOVING_DENIED = TypeError('denied to move parameters and buffers, '
'because GPipe should manage device placement')


class GPipe(nn.Module):
"""Wraps an arbitrary :class:`~torch.nn.Sequential` module to train on
GPipe_. If the module requires lots of memory, GPipe will be very
Expand Down Expand Up @@ -165,6 +169,35 @@ def __getitem__(self, index: int) -> nn.Module:

raise IndexError

# GPipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError.
def cuda(self, device: Any = None) -> nn.Module:
raise MOVING_DENIED

def cpu(self) -> nn.Module:
raise MOVING_DENIED

def to(self, *args: Any, **kwargs: Any) -> nn.Module:
# Deny these usages:
#
# - to(device[, dtype, non_blocking])
# - to(tensor[, non_blocking])
#
# But allow this:
#
# - to(dtype[, non_blocking])
#
if 'device' in kwargs or 'tensor' in kwargs:
raise MOVING_DENIED

if args:
if isinstance(args[0], (torch.device, int, str)):
raise MOVING_DENIED
if isinstance(args[0], Tensor):
raise MOVING_DENIED

return super().to(*args, **kwargs)

@staticmethod
def partition(module: nn.Sequential,
balance: Iterable[int],
Expand Down

0 comments on commit 38248d0

Please sign in to comment.