### GauGAN (Without Encoder)

Generator with SPADE Resblocks (modulated by binary thresholded versions of images), Patch Discriminator conditioned on binary versions of images, Hinge Loss. SpecNorm in both generator and discriminator.

GauGAN transforms a noise vector modulated by a segmented image into a filled real version of that segmented image. Modulation's called Spatially Adpative Normalization.

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import os
import cv2
from random import shuffle
import matplotlib.pyplot as plt
from models.gaugan import SPADE_Generator, Patch_Discriminator

from google.colab import drive
drive.mount("/content/drive")

In [None]:
def load_data(path, image_size, block_size = None, thresholded = False):
  x_train = []
  files = os.listdir(path)
  #shuffle(files)
  if block_size is None or block_size > len(files):
    block_size = len(files)

  for i,file in enumerate(files):
    img = cv2.imread(path+"/"+file)
    img = cv2.resize(img, (image_size[1], image_size[0]))

    if thresholded:
      img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
      img = cv2.GaussianBlur(img, (3, 3), 0)
      ret, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
      img = img.reshape(image_size[0], image_size[1], 1)
    else:
      img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      
    img = np.transpose(img, (2, 0, 1))
    img = np.float32(img)
    x_train.append(img)
    print(i,"/",block_size)

    if i >= block_size - 1:
      break

  return np.array(x_train)


In [None]:
x_train = load_data(path="/content/drive/My Drive/Datasets/vanGogh2Phots/trainB", image_size = (128,128))
y_train = load_data(path="/content/drive/My Drive/Datasets/vanGogh2Phots/trainB", image_size = (128,128), thresholded = True)

x_train = (x_train/255)*2 - 1
print(x_train.max(), x_train.min())
print(x_train.shape)

y_train = (y_train/255)*2 - 1
print(y_train.max(), y_train.min())
print(y_train.shape)

In [None]:
img = np.transpose(x_train[10], (1, 2, 0))
plt.imshow((img+1)/2)
plt.show()
img = np.transpose(y_train[10], (1, 2, 0))
plt.imshow((img.reshape(128,128)+1)/2)
plt.show()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

noise_dim = 128
netG = SPADE_Generator(image_size = (128,128), noise_dim = noise_dim, seg_channels = 1, specnorm = True).to(device)
netD = Patch_Discriminator(image_size = (128,128), in_channels = 4).to(device)

if torch.cuda.device_count() > 1:
    netG = nn.DataParallel(netG, list(range(torch.cuda.device_count())))
    netD = nn.DataParallel(netD, list(range(torch.cuda.device_count())))

#Optional orthogonal initialization of weights, does not work with Spectral Normalization!#########
'''for m in netG.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)

for m in netD.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)
'''

#Two Timescale Update Rule
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.0, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0001, betas = (0.0, 0.9))

In [None]:
epochs = 500
disc_steps = 1
batch_size = 16
iterations = len(x_train)//batch_size
path = "./saved_models/"

for epoch in range(epochs + 1):
    for i in range(iterations + 1):

        #Dealing with the discriminator################################
        netD.zero_grad()

        next_batch = np.random.randint(0, len(x_train), size = batch_size)
        data = torch.Tensor(x_train[next_batch])
        thresh_data = torch.Tensor(y_train[next_batch]).to(device)

        real_images = data.to(device)

        real_images = torch.cat([real_images, thresh_data], dim = 1)
        output = netD(real_images).view(-1)
        errD_real = torch.mean(F.relu(1 - output))
        D_x = output.mean().item()

        noise = torch.randn(batch_size, noise_dim, device = device)
        fake = netG(noise, thresh_data)

        fake = torch.cat([fake, thresh_data], dim = 1)
        output = netD(fake.detach()).view(-1)
        errD_fake = torch.mean(F.relu(1 + output))
        D_G_z1 = output.mean().item()

        errD = errD_fake + errD_real
        errD.backward()
        optimizerD.step()
            
        #Dealing with the generator###################################
        netG.zero_grad()

        output = netD(fake).view(-1)
        errG = -torch.mean(output)

        D_G_z2 = output.mean().item()
        errG.backward()
        optimizerG.step()

        if i%100 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, epochs, i, len(x_train),
                    errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    
    #if epoch%100 == 0:
        #torch.save(netG.state_dict(), path + "gaugan_G.pth")
        #torch.save(netD.state_dict(), path + "gaugan_D.pth")

In [None]:
noise = torch.randn(20, noise_dim, device = device)
next_batch = np.random.randint(0, len(x_train), size = 20)
thresh_data = torch.Tensor(y_train[next_batch]).to(device)
real_images = torch.Tensor(x_train[next_batch])

with torch.no_grad():
  fake = netG(noise, thresh_data).cpu()

print("Actual")
grid = torchvision.utils.make_grid(real_images, nrow = 5, padding = 1, pad_value = 0.15)
f = plt.figure(figsize=(15,15))
plt.imshow((grid.permute(1, 2, 0)+1)/2)
plt.axis('off')
plt.show()

print("Binarized")
grid = torchvision.utils.make_grid(thresh_data.cpu(), nrow = 5, padding = 1, pad_value = 0.15)
f = plt.figure(figsize=(15,15))
plt.imshow((grid.permute(1, 2, 0)+1)/2)
plt.axis('off')
plt.show()

print("Noise produced")
grid = torchvision.utils.make_grid(fake, nrow = 5, padding = 1, pad_value = 0.15)
f = plt.figure(figsize=(15,15))
plt.imshow((grid.permute(1, 2, 0)+1)/2)
plt.axis('off')
plt.show()