In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch.autograd import grad as torch_grad
from models.cycgan import Star_Generator, Star_Patch_Discriminator
from utils import load_data

#from google.colab import drive
#drive.mount("/content/drive")

In [None]:
pathnames = ["./wikiart/Minimalism", "./wikiart/Impressionism", "./wikiart/Pop_Art"]
no_of_domains = len(pathnames)
x_train = np.array([])
y_train = []

for i, path in enumerate(pathnames):
    print("Loading images from", path)
    data = load_data(path, image_size = (128,128), block_size = 1500)
    data = (data/255)*2 - 1
    y_train = y_train + [i for _ in range(len(data))]
    x_train = np.vstack([x_train, data]) if x_train.size else data

x_train = torch.Tensor(x_train)
y_train = torch.LongTensor(np.array(y_train, dtype = np.int32))

print(x_train.shape, y_train.shape)
print(x_train.max(), x_train.min(), y_train.max(), y_train.min())


In [None]:
batch_id = np.random.choice(len(x_train), size = 10)
data = x_train[batch_id]

print("data")
print(data.shape)
print(data.max(), data.min())
img = np.transpose(data[0], (1, 2, 0))
plt.imshow((img+1)/2)
plt.show()

In [None]:
def GradientPenalty(discriminator_model, real_data, generated_data, gp_weight = 10):
    batch_size = real_data.size()[0]

    # Calculate interpolation
    alpha = torch.rand(batch_size, 1, 1, 1)
    alpha = alpha.expand_as(real_data)
    if torch.cuda.is_available():
        alpha = alpha.cuda()

    interpolated = alpha * real_data + (1 - alpha) * generated_data
    interpolated = Variable(interpolated, requires_grad=True)

    if torch.cuda.is_available():
        interpolated = interpolated.cuda()

    # Calculate probability of interpolated examples
    prob_interpolated, _ = discriminator_model(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(prob_interpolated.size()).cuda() if torch.cuda.is_available() else torch.ones(
                           prob_interpolated.size()),
                           create_graph=True, retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(batch_size, -1)

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

    # Return gradient penalty
    return gp_weight * ((gradients_norm - 1) ** 2).mean()


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = Star_Generator(img_size = (128,128), cond_length = no_of_domains).to(device)
D = Star_Patch_Discriminator(img_size = (128,128), in_channels = 3, cond_length = no_of_domains).to(device)

for m in G.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.kaiming_normal(m.weight)

for m in D.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.kaiming_normal(m.weight)
 
#Optimizers
optimizerD = optim.Adam(D.parameters(), lr = 0.0002, betas = (0.0, 0.9))
optimizerG = optim.Adam(G.parameters(), lr = 0.0002, betas = (0.0, 0.9))

In [None]:
epochs = 300
_lambda = 10
batch_size = 16
disc_steps = 3
cls_criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for i in range(len(x_train)//batch_size):
        for k in range(disc_steps):
            #Dealing with the discriminators################################
            D.zero_grad()

            batch_id = np.random.choice(len(x_train), size = batch_size)
            data = x_train[batch_id]
            targets = y_train[batch_id].to(device)

            real_images = data.to(device)
            #Make a bunch of fake labels to translate the images into
            fake_targets = np.array([i for i in np.random.randint(0, no_of_domains, size = batch_size)])
            fake_targets = torch.LongTensor(fake_targets).to(device)
            
            output, pred_cls = D(real_images)
            errD_real = -torch.mean(output)
            errD_cls = torch.mean(cls_criterion(pred_cls, targets))
            a1 = F.softmax(pred_cls, dim = 1).argmax(dim = 1)
            a2 = targets

            fake_images = G(real_images, fake_targets)
       
            output, pred_cls = D(fake_images.detach())
            errD_fake = torch.mean(output) 

            GP = GradientPenalty(D, real_images, fake_images, gp_weight = 10)
     
            errD = errD_fake + errD_real + errD_cls + GP
            errD.backward()        

            optimizerD.step()

        #Dealing with the generators###################################
        G.zero_grad()
        
        output, pred_cls = D(fake_images)
        b1 = F.softmax(pred_cls, dim = 1).argmax(dim = 1)
        b2 = fake_targets

        errG = -torch.mean(output) 
        errG_cls = torch.mean(cls_criterion(pred_cls, fake_targets))
        
        cycled_images = G(fake_images, targets)
        
        errG_cyc = torch.mean(torch.abs(cycled_images - real_images))
        errG_cyc *= _lambda

        id_images = G(real_images, targets)
        errG_id = torch.mean(torch.abs(id_images - real_images))
        errG_id *= 0.1*_lambda

        errG = errG + errG_cls + errG_cyc + errG_id
        errG.backward()

        optimizerG.step()
        
        if i%100 == 0:
            print("Epoch %i Step %i --> Disc_Loss : %f (Cls %f)  Gen_Loss : %f (Cls %f Cyc %f Id %f) Acc R: %f, F: %f" 
                  % (epoch, i, errD, errD_cls, errG, errG_cls, errG_cyc, errG_id, 
                     (a1 == a2).float().mean().item(), (b1 == b2).float().mean().item()))

In [None]:
path = "./saved_models/"
torch.save(G.state_dict(), path + "stargan_G.pth")
torch.save(D.state_dict(), path + "stargan_D.pth")  

In [None]:
batch_idx = np.random.choice(len(x_train), size = 10)
data_x, data_y = x_train[batch_idx], y_train[batch_idx]

print("Actual images")

f, a = plt.subplots(1, 10, figsize=(30, 30))
for i in range(10):
  img = data_x[i]
  img = np.transpose(img, (1, 2, 0))
  img = (img+1)/2
  a[i].imshow(img)
  a[i].axis("off")
plt.show()

for k in range(no_of_domains):
  with torch.no_grad():
    real_images = torch.Tensor(data_x).to(device)
    fake_targets = np.array([k for i in range(10)])
    print(fake_targets[0])
    fake_targets = torch.LongTensor(fake_targets).to(device)
    fake_images = G(real_images, fake_targets)

  print("Translated images (Target", pathnames[k].split("/")[-1],")")

  f, a = plt.subplots(1, 10, figsize=(30, 30))
  for i in range(10):
    img = fake_images[i].cpu()
    img = np.transpose(img, (1, 2, 0))
    img = (img+1)/2
    a[i].imshow(img)
    a[i].axis("off")
  plt.show()
