From 735eee0d0f746c8cb05ad1843a7ddf894e983111 Mon Sep 17 00:00:00 2001 From: Emanuele Ballarin Date: Thu, 9 May 2024 15:05:33 +0200 Subject: [PATCH] Allow even more flexible WideResNet/PreActResNet architectures (DeepMind style) Signed-off-by: Emanuele Ballarin --- ebtorch/nn/architectures_resnets_dm.py | 14 +++++++++++--- setup.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ebtorch/nn/architectures_resnets_dm.py b/ebtorch/nn/architectures_resnets_dm.py index 3b37735..997a7b8 100644 --- a/ebtorch/nn/architectures_resnets_dm.py +++ b/ebtorch/nn/architectures_resnets_dm.py @@ -236,6 +236,7 @@ def __init__( padding: int = 0, num_input_channels: int = 3, bn_momentum: float = 0.1, + autopool: bool = False, ) -> None: super().__init__() self.mean: torch.Tensor = torch.tensor(mean).view(num_input_channels, 1, 1) @@ -243,7 +244,7 @@ def __init__( self.mean_cuda: Optional[torch.Tensor] = None self.std_cuda: Optional[torch.Tensor] = None self.padding: int = padding - num_channels: List[int, int, int, int] = [ + num_channels: List[int] = [ 16, 16 * width, 32 * width, @@ -294,6 +295,9 @@ def __init__( self.relu: nn.Module = activation_fn() self.logits: nn.Module = nn.Linear(num_channels[3], num_classes) self.num_channels: int = num_channels[3] + self.pooling: nn.Module = ( + nn.AdaptiveAvgPool2d((1, 1)) if autopool else nn.AvgPool2d(8) + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x: torch.Tensor = _auto_pad_to_same(self, x) @@ -301,7 +305,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = self.init_conv(out) out: torch.Tensor = self.layer(out) out: torch.Tensor = self.relu(self.batchnorm(out)) - out: torch.Tensor = F.avg_pool2d(out, 8) + out: torch.Tensor = self.pooling(out) out: torch.Tensor = out.view(-1, self.num_channels) return self.logits(out) @@ -320,6 +324,7 @@ def __init__( padding: int = 0, num_input_channels: int = 3, bn_momentum: float = 0.1, + autopool: bool = False, ) -> None: super().__init__() if width != 0: @@ -380,6 +385,9 @@ def __init__( ) self.relu: nn.Module = activation_fn() self.logits: nn.Module = nn.Linear(in_features=512, out_features=num_classes) + self.pooling: nn.Module = ( + nn.AdaptiveAvgPool2d((1, 1)) if autopool else nn.AvgPool2d(4) + ) def _make_layer( # Do not make static. self, @@ -412,6 +420,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = self.layer_2(out) out: torch.Tensor = self.layer_3(out) out: torch.Tensor = self.relu(self.batchnorm(out)) - out: torch.Tensor = F.avg_pool2d(out, 4) + out: torch.Tensor = self.pooling(out) out: torch.Tensor = out.view(out.size(0), -1) return self.logits(out) diff --git a/setup.py b/setup.py index fc82f82..72290e6 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def read(fname): setup( name=PACKAGENAME, - version="0.24.1", + version="0.24.2", author="Emanuele Ballarin", author_email="emanuele@ballarin.cc", url="https://github.com/emaballarin/ebtorch",