In [1]:
import torch.nn as nn
import torch

# Arhitektura diskriminatora

C64-C128-C256-C512 gde BatchNorm sloj nije primenjen nad prvim slojem.

In [9]:
class Discriminator(nn.Module):
    def __init__(self, num_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            *Discriminator.__get_layer__(input=2*num_channels, out=64, batch_norm=False),
            *Discriminator.__get_layer__(input=64, out=128),
            *Discriminator.__get_layer__(input=128, out=256),
            *Discriminator.__get_layer__(input=256, out=512),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )
    @staticmethod
    def __get_layer__(input, out, batch_norm = True):
        layer = [
            nn.Conv2d(in_channels=input, out_channels=out, kernel_size=4 , stride=2, padding=1, bias=False)
        ]
        if batch_norm:
            layer.append(nn.BatchNorm2d(out))
        layer.append(nn.LeakyReLU(0.2, True))
        return layer
    
    def forward(self, x, y):
        img = torch.cat((x, y), 1)
        return self.model(img)

In [10]:
img1 = torch.randn((1,3, 256, 256))
img2 = torch.randn((1, 3, 256, 256))

patch_gan = Discriminator(num_channels=3)
output = patch_gan(img1, img2)

In [11]:
output

tensor([[[[ 1.9838e-01, -8.0028e-02,  1.8873e-02,  3.9306e-01,  8.0724e-01,
            3.6742e-01, -7.0147e-02,  4.4478e-01, -5.3935e-01,  3.9032e-01,
            1.7190e-01, -2.8061e-01,  1.7764e-01,  3.1836e-01,  1.5427e-01],
          [ 1.1143e-01,  2.7604e-01,  2.2894e-01, -9.1022e-02, -4.6438e-01,
           -2.6418e-01, -1.5430e-01,  3.1231e-01,  5.9658e-01,  5.0102e-02,
            2.2623e-01,  4.6633e-01, -2.0086e-01, -2.6046e-01, -2.4412e-01],
          [ 6.3771e-01, -2.8400e-01,  4.0167e-01,  5.5317e-01,  2.3813e-01,
            6.0196e-03,  2.9869e-01,  4.3105e-01,  4.7958e-01,  2.3751e-01,
           -3.8184e-01,  2.9304e-02,  1.9502e-01, -2.5480e-01, -2.5725e-01],
          [ 9.1617e-02, -2.0918e-01,  4.7797e-01,  6.8434e-01, -1.8365e-01,
           -1.0697e-01, -8.7468e-01,  4.2869e-01, -3.0148e-01,  1.0411e-01,
            1.2887e-01,  2.4301e-01,  3.0623e-01,  3.2560e-02, -3.7634e-01],
          [-3.6572e-01,  5.0002e-01,  1.4906e-01,  4.7910e-01, -5.0538e-01,
        