# Convolutional Deep Generative Adversarial Networks (DCGAN) using PyTorch

Author: [Markus Enzweiler](https://markus-enzweiler-de), markus.enzweiler@hs-esslingen.de

This is a demo used in a Computer Vision & Machine Learning lecture. Feel free to use and contribute.

**See `gan_torch_mnist.ipynb` for a more in-depth notebook on GANs.**

## Setup

Adapt `packagePath` to point to the directory containing this notebeook.

In [None]:
# Imports
import sys
import os
import subprocess
import numpy as np

In [None]:
# Package Path
package_path = "./" # local
print(f"Package path: {package_path}")


def check_for_colab():
  try:
      import google.colab
      return True
  except ImportError:
      return False

# Running on Colab?
on_colab = check_for_colab()


In [None]:
# Clone git repository

# Absolute path of the repository directory
repo_dir = os.path.join(package_path, "torch-gan")
repo_url = "https://github.com/menzHSE/torch-gan.git"

# Store the original working directory
original_cwd = os.getcwd()

# Check if the directory already exists using the absolute path
if os.path.exists(os.path.join(original_cwd, repo_dir)):
    print("Repository exists. Resetting to HEAD...")
    # Navigate into the repository directory
    os.chdir(repo_dir)
    # Fetch the latest changes from the remote
    subprocess.run(["git", "fetch", "origin"])
    # Reset the local branch to the latest commit from the remote
    subprocess.run(["git", "reset", "--hard", "origin/HEAD"])
    # Change back to the original working directory
    os.chdir(original_cwd)
else:
    print("Cloning repository...")
    # Clone the repository if it doesn't exist
    subprocess.run(["git", "clone", repo_url, repo_dir])
# Clone git repository

# Absolute path of the repository directory
repo_dir = os.path.join(package_path, "torch-gan")
repo_url = "https://github.com/menzHSE/torch-gan.git"

# Store the original working directory
original_cwd = os.getcwd()

# Check if the directory already exists using the absolute path
if os.path.exists(os.path.join(original_cwd, repo_dir)):
    print("Repository exists. Resetting to HEAD...")
    # Navigate into the repository directory
    os.chdir(repo_dir)
    # Fetch the latest changes from the remote
    subprocess.run(["git", "fetch", "origin"])
    # Reset the local branch to the latest commit from the remote
    subprocess.run(["git", "reset", "--hard", "origin/HEAD"])
    # Change back to the original working directory
    os.chdir(original_cwd)
else:
    print("Cloning repository...")
    # Clone the repository if it doesn't exist
    subprocess.run(["git", "clone", repo_url, repo_dir])


In [None]:
# Install requirements in the current Jupyter kernel
req_file = os.path.join(repo_dir, "requirements.txt")
if os.path.exists(req_file):
    !{sys.executable} -m pip install -r {req_file}
else:
    print(f"Requirements file not found: {req_file}")


# Additional requirements for this notebook
req_file = "requirements.txt"
if os.path.exists(req_file):
    !{sys.executable} -m pip install -r {req_file}
else:
    print(f"Requirements file not found: {req_file}")
    

# Sample CelebA faces

If the dataset cannot be automatically downloaded by PyTorch due to **daily quota exceeded** you can manually download it and put it in the ```data/celaba``` folder, see code cell below.

The following files are necessary:
- img_align_celeba.zip
- list_attr_celeba.txt
- list_bbox_celeba.txt
- list_eval_partition.txt
- list_landmarks_align_celeba.txt
- list_landmarks_celeba.txt

In [None]:
# Set this to True to download the dataset from an alternative source
use_alternative_data_source = True

dest_dir = "./data/celeba"

if use_alternative_data_source:

    # Base URL for data
    data_url = "https://graal.ift.ulaval.ca/public/celeba/"
    # Directory to store the files
    dest_dir = "data/celeba"

    # Ensure the destination directory exists
    os.makedirs(dest_dir, exist_ok=True)

    # List of files to download
    celeba_files = [
        "img_align_celeba.zip",
        "list_attr_celeba.txt",
        "list_bbox_celeba.txt",
        "list_eval_partition.txt",
        "list_landmarks_align_celeba.txt",
        "list_landmarks_celeba.txt"
    ]

    # Download each file using wget with -nc option
    for file in celeba_files:
        file_url = data_url + file
        print(f"Downloading {file_url} ... ")
        !wget -nc {file_url} -P {dest_dir}
        print("done")

    print("Download complete.")


## Load Model and Data

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt

# random seed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Add the directory containing models.py to the system path
sys.path.append(os.path.join(package_path, 'torch-gan'))


# Now we can import the model and dataset
import model
import dataset
import device

# parameters
dataset_id       = "celeb-a"
num_latent_dims  = 100
max_num_filters  = 512
img_size         = (64, 64)
batch_size       = 32
model_id         = f"G_filters_{max_num_filters:04d}_dims_{num_latent_dims:04d}.pth"
gen_fname        = os.path.join(package_path, "torch-gan", "pretrained", dataset_id, model_id)
dev              = device.autoselectDevice()

# load dataset
celeba_train_loader, celeba_test_loader, _, celeba_num_img_channels = dataset.get_loaders(
    dataset_id, img_size=img_size, batch_size=batch_size)

# load the generator
G = model.Generator(num_latent_dims, celeba_num_img_channels, max_num_filters, dev)
G.load(gen_fname, dev)

if G:
    print(f"Model {gen_fname} loaded successfully!")
    print(f"Device used: {dev}")
    G.to(dev)
    G.eval()  



## Show some Training Images

In [None]:
def normalizeForDisplay(images):
    # normalize from [-1, 1] to [0, 1]
    return (images + 1.0) / 2.0

In [None]:
import matplotlib.pyplot as plt

# get a batch of images from the training set and display them
# we use the torchvision.utils.make_grid function to create a grid of images
images, labels = next(iter(celeba_train_loader))
grid_img = torchvision.utils.make_grid(normalizeForDisplay(images), nrow=batch_size//4)
plt.figure(figsize=(10,10))
plt.imshow(np.transpose(grid_img, (1,2,0)))
plt.axis('off')
plt.title('Batch from the Training Set')
plt.show()

## Generate CelebA-like Samples

In [None]:
import utils

def sampleAndPlot(G, num_latent_dims, num_samples=batch_size):
    with torch.no_grad():
        for i in range(num_samples):         
      
            # generate a random latent vector   
            z = utils.sample_latent_vectors(1, num_latent_dims, dev)

            # generate an image from the latent vector
            img = G(z)

            if i == 0:
                pics = img
            else:
                pics = torch.cat((pics, img), dim=0)

          
        # Create a grid of images
        grid_img = torchvision.utils.make_grid(pics, nrow=batch_size//4)

        # Convert grid to numpy and transpose axes for plotting
        grid_np = grid_img.cpu().numpy()
        grid_np = np.transpose(grid_np, (1, 2, 0))

        # Plotting
        plt.figure(figsize=(10, 10))
        plt.imshow(normalizeForDisplay(grid_np))
        plt.axis('off')
        plt.title(f'Randomly Generated Images from the Generator with {num_latent_dims} Latent Dimensions')
        plt.show()

In [None]:
sampleAndPlot(G, num_latent_dims)