Skip to content

Commit

Permalink
Allow even more flexible WideResNet/PreActResNet architectures (DeepM…
Browse files Browse the repository at this point in the history
…ind style)

Signed-off-by: Emanuele Ballarin <emanuele@ballarin.cc>
  • Loading branch information
emaballarin committed May 9, 2024
1 parent d86c0f4 commit 735eee0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
14 changes: 11 additions & 3 deletions ebtorch/nn/architectures_resnets_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,15 @@ def __init__(
padding: int = 0,
num_input_channels: int = 3,
bn_momentum: float = 0.1,
autopool: bool = False,
) -> None:
super().__init__()
self.mean: torch.Tensor = torch.tensor(mean).view(num_input_channels, 1, 1)
self.std: torch.Tensor = torch.tensor(std).view(num_input_channels, 1, 1)
self.mean_cuda: Optional[torch.Tensor] = None
self.std_cuda: Optional[torch.Tensor] = None
self.padding: int = padding
num_channels: List[int, int, int, int] = [
num_channels: List[int] = [
16,
16 * width,
32 * width,
Expand Down Expand Up @@ -294,14 +295,17 @@ def __init__(
self.relu: nn.Module = activation_fn()
self.logits: nn.Module = nn.Linear(num_channels[3], num_classes)
self.num_channels: int = num_channels[3]
self.pooling: nn.Module = (
nn.AdaptiveAvgPool2d((1, 1)) if autopool else nn.AvgPool2d(8)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x: torch.Tensor = _auto_pad_to_same(self, x)
out: torch.Tensor = _auto_cuda_mean_handler(self, x)
out: torch.Tensor = self.init_conv(out)
out: torch.Tensor = self.layer(out)
out: torch.Tensor = self.relu(self.batchnorm(out))
out: torch.Tensor = F.avg_pool2d(out, 8)
out: torch.Tensor = self.pooling(out)
out: torch.Tensor = out.view(-1, self.num_channels)
return self.logits(out)

Expand All @@ -320,6 +324,7 @@ def __init__(
padding: int = 0,
num_input_channels: int = 3,
bn_momentum: float = 0.1,
autopool: bool = False,
) -> None:
super().__init__()
if width != 0:
Expand Down Expand Up @@ -380,6 +385,9 @@ def __init__(
)
self.relu: nn.Module = activation_fn()
self.logits: nn.Module = nn.Linear(in_features=512, out_features=num_classes)
self.pooling: nn.Module = (
nn.AdaptiveAvgPool2d((1, 1)) if autopool else nn.AvgPool2d(4)
)

def _make_layer( # Do not make static.
self,
Expand Down Expand Up @@ -412,6 +420,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = self.layer_2(out)
out: torch.Tensor = self.layer_3(out)
out: torch.Tensor = self.relu(self.batchnorm(out))
out: torch.Tensor = F.avg_pool2d(out, 4)
out: torch.Tensor = self.pooling(out)
out: torch.Tensor = out.view(out.size(0), -1)
return self.logits(out)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.24.1",
version="0.24.2",
author="Emanuele Ballarin",
author_email="emanuele@ballarin.cc",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit 735eee0

Please sign in to comment.