Skip to content

Commit

Permalink
Config mypy correctly for torchgpipe_balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Oct 8, 2019
1 parent 1129755 commit 622025c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ mypy_path = ./stubs/
follow_imports = normal

# This project must be strictly typed.
[mypy-torchgpipe.*,mypy-torchgpipe_balancing.*]
[mypy-torchgpipe.*,torchgpipe_balancing.*]
check_untyped_defs = true
disallow_untyped_defs = true
disallow_untyped_calls = true
Expand Down
4 changes: 2 additions & 2 deletions torchgpipe_balancing/blockpartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper: https://arxiv.org/pdf/1308.2452.pdf
"""
from typing import List
from typing import Iterator, List, Tuple

__all__ = ['solve']

Expand Down Expand Up @@ -40,7 +40,7 @@ def block_size(i: int) -> float:
stop = splits[i]
return sum(normal_sequence[start:stop])

def leaderboard():
def leaderboard() -> Iterator[Tuple[float, int]]:
return ((block_size(i), i) for i in range(partitions))

while True:
Expand Down
2 changes: 1 addition & 1 deletion torchgpipe_balancing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def training_sandbox(module: nn.Sequential) -> Generator[None, None, None]:
module.train(training)


def synchronize_device(device: torch.device):
def synchronize_device(device: torch.device) -> None:
if device.type == 'cpu':
return
torch.cuda.synchronize(device)
Expand Down

0 comments on commit 622025c

Please sign in to comment.