<a href="https://colab.research.google.com/github/jahelsantiago/MINST-GANs/blob/main/Dc_MinstGan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn

## GENERATOR

In [20]:
class Generator(nn.Module):
  def __init__(self, z_dim = 10, img_chanel = 1, hidden_dim = 64):
    '''
    Generator Class
    Values:
        z_dim (int): the dimension of the noise vector.
        im_chan (int): the number of channels in the images, fitted for the dataset used.
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim (int): the inner dimension.
    '''

    super(Generator,self).__init__()
    self.z_dim = z_dim
    self.gen = nn.Sequential(
        self.get_gen_blok(z_dim, hidden_dim*4),
        self.get_gen_blok(hidden_dim*4, hidden_dim*2),
        self.get_gen_blok(hidden_dim*2, hidden_dim),
        self.get_gen_blok(hidden_dim, img_chanel, final_layer = True)                          
    )

  def get_gen_blok(self, input_dim,output_dim, kernel = 3, stride = 2, final_layer = False):
    '''
    Function to return a sequence of operations corresponding to a generator block of DCGAN, 
    corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation.
    Parameters:
        input_channels: how many channels the input feature representation has
        output_channels: how many channels the output feature representation should have
        kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
        stride: the stride of the convolution
        final_layer: a boolean, true if it is the final layer and false otherwise 
                  (affects activation and batchnorm)
    '''
    if not final_layer:      
      model = nn.Sequential(
        nn.ConvTranspose2d(input_dim,output_dim,kernel,stride),
        nn.BatchNorm2d(output_dim),
        nn.ReLU()
      )
    else: #Final layer
      model = nn.Sequential(
        nn.ConvTranspose2d(input_dim,output_dim,kernel,stride),
        nn.Tanh()
      )
    return model

  def unsqueeze_noise(self, noise):
      '''
      Function for completing a forward pass of the generator: Given a noise tensor, 
      returns a copy of that noise with width and height = 1 and channels = z_dim.
      Parameters:
          noise: a noise tensor with dimensions (n_samples, z_dim)
      '''
      return noise.view(len(noise), self.z_dim, 1, 1)

  def forward(self, noise):
      '''
      Function for completing a forward pass of the generator: Given a noise tensor, 
      returns generated images.
      Parameters:
          noise: a noise tensor with dimensions (n_samples, z_dim)
      '''
      x = self.unsqueeze_noise(noise)
      return self.gen(x)


def get_noise(n_samples, z_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    '''
    return torch.randn(n_samples, z_dim, device=device)


## DISCRMINATOR

In [27]:
class Discriminator(nn.Module):
  def __init__(self, im_chanel = 1, hidden_dim = 16):
    super(Discriminator, self).__init__()
    '''
    Discriminator Class
    Values:
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
    hidden_dim: the inner dimension, a scalar
    '''
    
    self.disc = nn.Sequential(
        self.get_disc_block(im_chanel, hidden_dim),
        self.get_disc_block(hidden_dim, hidden_dim*2),
        self.get_disc_block(hidden_dim * 2, 1, final_layer = True),
        
    )
  
  def get_disc_block(self, inpud_dim, output_dim, kernel = 4, stride = 2, final_layer = False):
    '''
    Function to return a sequence of operations corresponding to a discriminator block of DCGAN, 
    corresponding to a convolution, a batchnorm (except for in the last layer), and an activation.
    Parameters:
      inpud_dim: how many channels the input feature representation has
      output_dim: how many channels the output feature representation should have
      kernel: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
      stride: the stride of the convolution
      final_layer: a boolean, true if it is the final layer and false otherwise 
                (affects activation and batchnorm)
    '''
    if not final_layer:
      return nn.Sequential(
          nn.Conv2d(inpud_dim, output_dim, kernel, stride),
          nn.BatchNorm2d(output_dim),
          nn.LeakyReLU(negative_slope=0.2)
      )
    else:
      return nn.Sequential(
          nn.Conv2d(inpud_dim, output_dim, kernel, stride)          
      )

  def forward(self,image):
    '''
    Function for completing a forward pass of the discriminator: Given an image tensor, 
    returns a 1-dimension tensor representing fake/real.
    Parameters:
        image: a flattened image tensor with dimension (im_dim)
    '''
    disc_pred = self.disc(image)
    return disc_pred.view(len(disc_pred), -1)

# Training

In [25]:
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
batch_size = 128
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002

# These parameters control the optimizer's momentum, which you can read more about here:
# https://distill.pub/2017/momentum/ but you don’t need to worry about it for this course!
beta_1 = 0.5 
beta_2 = 0.999
device = 'cpu'

# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [30]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

# You initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [31]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!


n_epochs = 50
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        fake = gen(fake_noise)
        disc_fake_pred = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_pred = disc(real)
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        # Update gradients
        disc_loss.backward(retain_graph=True)
        # Update optimizer
        disc_opt.step()

        ## Update generator ##
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        disc_fake_pred = disc(fake_2)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ## Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

KeyboardInterrupt: ignored