In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision import transforms as T
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [2]:
def show_image_tensors(image_tensors, n_images=25, size=(1, 28, 28)):

  image_unflat = image_tensors.detach().cpu().view(-1, *size)
  image_grid = make_grid(image_unflat[:n_images], nrow=5)

  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [4]:
class Generator(nn.Module):

  def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        self.get_generator_block(z_dim, hidden_dim * 4),
        self.get_generator_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
        self.get_generator_block(hidden_dim * 2, hidden_dim),
        self.get_generator_block(hidden_dim, im_chan, kernel_size=4, final_layer=True)
    )

  def get_generator_block(self, in_channels, out_channels, kernel_size=3, stride=2, final_layer=False):
    if not final_layer:
      return nn.Sequential(
          nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(inplace=True)
      )
    else:
      return nn.Sequential(
          nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
          nn.Tanh()
      )

  def unsqueeze_noise(self, noise):
    return noise.view(len(noise), self.z_dim, 1, 1)

  def forward(self, noise):
    x = self.unsqueeze_noise(noise)
    return self.gen(x)

In [5]:
def get_noise(n_samples, z_dim, device='cpu'):
  return torch.randn(n_samples, z_dim, device=device)

In [6]:
class Discriminator(nn.Module):

  def __init__(self, im_chan=1, hidden_dim=16):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        self.get_discriminator_block(im_chan, hidden_dim),
        self.get_discriminator_block(hidden_dim, hidden_dim * 2),
        self.get_discriminator_block(hidden_dim * 2, 1, final_layer=True)
    )

  def get_discriminator_block(self, in_channels, out_channels, kernel_size=4, stride=1, final_layer=False):
    if not final_layer:
      return nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size, stride),
          nn.BatchNorm2d(out_channels),
          nn.LeakyReLU(0.2, inplace=True)
      )
    else:
      return nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size, stride)
      )

    def forward(self, image):
      disc_pred = self.disc(image)
      return disc_pred.view(len(disc_pred), -1)