### pix2pix

UNet-Generator with Resblocks in the middle, Basic Discriminators but the PatchGAN option is available, Least Squares loss + L1 reconstruction loss
Conditioning in the discriminator is done by concatenating the images from domain x and y together and feeding it to the discriminator.

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from models.cycgan import UNet_Generator, Basic_Discriminator
from utils import load_data

#from google.colab import drive
#drive.mount("/content/drive")

In [None]:
print("Coloured images (Y)....")
y_train = load_data(path="/content/drive/My Drive/Datasets/Fruits", image_size = (128,128), block_size = 1500)
print("BW images (X)....")
x_train = load_data(path="/content/drive/My Drive/Datasets/Fruits", image_size = (128,128), block_size = 1500, 
                    as_grayscale = True)

y_train = (y_train/255)*2 - 1
x_train = (x_train/255)*2 - 1

print(x_train.max(), x_train.min(), y_train.max(), y_train.min())
print(x_train.shape, y_train.shape)

In [None]:
train_loader_x = torch.utils.data.DataLoader(x_train, batch_size=64, num_workers=2, shuffle = True)
train_loader_y = torch.utils.data.DataLoader(y_train, batch_size=64, num_workers=2, shuffle = True)

In [None]:
data_loader = iter(train_loader_x)
data = next(data_loader)

print("x_data")
print(data.shape)
print(data.max(), data.min())
img = np.transpose(data[0], (1, 2, 0))
plt.imshow((img.reshape(128, 128)+1)/2, cmap = "gray")
plt.show()

data_loader = iter(train_loader_y)
data = next(data_loader)

print("y_data")
print(data.shape)
print(data.max(), data.min())
img = np.transpose(data[0], (1, 2, 0))
plt.imshow((img+1)/2)
plt.show()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = UNet_Generator(in_channels = 1).to(device)
D = Basic_Discriminator(in_channels = 4).to(device)

#Orthogonal initialization is king
for m in G.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)

for m in D.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)
 
#Optimizers
optimizerD = optim.Adam(D.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(G.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [None]:
epochs = 300
_lambda = 100

for epoch in range(epochs):
    for i, (data_x, data_y) in enumerate(zip(train_loader_x, train_loader_y)):

        #Dealing with the discriminators################################
        D.zero_grad()

        real_images_x = data_x.to(device)
        real_images_y = data_y.to(device)
        
        b_size = real_images_x.size(0)
        #Concatenate image x as a condition image y, conditional discriminator yes?
        input_images = torch.cat([real_images_x, real_images_y], dim = 1)
  
        output = D(input_images).view(-1)
        errD_real = torch.mean((output - 1)**2)

        #concatenate x with the fake images and then feed it to the discriminator
        fake_images_y = G(real_images_x)
        input_images = torch.cat([real_images_x, fake_images_y], dim = 1)
        
        output = D(input_images.detach()).view(-1)
        errD_fake = torch.mean((output)**2)

        errD = errD_fake + errD_real
        errD.backward()
        optimizerD.step()

        #Dealing with the generators###################################
        G.zero_grad()
        
        output = D(input_images).view(-1)
        
        errG_adv = torch.mean((output - 1)**2) 
        errG_cyc = torch.mean(torch.abs(fake_images_y - real_images_y))
        errG_cyc *= _lambda
        
        errG = errG_adv + errG_cyc
        errG.backward()

        optimizerG.step()
        
        if i%100 == 0:
            print("Epoch %i Step %i --> Disc_Loss : %f   Gen_Loss : %f" % (epoch, i, errD, errG))

In [None]:
batch_idx = np.random.choice(len(x_train), size = 10)
data_x = x_train[batch_idx]

print("Actual images")

f, a = plt.subplots(1, 10, figsize=(20, 20))
for i in range(10):
  img = data_x[i]
  img = np.transpose(img, (1, 2, 0))
  img = (img.reshape(128,128)+1)/2
  a[i].imshow(img, cmap = "gray")
  a[i].axis("off")

plt.show()

with torch.no_grad():
  real_images_x = torch.Tensor(data_x).to(device)

  fake_images_y = G(real_images_x)

print("Translated images")

f, a = plt.subplots(1, 10, figsize=(30, 30))
for i in range(10):
  img = fake_images_y[i].cpu()
  img = np.transpose(img, (1, 2, 0))
  img = (img+1)/2
  a[i].imshow(img)
  a[i].axis("off")

plt.show()