In [1]:
import torch.nn as nn
import torch

In [14]:
class Discriminator(nn.Module):
    def __init__(self, num_channels):
        super(Discriminator, self).__init__()
        final = [
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
            ]
        self.model = nn.Sequential(
            *Discriminator.__get_layer__(input=2*num_channels, out=64, batch_norm=False),
            *Discriminator.__get_layer__(input=64, out=128),
            *Discriminator.__get_layer__(input=128, out=256),
            *Discriminator.__get_layer__(input=256, out=512),
            *final
        )
    @staticmethod
    def __get_layer__(input, out, batch_norm = True):
        layer = [
            nn.Conv2d(in_channels=input, out_channels=out, kernel_size=4 , stride=2, padding=1)
        ]
        if batch_norm:
            layer.append(nn.BatchNorm2d(out))
        layer.append(nn.LeakyReLU(0.2, True))
        return layer
    
    def forward(self, x, y):
        img = torch.cat((x, y), 1)
        return self.model(img)

In [15]:
img1 = torch.randn((1,3, 256, 256))
img2 = torch.randn((1, 3, 256, 256))

patch_gan = Discriminator(num_channels=3)
output = patch_gan(img1, img2)

In [16]:
output

tensor([[[[0.4036, 0.6132, 0.4998, 0.5300, 0.4537, 0.4514, 0.5188, 0.5621,
           0.4457, 0.5408, 0.5318, 0.6400, 0.5848, 0.5604, 0.4199],
          [0.4678, 0.4430, 0.4779, 0.6826, 0.5238, 0.5069, 0.5166, 0.7046,
           0.5179, 0.5952, 0.4643, 0.4126, 0.5965, 0.5129, 0.5570],
          [0.6604, 0.6185, 0.3650, 0.6708, 0.5695, 0.5298, 0.3956, 0.4805,
           0.6111, 0.6504, 0.6385, 0.5921, 0.6212, 0.4925, 0.6241],
          [0.4215, 0.4674, 0.4605, 0.5745, 0.6112, 0.5112, 0.6797, 0.6332,
           0.3593, 0.5927, 0.7339, 0.6391, 0.5455, 0.3463, 0.6161],
          [0.4309, 0.4570, 0.3709, 0.5365, 0.3193, 0.4391, 0.5782, 0.5844,
           0.6096, 0.5631, 0.6432, 0.5907, 0.5537, 0.5342, 0.5293],
          [0.5860, 0.5246, 0.5245, 0.3016, 0.4829, 0.5040, 0.4897, 0.6082,
           0.6044, 0.5403, 0.4087, 0.3807, 0.4591, 0.6457, 0.5147],
          [0.4443, 0.3773, 0.5587, 0.5792, 0.5193, 0.5986, 0.5971, 0.4880,
           0.6743, 0.4909, 0.6089, 0.4673, 0.5567, 0.6224, 0.5181],