In [2]:
import torch
import torch.nn as nn
import torchvision

import numpy as np

from typing import Any

from torchvision.models import MobileNet_V3_Small_Weights as mweights
from torchvision.models import mobilenet_v3_small as mnet

In [16]:
class Deconv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super().__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.deconv(x)
        x = self.bn(x)
        return torch.nn.functional.relu(x, inplace=True)

In [17]:
class UpsampleBlock(nn.Module):
    def __init__( self, in_channels: int, out_channels: int ) -> None:
        super().__init__()
        self.deconv = Deconv2d
        
        self.branch1_1 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
        
        self.branch2_1 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.branch2_2 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1)
        self.branch2_3 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
        
        self.branch3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.branch3_2 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
        
    def forward(self, x):
        branch1 = self.branch1_1(x)
        
        print(branch1.shape)
        
        branch2 = self.branch2_1(x)
        branch2 = self.branch2_2(branch2)
        branch2 = self.branch2_3(branch2)
        
        print(branch2.shape)
        
        branch3 = self.branch3_1(x)
        branch3 = self.branch3_2(branch3)
        
        print(branch3.shape)
        
        outputs = [branch1, branch2, branch3]
        
        return torch.cat(outputs, 1)

In [24]:
class DeconvolutionBlock(nn.Module):
    def __init__( self, in_channels: int, out_channels: int ) -> None:
        super().__init__()
        self.deconv = Deconv2d
        
        self.branch1_1 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
        
        self.branch2_1 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.branch2_2 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1)
        
        self.branch3_1 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.branch3_2 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1)
        self.branch3_3 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1)
        
        self.branch4_1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.branch4_2 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
        
    def forward(self, x):
        branch1 = self.branch1_1(x)
        
        print(branch1.shape)
        
        branch2 = self.branch2_1(x)
        branch2 = self.branch2_2(branch2)
        
        print(branch2.shape)
        
        branch3 = self.branch3_1(x)
        branch3 = self.branch3_2(branch3)
        branch3 = self.branch3_3(branch3)
        
        print(branch3.shape)
        
        branch4 = self.branch4_1(x)
        branch4 = self.branch4_2(branch4)
        
        print(branch4.shape)
        
        outputs = [branch1, branch2, branch3, branch4]
        
        return torch.cat(outputs, 1)

In [25]:
x = torch.from_numpy(np.random.uniform(low=0.0, high=10.0, size=(3, 128, 128))).float()

In [26]:
dec = UpsampleBlock(in_channels=3, out_channels=1)
dec = DeconvolutionBlock(in_channels=3, out_channels=1)

In [27]:
dec(x.unsqueeze(0))

torch.Size([1, 1, 128, 128])
torch.Size([1, 1, 130, 130])
torch.Size([1, 1, 132, 132])
torch.Size([1, 1, 64, 64])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 128 but got size 130 for tensor number 1 in the list.