In [1]:
import sys
sys.path.append('..')
import torch
from simple_gan_1d import LeNet5Generator1d, LeNet5Discriminator1d
from common import get_param_count

In [2]:
batch_size = 16
output_classes = 256
latent_dims = 64
label_dims = 64
feature_dims = 64
latent_shape = (batch_size, latent_dims)
label_shape = (batch_size,)
output_shape = (batch_size, 1, 20000)

gen = LeNet5Generator1d(
    latent_dims=latent_dims,
    label_dims=label_dims,
    output_shape=output_shape,
    feature_dims=feature_dims,
    output_classes=output_classes
)

eg_latent = torch.randn(latent_shape)
eg_label = torch.randint(0, output_classes, label_shape)
eg_image = torch.randn(output_shape)
eg_output = gen(eg_latent, eg_label, eg_image)
print('Gen:', eg_latent.shape, 'x', eg_label.shape, 'x', eg_image.shape, '->', eg_output.shape)
print('Generator parameters:', get_param_count(gen))
print(gen)

Gen: torch.Size([16, 64]) x torch.Size([16]) x torch.Size([16, 1, 20000]) -> torch.Size([16, 1, 20000])
Generator parameters: 636265
LeNet5Generator1d(
  (feature_encoder): Sequential(
    (0): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (1): Conv1d(1, 16, kernel_size=(5,), stride=(5,), padding=(2,), bias=False)
    (2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): Conv1d(16, 32, kernel_size=(5,), stride=(5,), padding=(2,), bias=False)
    (5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Conv1d(32, 64, kernel_size=(5,), stride=(5,), padding=(2,))
  )
  (latent_encoder): Sequential(
    (0): Conv1d(64, 512, kernel_size=(1,), stride=(1,), bias=False)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv1d(512, 64, kernel_size=(1,), stride=(1,))
    (4): ConvTranspose1d(

In [3]:
disc = LeNet5Discriminator1d(output_shape)

eg_input = torch.randn(output_shape)
print('Disc:', output_shape, '->', disc(eg_input).shape)
print('Discriminator param count:', get_param_count(disc))
print(disc)

Disc: (16, 1, 20000) -> torch.Size([16, 256])
Discriminator param count: 379526
LeNet5Discriminator1d(
  (output_transform): Identity()
  (feature_encoder): Sequential(
    (0): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (1): Conv1d(1, 2, kernel_size=(5,), stride=(3,), bias=False)
    (2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Conv1d(2, 4, kernel_size=(5,), stride=(3,), bias=False)
    (5): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.2)
    (7): Conv1d(4, 8, kernel_size=(5,), stride=(3,))
    (8): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): LeakyReLU(negative_slope=0.2)
    (10): Conv1d(8, 16, kernel_size=(5,), stride=(3,))
  )
  (mlp_probe): Sequential(
    (0): Linear(in_features=960, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05

import sys
sys.path.append('..')
import torch
from resnet1d_gan import ResNet1dDiscriminator, ResNet1dGenerator
from common import get_param_count

input_shape = (64, 1, 20000)
filters = 8
disc = ResNet1dDiscriminator(input_shape, 8)
print(disc)

eg_input = torch.randn(input_shape)
eg_output = disc(eg_input)
print('Input shape:', eg_input.shape, '-> Output shape:', eg_output.shape)
print('Param count:', get_param_count(disc))

label_dims = 256
latent_dims = 100
output_shape = (64, 1, 20000)

gen = ResNet1dGenerator(latent_dims, label_dims, output_shape)
print(gen)

eg_label = torch.randint(0, 256, (64,))
eg_latent = torch.randn(64, latent_dims)
eg_trace = torch.randn(64, *output_shape[1:])
eg_output = gen(eg_latent, eg_label, eg_trace)
print('Label shape:', eg_label.shape.__repr__()+',', 'Latent shape:', eg_latent.shape, '-> Output shape:', eg_output.shape)
print('Param count:', get_param_count(gen))

import numpy as np
import torch
import sys
sys.path.append('..')
from lenet_gan import LeNet5Gen, LeNet5Disc
from common import get_param_count

input_shape = (64, 1, 28, 28)
disc = LeNet5Disc(input_shape)
print(disc)
print('Disc param count:', get_param_count(disc))

eg_input = torch.randn(input_shape)
eg_output = disc(eg_input)
print('Map dimensions: ({}) -> ({})'.format(eg_input.shape, eg_output.shape))

latent_dims = 100
label_dims = 10
feature_dims = 0
output_shape = (64, 1, 28, 28)
gen = LeNet5Gen(latent_dims, label_dims, feature_dims, output_shape)
print(gen)
print('Gen param count:', get_param_count(gen))

eg_latent = torch.randn((64, latent_dims))
eg_label = torch.randint(0, 10, (64,))
eg_output = gen(eg_latent, eg_label)
print('Map dimensions: ({}) x ({}) -> ({})'.format(eg_latent.shape, eg_label.shape, eg_output.shape))