### Evolutionary GAN

Based on [this paper](https://arxiv.org/abs/1803.00657) on using an evolutionary setup for the generators, mutating each child according to a different loss function. The generators in this notebook currently only reproduces from a single generator parent but it can be extended to work with multiple parents. [The official repo](https://github.com/WANG-Chaoyue/EvolutionaryGAN-pytorch) was referred to figure out the gradient calculation.

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import copy

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=2, shuffle = True)

data_loader = iter(trainloader)
(data, target) = next(data_loader)

print(data.shape)
print(data.max(), data.min())
img = np.transpose(data[0], (1, 2, 0))
plt.imshow((img+1)/2)
plt.show()

In [None]:
class Basic_Gen(nn.Module):
  def __init__(self, image_size = (32, 32), noise_dim = 128):
    super(Basic_Gen, self).__init__()
    self.h, self.w = image_size[0]//8, image_size[1]//8
    self.Proj = nn.Linear(noise_dim, 256*self.h*self.w)
    self.C1 = nn.Conv2d(256, 128, 3, 1, 1)
    self.C2 = nn.Conv2d(128, 64, 3, 1, 1)
    self.C3 = nn.Conv2d(64, 32, 3, 1, 1)
    self.C4 = nn.Conv2d(32, 3, 3, 1, 1)

    self.B0 = nn.BatchNorm2d(256)
    self.B1 = nn.BatchNorm2d(128)
    self.B2 = nn.BatchNorm2d(64)
    self.B3 = nn.BatchNorm2d(32)

    self.C1.weight.data.normal_(0.0, 0.02)
    self.C2.weight.data.normal_(0.0, 0.02)
    self.C3.weight.data.normal_(0.0, 0.02)
    self.C4.weight.data.normal_(0.0, 0.02)

  def forward(self, z):
    x = self.Proj(z)
    x = x.view(-1, 256, self.h, self.w)
    x = F.leaky_relu(self.B0(x))
    x = F.leaky_relu(self.B1(self.C1(x)))
    x = F.interpolate(x, scale_factor = 2)
    x = F.leaky_relu(self.B2(self.C2(x)))
    x = F.interpolate(x, scale_factor = 2)
    x = F.leaky_relu(self.B3(self.C3(x)))
    x = F.interpolate(x, scale_factor = 2)
    x = torch.tanh(self.C4(x))
    return x

class Basic_Disc(nn.Module):
  def __init__(self, image_size = (32, 32)):
    super(Basic_Disc, self).__init__()
    self.h, self.w = image_size[0]//8, image_size[1]//8
    self.C1 = nn.Conv2d(3, 64, 4, 2, 1)
    self.C2 = nn.Conv2d(64, 128, 4, 2, 1)
    self.C3 = nn.Conv2d(128, 256, 4, 2, 1)

    self.C1.weight.data.normal_(0.0, 0.02)
    self.C2.weight.data.normal_(0.0, 0.02)
    self.C3.weight.data.normal_(0.0, 0.02)
    self.C1 = nn.utils.spectral_norm(self.C1)
    self.C2 = nn.utils.spectral_norm(self.C2)
    self.C3 = nn.utils.spectral_norm(self.C3)

    self.D = nn.Linear(256*self.h*self.w, 1)
    self.D = nn.utils.spectral_norm(self.D)

  def forward(self, x):
    x = F.leaky_relu(self.C1(x))
    x = F.leaky_relu(self.C2(x))
    x = F.leaky_relu(self.C3(x))
    x = x.view(-1, 256*self.h*self.w)
    x = self.D(x)
    return x

bce_loss = nn.BCEWithLogitsLoss()

def loss_func(index, output):
  if i == 0:
    label = torch.zeros(b_size, 1).to(device)
    return -0.5*bce_loss(output, label)
  elif i == 1:
    label = torch.ones(b_size, 1).to(device)
    return 0.5*bce_loss(output, label)
  else:
    return torch.mean((output - 1)**2)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

noise_dim = 128
netG = Basic_Gen(image_size = (32,32), noise_dim = noise_dim).to(device)
netD = Basic_Disc(image_size = (32, 32)).to(device)

child_count = 3

child_Gs = [Basic_Gen(image_size = (32,32), noise_dim = noise_dim).to(device) for _ in range(child_count)]
child_opts = [optim.Adam(x.parameters(), lr = 0.0002, betas = (0.5, 0.999)) for x in child_Gs]


if torch.cuda.device_count() > 1:
    netG = nn.DataParallel(netG, list(range(torch.cuda.device_count())))
    netD = nn.DataParallel(netD, list(range(torch.cuda.device_count())))

#Two Timescale Update Rule
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [None]:
epochs = 500
disc_steps = 2

path = "./saved_models/"

for epoch in range(epochs):
    for i, (data, target) in enumerate(trainloader):

        #Dealing with the discriminator################################
        #Specify number of disc updates above##############
        for k in range(disc_steps):
            netD.zero_grad()
        
            b_size = data.size(0)//disc_steps
            real_images = data[k*b_size:(k+1)*b_size].to(device)
            real_label, fake_label = torch.ones(b_size, 1).to(device), torch.zeros(b_size, 1).to(device)

            output = netD(real_images).view(b_size, -1)
            errD_real = bce_loss(output, real_label)

            noise = torch.randn(b_size, noise_dim, device = device)
            fake = netG(noise)

            output = netD(fake.detach()).view(b_size, -1)
            errD_fake = bce_loss(output, fake_label)

            errD = errD_fake + errD_real
            errD.backward()
            optimizerD.step()
            
        #Dealing with the generator###################################
        netG.zero_grad()
        for g, o in zip(child_Gs, child_opts):
          g.zero_grad()
          g.load_state_dict(netG.state_dict())
          o.load_state_dict(optimizerG.state_dict())
        F_scores, errGs = [], []

        for k in range(child_count):
            #Mutate the child
            netD.zero_grad()
            noise = torch.randn(b_size, noise_dim, device = device)
            fake = child_Gs[k](noise)
            fake_output = netD(fake).view(b_size, -1)

            errG = loss_func(k, fake_output)
            errGs.append(errG.data.cpu().numpy().item())
            errG.backward()
            child_opts[k].step()

            #Evaluate post mutation
            netD.zero_grad()
            noise = torch.randn(b_size, noise_dim, device = device)
            real_output = netD(real_images).view(b_size, -1)
            fake_output = netD(child_Gs[k](noise)).view(b_size, -1)

            Fq = torch.sigmoid(fake_output).data.mean().cpu().numpy()

            div_loss = bce_loss(real_output, real_label) + bce_loss(fake_output, fake_label)
            gradients = torch.autograd.grad(outputs=div_loss, inputs=netD.parameters(),
                                            grad_outputs=torch.ones(div_loss.size()).to(device),
                                            create_graph=True, retain_graph=True, only_inputs=True)
            with torch.no_grad():
                for p, grad in enumerate(gradients):
                    grad = grad.view(-1)
                    allgrad = grad if p == 0 else torch.cat([allgrad,grad]) 

            Fd = -torch.log(torch.norm(allgrad)).data.cpu().numpy()
            F_scores.append(Fq + 0.001*Fd)

        #Figure out best child
        best_index = np.argsort(F_scores)[-1]
        netG.load_state_dict(child_Gs[best_index].state_dict())
        optimizerG.load_state_dict(child_opts[best_index].state_dict())

        if i%100 == 0:
            print(epoch, epochs, i, len(trainloader), "D: ", errD.item(), "Gs: ", errGs, F_scores, best_index)
            
    if epoch%2 == 0:
        !nvidia-smi;

In [None]:
noise = torch.randn(40, noise_dim, device = device)

with torch.no_grad():
  fake = []
  for i in range(7):
    f = netG(noise[i*10:(i+1)*10])
    fake.append(f.cpu())
  fake = torch.cat(fake)

print(fake.shape)
grid = torchvision.utils.make_grid(fake, nrow = 10, padding = 1, pad_value = 0.15)
f = plt.figure(figsize=(15,15))
plt.imshow((grid.permute(1, 2, 0)+1)/2)
plt.axis('off')
plt.show()