**AUTHORIZATIONS**

In [None]:
from google.colab import drive
drive.mount('./mount')

**IMPORT LIBRARIES**

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

import h5py
import random
from time import strftime, localtime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import os
import sys
import zipfile
import imageio
from tqdm.notebook import tqdm

**CONVERT TO HDF5**

In [None]:
def convert_to_hdf5_file(index_):

  hdf5_file = f"mount/My Drive/Colab Notebooks/celebA/{index_}_img_align_celeba.h5py"

  # how many of the 202,599 images to extract and package into HDF5
  total_images = int(index_)

  with h5py.File(hdf5_file, "w") as hf:
      count = 0
      with zipfile.ZipFile("mount/My Drive/Colab Notebooks/celebA/img_align_celeba.zip", "r") as zf:
        for i in tqdm(zf.namelist(), total=index_):
          if i[-4:] == ".jpg":

            # extract image
            ofile = zf.extract(i)
            img = imageio.imread(ofile)
            os.remove(ofile)

            # add image data to HDF5 file with new name
            hf.create_dataset("img_align_celeba/"+str(count)+".jpg", data=img, compression="gzip", compression_opts=9)
            
            count += 1

            # stop when total_images reached
            if count == total_images:
              break

convert_to_hdf5_file(20_096)

**NETWORKS**

In [3]:
class CelebADataset(Dataset):
    
    def __init__(self, file):
      self.file_object = h5py.File(file, "r")
      self.dataset = self.file_object["img_align_celeba"]
    
    def __len__(self):
      return len(self.dataset)
    
    def __getitem__(self, index):
      if (index >= len(self.dataset)):
        raise IndexError()
      img = np.array(self.dataset[str(index)+".jpg"])
      img = self.crop_to_centre_img(img, 128, 128)
      return torch.cuda.FloatTensor(img).permute(2, 0, 1).view(1, 3, 128, 128) / 255.0
    
    def crop_to_centre_img(self, img, width, height):
      img_height, img_width, _ = img.shape
      x = img_width // 2 - width // 2
      y = img_height // 2 - height // 2
      return img[y: y + height, x: x + width, :]
    
    def plot_image(self, index):
      img = np.array(self.dataset[str(index)+'.jpg'])
      img = self.crop_to_centre_img(img, 128, 128)
      plt.imshow(img, interpolation='nearest')

class View(nn.Module):
    def __init__(self, shape):
      super().__init__()
      """
      Reshape 3-dimensional Tensor to 1-Dimensional Tensor (Vector) or vice versa
      """
      self.shape = shape,

    def forward(self, x):
      return x.reshape(*self.shape)

class Discriminator(nn.Module):
    def __init__(self):
      super().__init__()
      """
      Create structure of NN
      """
      
      # Neural Network layers
      self.model = nn.Sequential(
            View((1, 3, 128, 128)),

            nn.Conv2d(3, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.GELU(),
            
            nn.Conv2d(256, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.GELU(),
            
            nn.Conv2d(256, 3, kernel_size=8, stride=2),
            nn.GELU(),
            
            View(3*10*10),
            nn.Linear(3*10*10, 1),
            nn.Sigmoid()
          )

      # Mean square loss
      self.loss_function = nn.MSELoss()

      # Adam optimiser
      self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

      # Timestap of loss for ploting progress
      self.counter = 0
      self.progress = []
    
    
    def forward(self, inputs):
      """
      Pass through NN and get its answer
      """
      return self.model(inputs)
    
    
    def train(self, inputs, targets):
      """
      Train NN; Take tensor of image with label identificator of image;
      Pass through NN; get loss/cost function and backpropagate NN
      to tweak weights (layers)
      """
      outputs = self.forward(inputs)
      loss = self.loss_function(outputs, targets)

      # For timestamp and plotting
      self.counter += 1
      if (self.counter % 10000 == 0):
          self.progress.append(loss.item())

      # Backpropagation -> zero gradients, perform a backward pass, update weights
      self.optimiser.zero_grad()
      loss.backward()
      self.optimiser.step()
    
    
    def plot_progress(self):
      """
      Plot loss of NN for every image it was trained
      """
      df = pd.DataFrame(self.progress, columns=['loss'])
      df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True,
                yticks=(0, 0.25, 0.5, 1.0, 5.0), title="Discriminator Loss")

class Generator(nn.Module):
    
    def __init__(self):
      super().__init__()
      """
      Create structure of NN
      """
      self.model = nn.Sequential(
            # reshape to z (_, z, y, x)
            nn.Linear(100, 3*11*11),
            nn.GELU(),
            View((1, 3, 11, 11)),
            
            nn.ConvTranspose2d(3, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.GELU(),

            nn.ConvTranspose2d(256, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.GELU(),

            nn.ConvTranspose2d(256, 3, kernel_size=8, stride=2, padding=1),
            nn.BatchNorm2d(3),
            
            # output is (1,3,128,128)
            nn.Sigmoid()
          )

      # No loss function; will use one from discriminator to calculate error

      self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

      # counter and accumulator for progress
      self.counter = 0
      self.progress = []

    def forward(self, inputs):        
      """
      Pass through NN and get its answer
      """
      return self.model(inputs)
    
    
    def train(self, D, inputs, targets):
      """
      Train NN; Take tensor of image with label identificator of image;
      Pass through NN; get loss/cost function and backpropagate NN
      to tweak weights (layers)
      """
      g_output = self.forward(inputs)
      d_output = D.forward(g_output)
      
      loss = D.loss_function(d_output, targets)

      self.counter += 1
      if (self.counter % 10000 == 0):
          self.progress.append(loss.item())

      # Backpropagation
      self.optimiser.zero_grad()
      loss.backward()
      self.optimiser.step()
    
    def plot_progress(self):
      """
      Plot loss of NN for every image it was trained
      """
      df = pd.DataFrame(self.progress, columns=['loss'])
      df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True,
              yticks=(0, 0.25, 0.5, 1.0, 5.0), title="Generator Loss")

**FUNCTIONS**

In [4]:
# functions to generate random data
def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data

def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data

# Functions to plot
def plot_loss():
    D.plot_progress()
    G.plot_progress()

def save_plot(G, epoch):
    fig, ax = plt.subplots(figsize=(16, 8))
    output = G.forward(generate_random_seed(100))
    img = output.detach().permute(0, 2, 3, 1).reshape(128, 128, 3).cpu().numpy()
    ax.imshow(img, interpolation="none", cmap="Blues")
    timestamp = strftime("%Y-%m-%d %H-%M-%S", localtime())
    plt.savefig(f"mount/My Drive/Colab Notebooks/celebA/generated images/single_face_{timestamp}.png")

def plot_results(G, seed1, seed2, epoch=None, save_img=False):
    # Plot 6 images
    rows, columns = 2, 3
    fig, ax = plt.subplots(rows, columns, figsize=(16, 8))
    for i in range(rows):
        for j in range(columns):
            if i == rows-1 and j == columns:        # Last in penultimate row is seed1-seed2
                output = G.forward(seed1-seed2)
            elif i == rows and j == columns:        # Last in last row is seed1+seed2
                output = G.forward(seed1+seed2)
            elif i == rows-1 and j == columns-1:    # Penultimate in penultimate row is seed1
                output = G.forward(seed1)
            elif i == rows and j == columns-1:      # Penultimate in last row is seed2
                output = G.forward(seed2)
            else:                                   # Everything else is random
                output = G.forward(generate_random_seed(100))
            img = output.detach().permute(0, 2, 3, 1).view(128, 128, 3).cpu().numpy()
            ax[i,j].imshow(img, interpolation="none", cmap="Blues")

    if save_img == True:
      timestamp = strftime("%Y-%m-%d %H-%M-%S", localtime())
      plt.savefig(f"mount/My Drive/Colab Notebooks/celebA/generated images/epoch_{epoch}_{timestamp}.png")
  
# Save
def save_model(G, D, state_dict=True):
    if state_dict == True:
      PATH = "mount/My Drive/Colab Notebooks/celebA/models/generator.pt"
      torch.save(G.state_dict(), PATH)
      PATH = "mount/My Drive/Colab Notebooks/celebA/models/discriminator.pt"
      torch.save(D.state_dict(), PATH)
    else:
      PATH = "mount/My Drive/Colab Notebooks/celebA/models/generator.pt"
      torch.save(G, PATH)
      PATH = "mount/My Drive/Colab Notebooks/celebA/models/discriminator.pt"
      torch.save(D, PATH)

**CUDA**

In [None]:
# CUDA
if torch.cuda.is_available():
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
  print("using cuda:", torch.cuda.get_device_name(0))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

**PREPARE NETWORKS**

In [6]:
dataset = CelebADataset("mount/My Drive/Colab Notebooks/celebA/20096_img_align_celeba.h5py")
D = Discriminator().to(device)
G = Generator().to(device)

seed1 = generate_random_seed(100)
seed2 = generate_random_seed(100)

**LOAD NETWORKS** (In case of disconnection in training)

In [None]:
PATH = "mount/My Drive/Colab Notebooks/celebA/models/discriminator.pt"
D.load_state_dict(torch.load(PATH))
PATH = "mount/My Drive/Colab Notebooks/celebA/models/generator.pt"
G.load_state_dict(torch.load(PATH))

**TEST NETWOKS OUTPUT SHAPE**

In [None]:
for index, img_tensor in enumerate(dataset):
  print(img_tensor.shape)
  d_output = D.forward(img_tensor)
  print(d_output.shape)
  g_output = G.forward(generate_random_seed(100))
  print(g_output.shape)
  break

**TRAIN**

In [None]:
# Train
epochs = 20
for e in range(epochs):
    print(f"\nTraining in {e+1} / {epochs} epochs...")
    for index, img_tensor in tqdm(enumerate(dataset), total=len(dataset)): # <- slow code
    # for index, img_tensor in enumerate(dataset):
      # Train Discriminator -> Real data
      D.train(img_tensor, torch.cuda.FloatTensor([1.0]))

      # Train Discriminator -> Fake data
      D.train(G.forward(generate_random_seed(100)).detach(), torch.cuda.FloatTensor([0.0]))

      # Train Generator
      G.train(D, generate_random_seed(100), torch.cuda.FloatTensor([1.0]))

    # Save models
    save_model(G, D)
    
    # Save progress and plot
    save_plot(G, e)
    plot_results(G, seed1, seed2, save_img=True)

# PLOT
plot_loss()
plot_results(G, seed1, seed2, save_img=True)

**MEMORY**

In [None]:
print(torch.cuda.memory_summary(device, abbreviated=True))