Derived from the notebook found here: https://github.com/Sachin-Wani/deeplearning.ai-GANs-Specialization/blob/master/Course%201%20-%20Build%20Basic%20Generative%20Adversarial%20Networks%20(GANs)/Week%201/C1W1_Your_First_GAN.ipynb

In [17]:
import matplotlib.pyplot as plt

import torch
from torch.nn import (
    BatchNorm1d, Linear, Module, ReLU, Sequential, Sigmoid)
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm.auto import tqdm

In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x7fbb9e092250>

In [6]:
def plot_tensor_imgs(img_tensor, n_img=25, size=(1, 28, 28)):
    img = img_tensor.detach().cpu().view(-1, *size)
    img_grid = make_grid(img[:n_img], nrow=5)
    plt.imshow(img_grid.permute(1, 2, 0).squeeze())

### Generator

In [12]:
def get_generator_block(in_dim, out_dim):
    '''
    Function for returning a block of the generator's neural network
    given input and output dims.
    Args:
        in_dim (int): dim of the input vector
        out_dim (int): dim of the output vector
    Returns:
        a generator neural network layer
    '''
    
    return Sequential(Linear(in_dim, out_dim),
                      BatchNorm1d(out_dim),
                      ReLU(inplace=True))

In [13]:
def test_gen_block(in_features, out_features, num_test=1000):
    block = get_generator_block(in_features, out_features)
    assert len(block) == 3, 'block has wrong len'
    assert type(block[0]) == nn.Linear, 'block[0] not Linear'
    assert type(block[1]) == nn.BatchNorm1d, 'block[1] not BatchNorm'
    assert type(block[2]) == nn.ReLU, 'block[2] not ReLU'
    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)
    assert tuple(test_output.shape) == (num_test, out_features), \
        (f'output_shape = {tuple(test_output.shape)}, '
         f'expected {(n_test, out_features)}')
    assert test_output.std() > 0.55, 'output sd <= 0.55'
    assert test_output.std() < 0.65, 'output sd >= 0.65'

test_gen_block(25, 12)
test_gen_block(15, 28)
print("Success!")

Success!


In [16]:
128*8

1024

In [22]:
class Generator(Module):
    '''
    Generator Class
    Vals:
        z_dim (int): dim of the noise vector
        img_dim: the dimension of the images, fitted for the dataset used
          (Defaults to MNIST images: 28 x 28 = 784)
        hidden_dim (int): dime of inner layer
    '''
    def __init__(self, z_dim=10, img_dim=28*28, hidden_dim=128):
        super().__init__()
        self.gen = Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, 2*hidden_dim),
            get_generator_block(2*hidden_dim, 4*hidden_dim),
            get_generator_block(4*hidden_dim, 8*hidden_dim),
            Linear(8*hidden_dim, img_dim),
            Sigmoid())
        
    def forward(self, noise):
        '''
        Complete a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Args:
            noise: a noise tensor with dims (n_samples, z_dim)
        '''
        return self.get(noise)
    
    def get_gen(self):
        return self.gen

In [23]:
def test_generator(z_dim, im_dim, hidden_dim, num_test=10000):
    gen = Generator(z_dim, im_dim, hidden_dim).get_gen()
    
    # Check there are six modules in the sequential part
    assert len(gen) == 6
    assert str(gen.__getitem__(4)).replace(' ', '') == f'Linear(in_features={hidden_dim * 8},out_features={im_dim},bias=True)'
    assert str(gen.__getitem__(5)).replace(' ', '') == 'Sigmoid()'
    test_input = torch.randn(num_test, z_dim)
    test_output = gen(test_input)

    # Check that the output shape is correct
    assert tuple(test_output.shape) == (num_test, im_dim)
    assert test_output.max() < 1, "Make sure to use a sigmoid"
    assert test_output.min() > 0, "Make sure to use a sigmoid"
    assert test_output.std() > 0.05, "Don't use batchnorm here"
    assert test_output.std() < 0.15, "Don't use batchnorm here"

test_generator(5, 10, 20)
test_generator(20, 8, 24)
print("Success!")

Success!
