Skip to content

Commit

Permalink
Write more type hints for PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Jun 18, 2019
1 parent 118e290 commit c9b6871
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
14 changes: 13 additions & 1 deletion stubs/torch/cuda/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ _device_t = Union[_device, int]
def check_error(res: int) -> None: ...
def device_count() -> int: ...
def empty_cache() -> None: ...
def synchronize(device: _device_t) -> None: ...

#MODIFIED BY TORCHGPIPE
def synchronize(device: Optional[_device_t] = ...) -> None: ...
#END

def set_device(device: _device_t) -> None: ...
def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ...
def get_device_name(device: Optional[_device_t]=...) -> str: ...
Expand All @@ -40,3 +44,11 @@ def max_memory_cached(device: Optional[_device_t]=...) -> int: ...
def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
def cudart() -> ctypes.CDLL: ...
def find_cuda_windows_lib() -> Optional[ctypes.CDLL]: ...

#MODIFIED BY TORCHGPIPE
from typing import Any
class device:
def __init__(self, device: _device_t = ...) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ...
#END
9 changes: 9 additions & 0 deletions stubs/torch/nn/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ class Module:
def named_children(self) -> Iterable[Tuple[str, Module]]: ...
def add_module(self, name: str, module: Module) -> None: ...

def buffers(self) -> Iterator[Tensor]: ...
def parameters(self) -> Iterator[Parameter]: ...

def state_dict(self, destination: Optional[str] = ..., prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Tensor]: ...
def load_state_dict(self, state_dict: Dict[str, Tensor], strict: bool = ...) -> Tuple[List[str], List[str]]: ...

def train(self: TModule, mode: bool = ...) -> TModule: ...
def eval(self: TModule) -> TModule: ...


class Sequential(Module):
@overload
Expand Down

0 comments on commit c9b6871

Please sign in to comment.