# Convolutional Deep Generative Adversarial Networks (DCGAN) using PyTorch

Author: [Markus Enzweiler](https://markus-enzweiler-de), markus.enzweiler@hs-esslingen.de

This is a demo used in a Computer Vision & Machine Learning lecture. Feel free to use and contribute.

**See `gan_torch_mnist.ipynb` for a more in-depth notebook on GANs.**

## Setup

Adapt `packagePath` to point to the directory containing this notebeook.

In [None]:
# Notebook id
nb_id = "gan/torch"

# Imports
import sys
import os

In [None]:
# Package Path (folder of this notebook)

#####################
# Local environment #
#####################

package_path = "./"


#########
# Colab #
#########


def check_for_colab():
  try:
      import google.colab
      return True
  except ImportError:
      return False

# running on Colab?
on_colab = check_for_colab()

if on_colab:

    # assume this notebook is run from Google Drive and the whole
    # cv-ml-lecture-notebooks repo has been setup via setupOnColab.ipynb

    # Google Drive mount point
    gdrive_mnt = '/content/drive'

    ##########################################################################
    # Ensure that this is the same as gdrive_repo_root in setupOnColab.ipynb #
    ##########################################################################
    # Path on Google Drive to the cv-ml-lecture-notebooks repo
    gdrive_repo_root = f"{gdrive_mnt}/MyDrive/cv-ml-lecture-notebooks"

    # mount drive
    from google.colab import drive
    drive.mount(gdrive_mnt, force_remount=True)

    # set package path
    package_path = f"{gdrive_repo_root}/{nb_id}"

# check whether package path exists
if not os.path.isdir(package_path):
  raise FileNotFoundError(f"Package path does not exist: {package_path}")

print(f"Package path: {package_path}")


In [None]:
# Additional imports

# Repository Root
repo_root = os.path.abspath(os.path.join(package_path, "..", ".."))
# Add the repository root to the system path
if repo_root not in sys.path:
    sys.path.append(repo_root)

# Package Imports
from nbutils import requirements as nb_reqs
from nbutils import colab as nb_colab
from nbutils import git as nb_git
from nbutils import exec as nb_exec
from nbutils import data as nb_data

In [None]:
# Clone git repository

# Absolute path of the repository directory
repo_dir = os.path.join(package_path, "torch-gan")
repo_url = "https://github.com/menzHSE/torch-gan.git"

nb_git.clone(repo_url, repo_dir, on_colab)

In [None]:
# Install requirements in the current Jupyter kernel
req_file = os.path.join(repo_dir, "requirements.txt")
nb_reqs.pip_install_reqs(req_file, on_colab)

# Additional requirements for this notebook
req_file = os.path.join(package_path, "requirements.txt")
nb_reqs.pip_install_reqs(req_file, on_colab)

# Sample CelebA faces

If the dataset cannot be automatically downloaded by PyTorch due to **daily quota exceeded** you can manually download it and put it in the ```data/celaba``` folder, see code cell below.

The following files are necessary:
- img_align_celeba.zip
- list_attr_celeba.txt
- list_bbox_celeba.txt
- list_eval_partition.txt
- list_landmarks_align_celeba.txt
- list_landmarks_celeba.txt

In [None]:
# To get around the problem of "Quota Exceeded" on the "official" CelebA
# download via torchvision (https://github.com/pytorch/vision/issues/1920),
# we use an alternative data source.

# Set this to True to download the dataset from an alternative source
use_alternative_data_source = True

if use_alternative_data_source:
  print("Dowloading CelebA ... this will take a while")
  nb_data.download_celeba("./data/celeba", on_colab, verbose=False)


## Load Model and Data

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

# random seed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Add the directory containing models.py to the system path
sys.path.append(os.path.join(package_path, 'torch-gan'))


# Now we can import the model and dataset
import model
import dataset
import device

# parameters
dataset_id       = "celeb-a"
num_latent_dims  = 100
max_num_filters  = 512
img_size         = (64, 64)
batch_size       = 32
model_id         = f"G_filters_{max_num_filters:04d}_dims_{num_latent_dims:04d}.pth"
gen_fname        = os.path.join(package_path, "torch-gan", "pretrained", dataset_id, model_id)
dev              = device.autoselectDevice()

# load dataset
celeba_train_loader, celeba_test_loader, _, celeba_num_img_channels = dataset.get_loaders(
    dataset_id, img_size=img_size, batch_size=batch_size)

# load the generator
G = model.Generator(num_latent_dims, celeba_num_img_channels, max_num_filters, dev)
G.load(gen_fname, dev)

if G:
    print(f"Model {gen_fname} loaded successfully!")
    print(f"Device used: {dev}")
    G.to(dev)
    G.eval()  



## Show some Training Images

In [None]:
def normalizeForDisplay(images):
    # normalize from [-1, 1] to [0, 1]
    return (images + 1.0) / 2.0

In [None]:
import matplotlib.pyplot as plt

# get a batch of images from the training set and display them
# we use the torchvision.utils.make_grid function to create a grid of images
images, labels = next(iter(celeba_train_loader))
grid_img = torchvision.utils.make_grid(normalizeForDisplay(images), nrow=batch_size//4)
plt.figure(figsize=(10,10))
plt.imshow(np.transpose(grid_img, (1,2,0)))
plt.axis('off')
plt.title('Batch from the Training Set')
plt.show()

## Generate CelebA-like Samples

In [None]:
import utils

def sampleAndPlot(G, num_latent_dims, num_samples=batch_size):
    with torch.no_grad():
        for i in range(num_samples):         
      
            # generate a random latent vector   
            z = utils.sample_latent_vectors(1, num_latent_dims, dev)

            # generate an image from the latent vector
            img = G(z)

            if i == 0:
                pics = img
            else:
                pics = torch.cat((pics, img), dim=0)

          
        # Create a grid of images
        grid_img = torchvision.utils.make_grid(pics, nrow=batch_size//4)

        # Convert grid to numpy and transpose axes for plotting
        grid_np = grid_img.cpu().numpy()
        grid_np = np.transpose(grid_np, (1, 2, 0))

        # Plotting
        plt.figure(figsize=(10, 10))
        plt.imshow(normalizeForDisplay(grid_np))
        plt.axis('off')
        plt.title(f'Randomly Generated Images from the Generator with {num_latent_dims} Latent Dimensions')
        plt.show()

In [None]:
sampleAndPlot(G, num_latent_dims)