First, preprocess data and define first steps

In [9]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision 
from torchvision import transforms 


print(torch.__version__)
print("GPU Available:", torch.cuda.is_available())

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = "cpu" 

image_path = './'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))
])
mnist_dataset = torchvision.datasets.MNIST(root=image_path, 
                                           train=True, 
                                           transform=transform, 
                                           download=True)

batch_size = 64

torch.manual_seed(1)
np.random.seed(1)

## Set up the dataset
from torch.utils.data import DataLoader
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size, 
                      shuffle=True, drop_last=True)



2.7.0+cpu
GPU Available: False


Now define starting GAN - based on the sample code provided by Professor

In [10]:
def make_generator_network(input_size, n_filters):
    model = nn.Sequential(
        nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0, 
                           bias=False),
        nn.BatchNorm2d(n_filters*4),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
        nn.BatchNorm2d(n_filters*2),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
        nn.BatchNorm2d(n_filters),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
        nn.Tanh())
    return model

class Discriminator(nn.Module):
    def __init__(self, n_filters):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 2),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters*4),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid())
        
    def forward(self, input):
        output = self.network(input)
        return output.view(-1, 1).squeeze(0)

Remove one convolutional layer in both the generator and the discriminator.

In [11]:
class ReducedGenerator(nn.Module):
    def __init__(self, input_size, n_filters):
        super().__init__()
        self.network = nn.Sequential(
            # z -> (n_filters*4, 4, 4)
            nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(n_filters*4),
            nn.LeakyReLU(0.2),

            # (n_filters*4, 4, 4) -> (n_filters, 14, 14)
            nn.ConvTranspose2d(n_filters*4, n_filters, 5, 3, 1, bias=False),
            nn.BatchNorm2d(n_filters),
            nn.LeakyReLU(0.2),

            # (n_filters, 14, 14) -> (1, 28, 28)
            nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.network(x)


class ReducedDiscriminator(nn.Module):
    def __init__(self, n_filters):
        super().__init__()
        self.network = nn.Sequential(
            # (1, 28, 28) -> (n_filters, 14, 14)
            nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),

            # (n_filters, 14, 14) -> (n_filters*4, 7, 7)
            nn.Conv2d(n_filters, n_filters*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters*4),
            nn.LeakyReLU(0.2),

            # (n_filters*4, 7, 7) -> (1, 1, 1)
            nn.Conv2d(n_filters*4, 1, 7, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.network(input)
        return output.view(-1, 1).squeeze(0)


Add one convolutional layer in both the generator and the discriminator.

In [12]:
class ExtendedGenerator(nn.Module):
    def __init__(self, input_size, n_filters):
        super().__init__()
        self.network = nn.Sequential(
            # z -> (n_filters*8, 2, 2)
            nn.ConvTranspose2d(input_size, n_filters * 8, 2, 1, 0, bias=False),
            nn.BatchNorm2d(n_filters * 8),
            nn.LeakyReLU(0.2),

            # (n_filters*8, 2, 2) -> (n_filters*4, 4, 4)
            nn.ConvTranspose2d(n_filters * 8, n_filters * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 4),
            nn.LeakyReLU(0.2),

            # (n_filters*4, 4, 4) -> (n_filters*2, 7, 7)
            nn.ConvTranspose2d(n_filters * 4, n_filters * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 2),
            nn.LeakyReLU(0.2),

            # (n_filters*2, 7, 7) -> (n_filters, 14, 14)
            nn.ConvTranspose2d(n_filters * 2, n_filters, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters),
            nn.LeakyReLU(0.2),

            # (n_filters, 14, 14) -> (1, 28, 28)
            nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.network(x)


class ExtendedDiscriminator(nn.Module):
    def __init__(self, n_filters):
        super().__init__()
        self.network = nn.Sequential(
            # (1, 28, 28) -> (n_filters//2, 28, 28) ← extra layer
            nn.Conv2d(1, n_filters // 2, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2),

            # (n_filters//2, 28, 28) -> (n_filters, 14, 14)
            nn.Conv2d(n_filters // 2, n_filters, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),

            # (n_filters, 14, 14) -> (n_filters*2, 7, 7)
            nn.Conv2d(n_filters, n_filters * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 2),
            nn.LeakyReLU(0.2),

            # (n_filters*2, 7, 7) -> (n_filters*4, 4, 4)
            nn.Conv2d(n_filters * 2, n_filters * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 4),
            nn.LeakyReLU(0.2),

            # (n_filters*4, 4, 4) -> (1, 1, 1)
            nn.Conv2d(n_filters * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.network(x).view(-1, 1).squeeze(0)

Functions for training

In [13]:
z_size = 100
image_size = (28, 28)
n_filters = 32
batch_size = 64
num_epochs = 10
mode_z = 'uniform'

def create_noise(batch_size, z_size, mode_z):
    if mode_z == 'uniform':
        input_z = torch.rand(batch_size, z_size, 1, 1) * 2 - 1
    elif mode_z == 'normal':
        input_z = torch.randn(batch_size, z_size, 1, 1)
    return input_z

def create_samples(g_model, input_z):
    g_output = g_model(input_z)
    images = torch.reshape(g_output, (batch_size, *image_size))    
    return (images + 1) / 2.0


def train_gan(gen_class, disc_class, label, z_size=100, n_filters=32, 
              mode_z='uniform', num_epochs=10, batch_size=64):
    print(f"Training {label} GAN")
    
    gen_model = gen_class(z_size, n_filters).to(device)
    disc_model = disc_class(n_filters).to(device)

    g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0003)
    d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)
    loss_fn = nn.BCELoss()
    fixed_z = create_noise(batch_size, z_size, mode_z).to(device)

    def d_train(x):
        disc_model.zero_grad()
        batch_size = x.size(0)
        x = x.to(device)

        d_labels_real = torch.ones(batch_size, 1, device=device)
        d_proba_real = disc_model(x)
        d_loss_real = loss_fn(d_proba_real, d_labels_real)

        input_z = create_noise(batch_size, z_size, mode_z).to(device)
        g_output = gen_model(input_z)
        d_proba_fake = disc_model(g_output.detach())
        d_labels_fake = torch.zeros(batch_size, 1, device=device)
        d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()
        return d_loss.item()

    def g_train(x):
        gen_model.zero_grad()
        batch_size = x.size(0)
        input_z = create_noise(batch_size, z_size, mode_z).to(device)
        g_labels_real = torch.ones((batch_size, 1), device=device)

        g_output = gen_model(input_z)
        d_proba_fake = disc_model(g_output)
        g_loss = loss_fn(d_proba_fake, g_labels_real)

        g_loss.backward()
        g_optimizer.step()
        return g_loss.item()

    epoch_samples = []
    for epoch in range(1, num_epochs + 1):
        gen_model.train()
        g_losses, d_losses = [], []
        for x, _ in mnist_dl:
            d_losses.append(d_train(x))
            g_losses.append(g_train(x))

        print(f"{label} Epoch {epoch:03d} | Avg Losses >> G: {np.mean(g_losses):.4f} / D: {np.mean(d_losses):.4f}")
        gen_model.eval()
        with torch.no_grad():
            epoch_samples.append(create_samples(gen_model, fixed_z).cpu().numpy())
    return epoch_samples

Train the three models and compare results; unchanged from sample code, with a conv layer removed, and a conv layer added. 

In [14]:
samples_original = train_gan(make_generator_network, Discriminator, "Original")
samples_reduced  = train_gan(ReducedGenerator, ReducedDiscriminator, "Reduced")
samples_extended = train_gan(ExtendedGenerator, ExtendedDiscriminator, "Extended")

#visualization

selected_epochs = [1, 2, 4, 10, 20]
fig = plt.figure(figsize=(15, 10))

for i, e in enumerate(selected_epochs):
    for j, samples in enumerate([samples_original, samples_reduced, samples_extended]):
        for k in range(5):
            ax = fig.add_subplot(len(selected_epochs), 15, i*15 + j*5 + k + 1)
            ax.set_xticks([])
            ax.set_yticks([])
            if k == 0 and j == 0:
                ax.set_ylabel(f"Epoch {e}", fontsize=12)
            if i == 0 and k == 2:
                ax.set_title(["Original", "Reduced", "Extended"][j], fontsize=12)
            image = samples[e-1][k]
            ax.imshow(image, cmap='gray_r')

plt.tight_layout()
plt.suptitle("GAN Output Comparison by Architecture", fontsize=16, y=1.02)
plt.show()

Training Original GAN


KeyboardInterrupt: 