diff --git a/captum/testing/helpers/classification_models.py b/captum/testing/helpers/classification_models.py index 298cd2a29..a0db96901 100644 --- a/captum/testing/helpers/classification_models.py +++ b/captum/testing/helpers/classification_models.py @@ -12,23 +12,17 @@ class SigmoidModel(nn.Module): -pytorch-and-make-your-life-simpler-ec5367895199 """ - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, num_in, num_hidden, num_out) -> None: + def __init__(self, num_in: int, num_hidden: int, num_out: int) -> None: super().__init__() - # pyre-fixme[4]: Attribute must be annotated. self.num_in = num_in - # pyre-fixme[4]: Attribute must be annotated. self.num_hidden = num_hidden - # pyre-fixme[4]: Attribute must be annotated. self.num_out = num_out self.lin1 = nn.Linear(num_in, num_hidden) self.lin2 = nn.Linear(num_hidden, num_out) self.relu1 = nn.ReLU() self.sigmoid = nn.Sigmoid() - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: lin1 = self.lin1(input) lin2 = self.lin2(self.relu1(lin1)) return self.sigmoid(lin2) @@ -40,14 +34,12 @@ class SoftmaxModel(nn.Module): https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/ """ - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, num_in, num_hidden, num_out, inplace: bool = False) -> None: + def __init__( + self, num_in: int, num_hidden: int, num_out: int, inplace: bool = False + ) -> None: super().__init__() - # pyre-fixme[4]: Attribute must be annotated. self.num_in = num_in - # pyre-fixme[4]: Attribute must be annotated. self.num_hidden = num_hidden - # pyre-fixme[4]: Attribute must be annotated. self.num_out = num_out self.lin1 = nn.Linear(num_in, num_hidden) self.lin2 = nn.Linear(num_hidden, num_hidden) @@ -56,9 +48,7 @@ def __init__(self, num_in, num_hidden, num_out, inplace: bool = False) -> None: self.relu2 = nn.ReLU(inplace=inplace) self.softmax = nn.Softmax(dim=1) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: lin1 = self.relu1(self.lin1(input)) lin2 = self.relu2(self.lin2(lin1)) lin3 = self.lin3(lin2) @@ -72,14 +62,10 @@ class SigmoidDeepLiftModel(nn.Module): -pytorch-and-make-your-life-simpler-ec5367895199 """ - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, num_in, num_hidden, num_out) -> None: + def __init__(self, num_in: int, num_hidden: int, num_out: int) -> None: super().__init__() - # pyre-fixme[4]: Attribute must be annotated. self.num_in = num_in - # pyre-fixme[4]: Attribute must be annotated. self.num_hidden = num_hidden - # pyre-fixme[4]: Attribute must be annotated. self.num_out = num_out self.lin1 = nn.Linear(num_in, num_hidden, bias=False) self.lin2 = nn.Linear(num_hidden, num_out, bias=False) @@ -88,9 +74,7 @@ def __init__(self, num_in, num_hidden, num_out) -> None: self.relu1 = nn.ReLU() self.sigmoid = nn.Sigmoid() - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: lin1 = self.lin1(input) lin2 = self.lin2(self.relu1(lin1)) return self.sigmoid(lin2) @@ -102,14 +86,10 @@ class SoftmaxDeepLiftModel(nn.Module): https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/ """ - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, num_in, num_hidden, num_out) -> None: + def __init__(self, num_in: int, num_hidden: int, num_out: int) -> None: super().__init__() - # pyre-fixme[4]: Attribute must be annotated. self.num_in = num_in - # pyre-fixme[4]: Attribute must be annotated. self.num_hidden = num_hidden - # pyre-fixme[4]: Attribute must be annotated. self.num_out = num_out self.lin1 = nn.Linear(num_in, num_hidden) self.lin2 = nn.Linear(num_hidden, num_hidden) @@ -121,9 +101,7 @@ def __init__(self, num_in, num_hidden, num_out) -> None: self.relu2 = nn.ReLU() self.softmax = nn.Softmax(dim=1) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: lin1 = self.relu1(self.lin1(input)) lin2 = self.relu2(self.lin2(lin1)) lin3 = self.lin3(lin2)