In [1]:
import torch
from torch import nn

In [2]:
config = torch.load('style_vangogh.pth')

In [3]:
x = sum(t.numel() for t in config.values())
print(f"Numparams is {x:,}")

Numparams is 11,388,698


In [4]:
class DiscConvBlock(nn.Module):
    def __init__(self, channels_in, channels_out, stride=2, is_first=False):
        super(DiscConvBlock, self).__init__()
        block = (
            nn.Conv2d(channels_in, channels_out, kernel_size=4, stride=stride, padding=1),
            nn.InstanceNorm2d(channels_out),
            nn.LeakyReLU(0.2, True),
        )
        if is_first: # remove the second element
            block = block[0], block[1]
        self.block = nn.Sequential(*block)
        
    def forward(self, x):
        return self.block(x)
    
x = torch.rand(2, 3, 256, 256)
"""
    Note that there is instance norm for first block
    and the stride is 2 for first 3 blocks and 1 for the last block.
    This is followed by a huge conv layer with again stride=1
"""
model = nn.Sequential(
    DiscConvBlock(3, 64, is_first=True),
    DiscConvBlock(64, 128),
    DiscConvBlock(128, 256),
    DiscConvBlock(256, 512, stride=1),
    # last block uses 1 channel conv
    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
)

print(model)

Sequential(
  (0): DiscConvBlock(
    (block): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
  )
  (1): DiscConvBlock(
    (block): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (2): DiscConvBlock(
    (block): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (3): DiscConvBlock(
    (block): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, af

In [5]:
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        use_bias = True

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)
    
test = NLayerDiscriminator(input_nc=3)

In [6]:
x = torch.rand(2, 3, 256, 256)
with torch.no_grad():
    print(model(x).shape)
    
y = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{y:,}")

torch.Size([2, 1, 30, 30])
2,764,737


In [7]:
x = torch.rand(2, 3, 256, 256)
with torch.no_grad():
    print(test(x).shape)
    
y = sum(p.numel() for p in test.parameters() if p.requires_grad)
print(f"{y:,}")

torch.Size([2, 1, 30, 30])
2,764,737
