### CycleGAN

UNet-Generators with Resblocks in the middle, Basic Discriminators but the PatchGAN option is available, Least Squares loss + Cycle Consistency Loss + Identity Loss
CycleGan maps images from two domains to one another. So if you have photos and paintings, it turns photos into paintings and vice versa.

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from models.cycgan import UNet_Generator, Basic_Discriminator
from utils import load_data

#from google.colab import drive
#drive.mount("/content/drive")

load_data() can be found in utils.py. Specify the size you want to resize it to with the image_size parameter and the number of images you want read using block_size. By default it reads all of it.

In [None]:
print("Fruits (X)....")
x_train = load_data(path="/content/drive/My Drive/Datasets/Fruits", image_size = (128,128), block_size = 1500)
print("Ukiyo_e (Y)....")
y_train = load_data(path="/content/drive/My Drive/Datasets/Ukiyo_e", image_size = (128,128), block_size = 1500)
x_train = (x_train/255)*2 - 1
y_train = (y_train/255)*2 - 1
print(x_train.max(), x_train.min(), y_train.max(), y_train.min())
print(x_train.shape, y_train.shape)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor()])

train_loader_x = torch.utils.data.DataLoader(x_train, batch_size=64, num_workers=2, shuffle = True)
train_loader_y = torch.utils.data.DataLoader(y_train, batch_size=64, num_workers=2, shuffle = True)

In [None]:
data_loader = iter(train_loader_x)
data = next(data_loader)

print("x_data")
print(data.shape)
print(data.max(), data.min())
img = np.transpose(data[0], (1, 2, 0))
plt.imshow((img+1)/2)
plt.show()

data_loader = iter(train_loader_y)
data = next(data_loader)

print("y_data")
print(data.shape)
print(data.max(), data.min())
img = np.transpose(data[0], (1, 2, 0))
plt.imshow((img+1)/2)
plt.show()

There's two Generator Discriminator pairs, all of them have the same learning rate here but you could try mixing it up.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = UNet_Generator().to(device)
H = UNet_Generator().to(device)
D_x = Basic_Discriminator().to(device)
D_y = Basic_Discriminator().to(device)

#Orthogonal initialization is king
for m in G.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)

for m in H.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)

for m in D_x.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)

for m in D_y.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)
        
#Optimizers
optimizerD_x = optim.Adam(D_x.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(G.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerD_y = optim.Adam(D_y.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerH = optim.Adam(H.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [None]:
epochs = 300
_lambda = 10

for epoch in range(epochs):
    for i, (data_x, data_y) in enumerate(zip(train_loader_x, train_loader_y)):

        #Dealing with the discriminators################################
        D_x.zero_grad()
        D_y.zero_grad()
        
        real_images_x = data_x.to(device)
        real_images_y = data_y.to(device)
        
        b_size = real_images_x.size(0)  
  
        output_x = D_x(real_images_x).view(-1)
        output_y = D_y(real_images_y).view(-1)
        errD_real = torch.mean((output_x - 1)**2) + torch.mean((output_y - 1)**2)

        fake_images_y = G(real_images_x)
        fake_images_x = H(real_images_y)
        output_x = D_x(fake_images_x.detach()).view(-1)
        output_y = D_y(fake_images_y.detach()).view(-1)
        errD_fake = torch.mean((output_x)**2) + torch.mean((output_y)**2)

        errD = errD_fake + errD_real
        errD.backward()
        optimizerD_x.step()
        optimizerD_y.step()

        #Dealing with the generators###################################
        G.zero_grad()
        H.zero_grad()
        
        cycled_images_x = H(fake_images_y)
        cycled_images_y = G(fake_images_x)
        identity_x = H(real_images_x)
        identity_y = G(real_images_y)
        
        output_x = D_x(fake_images_x).view(-1)
        output_y = D_y(fake_images_y).view(-1)
        
        errG_adv = torch.mean((output_x - 1)**2) + torch.mean((output_y - 1)**2)
        errG_cyc = torch.mean(torch.abs(cycled_images_x - real_images_x)) + torch.mean(torch.abs(cycled_images_y - real_images_y))
        errG_cyc *= _lambda

        errG_id = torch.mean(torch.abs(identity_x - real_images_x)) + torch.mean(torch.abs(identity_y - real_images_y))
        errG_id *= 0.1*_lambda
        
        errG = errG_adv + errG_id + errG_cyc
        errG.backward()

        optimizerG.step()
        optimizerH.step()
        
        if i%100 == 0:
            print("Epoch %i Step %i --> Disc_Loss : %f   Gen_Loss : %f" % (epoch, i, errD, errG))
            
    #if epoch%100 == 0:
        #torch.save(G.state_dict(), path + "cycgan_G.pth")
        #torch.save(H.state_dict(), path + "cycgan_H.pth")  
        #torch.save(D_x.state_dict(), path + "cycgan_D_x.pth")
        #torch.save(D_y.state_dict(), path + "cycgan_G_y.pth")  

In [None]:
batch_idx = np.random.choice(len(x_train), size = 10)
batch_idy = np.random.choice(len(y_train), size = 10)
data_x = x_train[batch_idx]
data_y = y_train[batch_idy]

print("Actual images")

f, a = plt.subplots(1, 10, figsize=(20, 20))
for i in range(10):
  img = data_x[i]
  img = np.transpose(img, (1, 2, 0))
  img = (img+1)/2
  a[i].imshow(img)
  a[i].axis("off")

f, a = plt.subplots(1, 10, figsize=(20, 20))
for i in range(10):
  img = data_y[i]
  img = np.transpose(img, (1, 2, 0))
  img = (img+1)/2
  a[i].imshow(img)
  a[i].axis("off")

plt.show()

with torch.no_grad():
  real_images_x = torch.Tensor(data_x).to(device)
  real_images_y = torch.Tensor(data_y).to(device)

  fake_images_y = G(real_images_x)
  fake_images_x = H(real_images_y)

print("Translated images")

f, a = plt.subplots(1, 10, figsize=(30, 30))
for i in range(10):
  img = fake_images_y[i].cpu()
  img = np.transpose(img, (1, 2, 0))
  img = (img+1)/2
  a[i].imshow(img)
  a[i].axis("off")

f, a = plt.subplots(1, 10, figsize=(30, 30))
for i in range(10):
  img = fake_images_x[i].cpu()
  img = np.transpose(img, (1, 2, 0))
  img = (img+1)/2
  a[i].imshow(img)
  a[i].axis("off")

plt.show()