In [4]:
import torch
import torch.nn as nn
import torchvision.models as models

from torchinfo import summary

In [33]:
class DeconvNet(nn.Module):
    def __init__(self, convnet_features: nn.Module):
        super().__init__()
        self.deconv = nn.Sequential()

        for layer in reversed(convnet_features):
            # Deconvnet counterpart of a MaxPool2d : MaxUnpool2d
            if isinstance(layer, nn.MaxPool2d):
                self.deconv.append(
                    nn.MaxUnpool2d(
                        kernel_size=layer.kernel_size,
                        stride=layer.stride,
                        padding=layer.padding
                    )
                )
            # Deconvnet counterpart of a ReLU : ReLU
            elif isinstance(layer, nn.ReLU):
                 self.deconv.append(nn.ReLU())
            # Deconvnet counterpart of a Conv2d : ConvTranspose2d
            elif isinstance(layer, nn.Conv2d):
                t_conv = nn.ConvTranspose2d(
                    in_channels=layer.in_channels,
                    out_channels=layer.out_channels,
                    kernel_size=layer.kernel_size,
                    stride=layer.stride,
                    padding=layer.padding,
                    dilation=layer.dilation
                )
                weights = layer.weight
                with torch.no_grad():
                    t_conv.weight = weights
                self.deconv.append(t_conv)


    def assert_correct_layer_idx(self, idx: int) -> None:
        assert self.is_correct_idx(idx), f"Layer index (idx = {idx}) must be in range [0, {len(self.deconv)})"


    def get_normalized_idx(self, idx: int=-1):
        if idx < 0:
            idx = len(self.deconv) + idx
        
        return idx


    def forward(self, y: torch.tensor, maxpool_indices: list[torch.Tensor], from_layer_idx: int=-1, verbose: bool=True) -> torch.Tensor:
        from_layer_idx = self.get_normalized_idx(from_layer_idx)
        self.assert_correct_layer_idx(from_layer_idx)

        idx_indices = -1
        for idx, layer in enumerate(self.deconv[:from_layer_idx+1]):
            if isinstance(layer, nn.MaxPool2d):
                indices = maxpool_indices[idx_indices]
                idx_indices -= 1
                y = layer(y, indices)
            else:
                y = layer(y)

            if verbose:
                print(y.size(), "after {idx} :", layer)

        return y

    def summary(self, input_batch_size: torch.Size, verbose=True):
        """
        Display summary of the deconv model
        """
        print("input size:", input_batch_size)
        self.model_summary = summary(self, input_size=input_batch_size)
        if verbose:
            print(self.model_summary)


In [24]:
model = models.alexnet(weights='IMAGENET1K_V1')
model_features = model.features

In [25]:
summary(model_features, input_size=torch.Size([1, 3, 224, 224]))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 256, 6, 6]            --
├─Conv2d: 1-1                            [1, 64, 55, 55]           23,296
├─ReLU: 1-2                              [1, 64, 55, 55]           --
├─MaxPool2d: 1-3                         [1, 64, 27, 27]           --
├─Conv2d: 1-4                            [1, 192, 27, 27]          307,392
├─ReLU: 1-5                              [1, 192, 27, 27]          --
├─MaxPool2d: 1-6                         [1, 192, 13, 13]          --
├─Conv2d: 1-7                            [1, 384, 13, 13]          663,936
├─ReLU: 1-8                              [1, 384, 13, 13]          --
├─Conv2d: 1-9                            [1, 256, 13, 13]          884,992
├─ReLU: 1-10                             [1, 256, 13, 13]          --
├─Conv2d: 1-11                           [1, 256, 13, 13]          590,080
├─ReLU: 1-12                             [1, 256, 13, 13]    

In [34]:
deconvnet_model = DeconvNet(model_features)
deconvnet_model

DeconvNet(
  (deconv): Sequential(
    (0): MaxUnpool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0))
    (1): ReLU()
    (2): ConvTranspose2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): MaxUnpool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0))
    (8): ReLU()
    (9): ConvTranspose2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (10): MaxUnpool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0))
    (11): ReLU()
    (12): ConvTranspose2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
  )
)

In [38]:
class Vgg16Conv(nn.Module):
    """
    vgg16 convolution network architecture
    """

    def __init__(self, num_cls=1000):
        """
        Input
            number of class, default is 1k.
        """
        super(Vgg16Conv, self).__init__()
    
        self.features = nn.Sequential(
            # conv1
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, return_indices=True),
            
            # conv2
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, return_indices=True),

            # conv3
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, return_indices=True),

            # conv4
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, return_indices=True),

            # conv5
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, return_indices=True)
        )

    
        # index of conv
        self.conv_layer_indices = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]

In [18]:
class Vgg16Deconv(nn.Module):
    """
    vgg16 transpose convolution network architecture
    """
    def __init__(self):
        super(Vgg16Deconv, self).__init__()

        self.features = nn.Sequential(
            # deconv1
            nn.MaxUnpool2d(2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, 3, padding=1),

            # deconv2
            nn.MaxUnpool2d(2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 3, padding=1),
            
            # deconv3
            nn.MaxUnpool2d(2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 3, padding=1),
            
            # deconv4
            nn.MaxUnpool2d(2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, padding=1),
            
            # deconv5
            nn.MaxUnpool2d(2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 3, padding=1)    
        )

In [19]:
try_model = Vgg16Deconv()

In [20]:
summary(try_model)

Layer (type:depth-idx)                   Param #
Vgg16Deconv                              --
├─Sequential: 1-1                        --
│    └─MaxUnpool2d: 2-1                  --
│    └─ReLU: 2-2                         --
│    └─ConvTranspose2d: 2-3              2,359,808
│    └─ReLU: 2-4                         --
│    └─ConvTranspose2d: 2-5              2,359,808
│    └─ReLU: 2-6                         --
│    └─ConvTranspose2d: 2-7              2,359,808
│    └─MaxUnpool2d: 2-8                  --
│    └─ReLU: 2-9                         --
│    └─ConvTranspose2d: 2-10             2,359,808
│    └─ReLU: 2-11                        --
│    └─ConvTranspose2d: 2-12             2,359,808
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             1,179,904
│    └─MaxUnpool2d: 2-15                 --
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             590,080
│    └─ReLU: 2-18                        --
│    └─ConvTranspose2d: 