Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions captum/optim/_models/inception_v1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple, Union, cast

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -25,10 +27,15 @@ def googlenet(
training. Default: 1008 when pretrained is True.
transform_input (bool): If True, preprocesses the input according to
the method with which it was trained on ImageNet. Default: *False*
bgr_transform (bool): If True and transform_input is True, perform an
RGB to BGR transform in the internal preprocessing.
Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "bgr_transform" not in kwargs:
kwargs["bgr_transform"] = False
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if "out_features" not in kwargs:
Expand Down Expand Up @@ -56,10 +63,12 @@ def __init__(
out_features: int = 1008,
aux_logits: bool = False,
transform_input: bool = False,
bgr_transform: bool = False,
) -> None:
super(InceptionV1, self).__init__()
self.aux_logits = aux_logits
self.transform_input = transform_input
self.bgr_transform = bgr_transform
lrn_vals = (9, 9.99999974738e-05, 0.5, 1.0)

self.conv1 = nn.Conv2d(
Expand Down Expand Up @@ -125,10 +134,12 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
assert x.min() >= 0.0 and x.max() <= 1.0
x = x.unsqueeze(0) if x.dim() == 3 else x
x = x * 255 - 117
x = x.clone()[:, [2, 1, 0]] # RGB to BGR
x = x[:, [2, 1, 0]] if self.bgr_transform else x
return x

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self, x: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
x = self._transform_input(x)
x = F.pad(x, (2, 3, 2, 3))
x = self.conv1(x)
Expand Down Expand Up @@ -173,7 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.drop(x)
x = self.fc(x)
if not self.aux_logits:
return x
return cast(torch.Tensor, x)
else:
return x, aux1_output, aux2_output

Expand Down
12 changes: 12 additions & 0 deletions tests/optim/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ def test_transform_inceptionv1(self) -> None:
x = torch.randn(1, 3, 224, 224).clamp(0, 1)
model = googlenet(pretrained=True)
output = model._transform_input(x)
expected_output = x * 255 - 117
assertTensorAlmostEqual(self, output, expected_output, 0)

def test_transform_bgr_inceptionv1(self) -> None:
if torch.__version__ <= "1.2.0":
raise unittest.SkipTest(
"Skipping inceptionV1 internal transform"
+ " BGR due to insufficient Torch version."
)
x = torch.randn(1, 3, 224, 224).clamp(0, 1)
model = googlenet(pretrained=True, bgr_transform=True)
output = model._transform_input(x)
expected_output = x[:, [2, 1, 0]] * 255 - 117
assertTensorAlmostEqual(self, output, expected_output, 0)

Expand Down