In [3]:
import torch
import torch.nn as nn
import torchvision

import numpy as np

from typing import Any

from torchvision.models import MobileNet_V3_Small_Weights as mweights
from torchvision.models import mobilenet_v3_small as mnet

In [5]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

In [39]:
class DecoderBlock(nn.Module):
    def __init__( self, in_channels: int, out_channels: int ) -> None:
        super().__init__()
        self.deconv = nn.ConvTranspose2d
        
        self.branch1_1 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
        
        self.branch2_1 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.branch2_2 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1)
        self.branch2_3 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
        
        self.branch3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.branch3_2 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
        
    def forward(self, x):
        branch1 = self.branch1_1(x)
        
        print(branch1.shape)
        
        branch2 = self.branch2_1(x)
        branch2 = self.branch2_2(branch2)
        branch2 = self.branch2_3(branch2)
        
        print(branch2.shape)
        
        branch3 = self.branch3_1(x)
        branch3 = self.branch3_2(branch3)
        
        print(branch3.shape)
        
        outputs = [branch1, branch2, branch3]
        
        return torch.cat(outputs, 1)

In [40]:
x = torch.from_numpy(np.random.uniform(low=0.0, high=10.0, size=(3, 128, 128))).float()

In [41]:
dec = DecoderBlock(in_channels=3, out_channels=1)

In [42]:
dec(x)

torch.Size([1, 256, 256])
torch.Size([1, 260, 260])
torch.Size([1, 64, 64])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 256 but got size 260 for tensor number 1 in the list.