Skip to content

Commit

Permalink
Fix wrong type hint for register_parameter()
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed May 31, 2019
1 parent d35d7ec commit 4240b3d
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions stubs/torch/nn/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#MODIFIED BY TORCHGPIPE
from typing import Any, Callable, Iterable, Iterator, Tuple, TypeVar
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar

from torch import Tensor, device

Expand All @@ -17,6 +17,13 @@ class __RemovableHandle:
def remove(self) -> None: ...


class Parameter(Tensor):
def __new__(cls,
data: Optional[Tensor] = None,
requires_grad: bool = True,
) -> Parameter: ...


class Module:
training: bool

Expand All @@ -27,7 +34,7 @@ class Module:
def apply(self, fn: Callable[[Module], None]) -> Module: ...

def register_buffer(self, name: str, tensor: Tensor) -> None: ...
def register_parameter(self, name: str, tensor: Tensor) -> None: ...
def register_parameter(self, name: str, param: Parameter) -> None: ...

def register_backward_hook(self, hook: __Hook2) -> __RemovableHandle: ...
def register_forward_pre_hook(self, hook: __Hook1) -> __RemovableHandle: ...
Expand Down

0 comments on commit 4240b3d

Please sign in to comment.