In [1]:
import numpy as np
import torch
from models.unet_v2 import Discriminator

discriminator = Discriminator((1, 28, 28))
print('Discriminator with {} parameters:'.format(sum(p.numel() for p in discriminator.parameters() if p.requires_grad)))
print(discriminator)
eg_input = torch.randn(16, 1, 28, 28)
eg_features = discriminator.extract_features(eg_input)
print('Disc. feature extractor: {} -> {}'.format(eg_input.shape, eg_features.shape))
eg_features_realism = discriminator.get_realism_features(eg_features)
eg_output_realism = discriminator.classify_realism(eg_features_realism)
print('Disc. realism classifier: {} -> {} -> {}'.format(eg_features.shape, eg_features_realism.shape, eg_output_realism.shape))
eg_features_leakage = discriminator.get_leakage_features(eg_features)
eg_output_leakage = discriminator.classify_leakage(eg_features_leakage)
print('Disc. leakage classifier: {} -> {} -> {}'.format(eg_features.shape, eg_features_leakage.shape, eg_output_leakage.shape))

Discriminator with 43503 parameters:
Discriminator(
  (input_transform): Sequential(
    (0): Conv2d(1, 12, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3)
  )
  (feature_extractor): Sequential(
    (0): DiscriminatorBlock(
      (residual_connection): Sequential(
        (0): LeakyReLU(negative_slope=0.1)
        (1): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3)
        (2): LeakyReLU(negative_slope=0.1)
        (3): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=3)
      )
      (skip_connection): Sequential(
        (0): Conv2d(12, 24, kernel_size=(2, 2), stride=(2, 2), groups=3)
      )
    )
    (1): DiscriminatorBlock(
      (residual_connection): Sequential(
        (0): LeakyReLU(negative_slope=0.1)
        (1): Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3)
        (2): LeakyReLU(negative_slope=0.1)
        (3): Conv2d(48

In [2]:
import numpy as np
import torch
from models.unet_v2 import Generator

generator = Generator((1, 28, 28))
print('Generator with {} parameters:'.format(sum(p.numel() for p in generator.parameters() if p.requires_grad)))
print(generator)
eg_input = torch.randn(16, 1, 28, 28)
eg_output = generator(eg_input)
print('Generator: {} -> {}'.format(eg_input.shape, eg_output.shape))

Generator with 50054 parameters:
Generator(
  (input_transform): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (resampler): WrapWithResampler(
    (resample_path): Sequential(
      (0): GeneratorBlock(
        (residual_connection): Sequential(
          (0): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU(inplace=True)
          (5): Conv2d(16, 16, kernel_size=(2, 2), stride=(2, 2))
        )
        (skip_connection): Sequential(
          (0): Conv2d(8, 16, kernel_size=(2, 2), stride=(2, 2))
        )
      )
      (1): WrapWithResampler(
        (resample_path): Sequential(
          (0): GeneratorBlock(
            (residual_connection): Sequential(
              (0): BatchNorm2d(16, eps=1e-05, momentum=0