In [1]:
import numpy as np
import torch
from models.resnet import *
def param_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
eg_input = torch.randn(16, 1, 28, 28)

In [2]:
classifier = Classifier((1, 28, 28))
print(classifier)
print('Number of parameters: {}'.format(param_count(classifier)))
eg_output = classifier(eg_input)
print('Classifier: {} -> {}'.format(eg_input.shape, eg_output.shape))

Classifier(
  (feature_extractor): FeatureExtractor(
    (input_transform): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (model): Sequential(
      (0): ResidualBlock(
        (residual_connection): Sequential(
          (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU(inplace=True)
          (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
        (skip_connection): Sequential(
          (0): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2))
        )
      )
      (1): ResidualBlock(
        (residual_connection): Sequential(
          (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): 

In [3]:
discriminator = Discriminator((1, 28, 28))
print(discriminator)
print('Number of parameters: {}'.format(param_count(discriminator)))
eg_features = discriminator.get_features(eg_input)
print('Discriminator feature extractor: {} -> {}'.format(eg_input.shape, eg_features.shape))
eg_realism_classification = discriminator.classify_realism(eg_features)
print('Discrimintor realism classifier: {} -> {}'.format(eg_features.shape, eg_realism_classification.shape))
eg_label_classification = discriminator.classify_label(eg_features)
print('Discriminator label classifier: {} -> {}'.format(eg_features.shape, eg_label_classification.shape))

Discriminator(
  (feature_extractor): FeatureExtractor(
    (input_transform): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (model): Sequential(
      (0): ResidualBlock(
        (residual_connection): Sequential(
          (0): LeakyReLU(negative_slope=0.1)
          (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): LeakyReLU(negative_slope=0.1)
          (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
        (skip_connection): Sequential(
          (0): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2))
        )
      )
      (1): ResidualBlock(
        (residual_connection): Sequential(
          (0): LeakyReLU(negative_slope=0.1)
          (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): LeakyReLU(negative_slope=0.1)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
        (skip_connection): Sequential(
        

In [4]:
generator = Generator((1, 28, 28))
print(generator)
print('Number of parameters: {}'.format(param_count(generator)))
eg_latent = torch.randn(16, generator.latent_features)
eg_output = generator(eg_latent)
print('Generator: {} -> {}'.format(eg_latent.shape, eg_output.shape))

Generator(
  (feature_reconstructor): FeatureReconstructor(
    (input_transform): ConvTranspose2d(64, 64, kernel_size=(7, 7), stride=(1, 1))
    (model): Sequential(
      (0): ResidualBlock(
        (residual_connection): Sequential(
          (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU(inplace=True)
          (5): ConvTranspose2d(32, 32, kernel_size=(2, 2), stride=(2, 2))
        )
        (skip_connection): Sequential(
          (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
        )
      )
      (1): ResidualBlock(
        (residual_connection): Sequential(
          (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2

In [5]:
autoencoder = Autoencoder((1, 28, 28))
print(autoencoder)
print('Number of parameters: {}'.format(param_count(autoencoder)))
eg_features = autoencoder.get_features(eg_input)
print('Autoencoder feature extractor: {} -> {}'.format(eg_input.shape, eg_features.shape))
eg_reconstruction = autoencoder.reconstruct_features(eg_features)
print('Autoencoder feature reconstructor: {} -> {}'.format(eg_features.shape, eg_reconstruction.shape))
eg_label_classification = autoencoder.classify_labels(eg_features)
print('Autoencoder label classifier: {} -> {}'.format(eg_features.shape, eg_label_classification.shape))

Autoencoder(
  (feature_extractor): FeatureExtractor(
    (input_transform): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (model): Sequential(
      (0): ResidualBlock(
        (residual_connection): Sequential(
          (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU(inplace=True)
          (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
        (skip_connection): Sequential(
          (0): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2))
        )
      )
      (1): ResidualBlock(
        (residual_connection): Sequential(
          (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2):