In [3]:
import torch
import torch.nn as nn

class MergeLayer(nn.Module):
    """merge input height with image"""
    
    def __init__(self, in_channels, out_channels):
        super(MergeLayer, self).__init__()
        self.h_layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(1e-2)
        )
        self.i_layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(1e-2)
        )
        
    def forward(self, x_h, x_i):
        output_h = self.h_layer(x_h)
        output_i = self.i_layer(x_i)
        output = torch.cat((output_h, output_i), 1)
        return output
    
ml = MergeLayer(1, 64)
xh = torch.randn(1, 1, 256, 256)
xi = torch.randn(1, 1, 256, 256)
output = ml(xh, xi)
output.shape

torch.Size([1, 128, 256, 256])

In [4]:
class ConvBlock_(nn.Module):
    """Conv(3x3)-BN-LReLU"""
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(ConvBlock_, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(1e-2)
        )
        
    def forward(self, x):
        output = self.conv(x)
        return output
    
input = torch.randn(1, 64, 256, 256)
conv = ConvBlock_(64, 64, stride=2)
output = conv(input)
output.shape

torch.Size([1, 64, 128, 128])

In [16]:
class Discriminator(nn.Module):
    """Height Layer + Image Layer-BN"""
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.merge = MergeLayer(1, 64)
        self.convs = nn.Sequential(
            ConvBlock_(128, 64),
            ConvBlock_(64, 64, 2),
            ConvBlock_(64, 128),
            ConvBlock_(128, 128, 2),
            ConvBlock_(128, 256),
            ConvBlock_(256, 256, 2),
            ConvBlock_(256, 512),
            ConvBlock_(512, 512, 2)
        )
        self.fc1 = nn.Linear(512 * 16 * 16, 200)
        self.lrelu = nn.LeakyReLU(1e-2)
        self.fc2 = nn.Linear(200, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x_h, x_i):
        merged = self.merge(x_h, x_i)
        print(merged.shape)
        output = self.convs(merged)
        
        output = output.view(-1, 512 * 16 * 16)
        print(output.shape)
        vector1 = self.fc1(output)
        relu = self.lrelu(vector1)
        vector2 = self.fc2(relu)
        scalar = self.sigmoid(vector2)
        return scalar
    
input1 = torch.randn(16, 1, 256, 256)
input2 = torch.randn(16, 1, 256, 256)
dis = Discriminator()
output = dis(input1, input2)
output

torch.Size([16, 128, 256, 256])
torch.Size([16, 131072])


tensor([[0.5288],
        [0.5114],
        [0.5800],
        [0.5399],
        [0.5423],
        [0.4914],
        [0.4976],
        [0.5388],
        [0.5264],
        [0.5245],
        [0.5017],
        [0.5020],
        [0.5283],
        [0.5315],
        [0.5310],
        [0.5033]], grad_fn=<SigmoidBackward0>)