In [16]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm

from modeling.base import BaseNetwork
from modules.blocks import ResBlock, ConvBlock, PCBlock


class MPN(BaseNetwork):
    def __init__(self, base_n_channels, neck_n_channels):
        super(MPN, self).__init__()
        assert base_n_channels >= 4, "Base num channels should be at least 4"
        assert neck_n_channels >= 16, "Neck num channels should be at least 16"
        self.rb1 = ResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=2, padding=2, dilation=1)
        self.rb2 = ResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2)
        self.rb3 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=2, dilation=2)
        self.rb4 = ResBlock(channels_in=base_n_channels * 2, channels_out=neck_n_channels, kernel_size=3, stride=1, padding=4, dilation=4)

        self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0)

        self.rb5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1)
        self.rb6 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1)
        self.rb7 = ResBlock(channels_in=base_n_channels, channels_out=base_n_channels // 2, kernel_size=3, stride=1)

        self.cb1 = ConvBlock(channels_in=base_n_channels // 2, channels_out=base_n_channels // 4, kernel_size=3, stride=1, padding=1)
        self.conv1 = nn.Conv2d(base_n_channels // 4, 1, kernel_size=3, stride=1, padding=1)

        self.init_weights(init_type="normal", gain=0.02)

    def forward(self, x):
        out = self.rb1(x)
        out = self.rb2(out)
        # mid_out = out
        out = self.rb3(out)
        mid_out = out
        neck = self.rb4(out)
        # bottleneck here

        out = self.rb5(neck)
        out = self.upsample(out)
        out = self.rb6(out)
        out = self.upsample(out)
        out = self.rb7(out)

        out = self.cb1(out)
        out = self.conv1(out)

        return torch.sigmoid(out), neck,mid_out

In [17]:
mpn = MPN(64, 128)
inp = torch.rand((2, 3, 256, 256))
out, neck,mid_out = mpn(inp)
# print(out.shape,neck.shape)
print(mid_out.shape)

torch.Size([2, 128, 64, 64])
