In [2]:
import torch
import torch.nn as nn
import torchvision

import numpy as np

from typing import Any, Optional, Callable, List

from torchvision.models import MobileNet_V3_Small_Weights as mweights
from torchvision.models import mobilenet_v3_small as mnet

In [16]:
class Deconv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super().__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.deconv(x)
        x = self.bn(x)
        return torch.nn.functional.relu(x, inplace=True)

In [18]:
class UpsampleBlock(nn.Module):
    def __init__( self, in_channels: int, out_channels: int ) -> None:
        super().__init__()
        self.deconv = Deconv2d
        
        self.branch1_1 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
        
        self.branch2_1 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.branch2_2 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1)
        self.branch2_3 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
        
        self.branch3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.branch3_2 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
        
    def forward(self, x):
        branch1 = self.branch1_1(x)
        
        print(branch1.shape)
        
        branch2 = self.branch2_1(x)
        branch2 = self.branch2_2(branch2)
        branch2 = self.branch2_3(branch2)
        
        print(branch2.shape)
        
        branch3 = self.branch3_1(x)
        branch3 = self.branch3_2(branch3)
        
        print(branch3.shape)
        
        outputs = [branch1, branch2, branch3]
        
        return torch.cat(outputs, 1)

In [5]:
class DeconvolutionBlock(nn.Module):
    def __init__( self, in_channels: int, out_channels: int ) -> None:
        super().__init__()
        self.deconv = Deconv2d
        
        self.branch1_1 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
        
        self.branch2_1 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.branch2_2 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1)
        
        self.branch3_1 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.branch3_2 = self.deconv(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1)
        self.branch3_3 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1)
        
        self.branch4_1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.branch4_2 = self.deconv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
        
    def forward(self, x):
        branch1 = self.branch1_1(x)
        
        print(branch1.shape)
        
        branch2 = self.branch2_1(x)
        branch2 = self.branch2_2(branch2)
        
        print(branch2.shape)
        
        branch3 = self.branch3_1(x)
        branch3 = self.branch3_2(branch3)
        branch3 = self.branch3_3(branch3)
        
        print(branch3.shape)
        
        branch4 = self.branch4_1(x)
        branch4 = self.branch4_2(branch4)
        
        print(branch4.shape)
        
        outputs = [branch1, branch2, branch3, branch4]
        
        return torch.cat(outputs, 1)

In [None]:
class DecoderBlock(nn.Module):
    def __init__( self, in_channels: int, out_channels: int ) -> None:
        super().__init__()
        

In [9]:
x = torch.from_numpy(np.random.uniform(low=0.0, high=10.0, size=(3, 128, 128))).float()

In [19]:
dec = UpsampleBlock(in_channels=3, out_channels=1)
# dec = DeconvolutionBlock(in_channels=3, out_channels=1)

In [20]:
dec(x.unsqueeze(0))

ValueError: expected 4D input (got 3D input)

In [38]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

In [39]:
class InceptionA(nn.Module):
    def __init__(
        self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super().__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)

        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)

        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)

    def _forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return outputs

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)

Мейби надо просто поменять InceptionUpsample на простой ConvTarnspose2D... Сейчас там не совпадает размерность слоёв (в первой ветке 128б во второй 130, в третьей 64)

In [40]:
dec = InceptionA(1, 2)

In [42]:
dec(x)

RuntimeError: Given groups=1, weight of size [64, 1, 1, 1], expected input[1, 3, 128, 128] to have 1 channels, but got 3 channels instead