Skip to content

Commit

Permalink
[cleanup] fix pre-commit mypy issues (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
msbaines committed Sep 16, 2020
1 parent d16e9f6 commit 4a874a6
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import os
import time
from typing import Any, List, Union, cast
from typing import Any, List, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -62,7 +62,7 @@ def collate(inputs: List[Any]):
torch.cuda.reset_peak_memory_stats(rank)

# Shard the optimizer
optimizer: Union[OSS, OPTIM] = OSS(
optimizer: torch.optim.Optimizer = OSS(
params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)

Expand Down
6 changes: 2 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
Expand Down Expand Up @@ -73,10 +74,7 @@
def setup(app):
app.add_config_value(
"recommonmark_config",
{
"url_resolver": lambda url: github_doc_root + url,
"auto_toc_tree_section": "Contents",
},
{"url_resolver": lambda url: github_doc_root + url, "auto_toc_tree_section": "Contents"},
True,
)
app.add_transform(AutoStructify)
Expand Down
1 change: 1 addition & 0 deletions stubs/torch/cuda/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def manual_seed(seed: int) -> None: ...
def memory_allocated(device: Optional[_device_t]=...) -> int: ...
def max_memory_allocated(device: Optional[_device_t]=...) -> int: ...
def reset_max_memory_allocated(device: Optional[_device_t]=...) -> None: ...
def reset_peak_memory_stats(device: Union[_device_t, int] = None) -> None: ...
def memory_cached(device: Optional[_device_t]=...) -> int: ...
def max_memory_cached(device: Optional[_device_t]=...) -> int: ...
def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
Expand Down
2 changes: 1 addition & 1 deletion stubs/torch/utils/data/dataset.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ T = TypeVar('T')
class Dataset(Generic[T_co]):
def __getitem__(self, index: int) -> T_co: ...
def __len__(self) -> int: ...
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ...
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': ...

class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterable[T_co]: ...
Expand Down
2 changes: 1 addition & 1 deletion stubs/torch/utils/data/distributed.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ from . import Sampler, Dataset
T_co = TypeVar('T_co', covariant=True)
class DistributedSampler(Sampler[T_co]):
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ...
def __iter__(self) -> Iterator[int]: ...
def __iter__(self) -> Iterator[T_co]: ...
def __len__(self) -> int: ...
def set_epoch(self, epoch: int) -> None: ...

0 comments on commit 4a874a6

Please sign in to comment.