In [None]:
import os
import numpy as np
import time
import torch
from torch import nn
from torch import optim
from tensorboardX import SummaryWriter
import torch.nn.functional as F

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(input_shape[0], 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.dis(x)
    
    def _get_conv_out(self, shape):
        o = self.dis(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

class Generator(nn.Module):
    def __init__(self, z_size):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(in_channels=z_size, out_channels=512, kernel_size=4, stride=1, bias=False),
            nn.BatchNorm2d(512, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)
    

In [None]:
latent_size = 128
n_critic = 5
labmda = 10
adam_betas = (0.5, 0.9)
image_shape = (3, 64, 64)
batch_size = 64
lrD = 1e-4
lrG = 1e-4
epochs = 100
DEVICE='cpu'

In [None]:
from PIL import Image
import numpy as np
import tarfile
from os import listdir
from os.path import isfile, join
from keras.preprocessing.image import load_img, img_to_array
from torchvision import transforms
import functools

class ArtLoader:
    def __init__(self, location = './data/', batch_size=20):

        self.location = location
        self.batch_size = batch_size
        self.styles = [
                  'abstract',
                  'animal-painting',
                  'cityscape',
                  'figurative',
                  'flower-painting'
                  'genre-painting',
                  'landscape',
                  'marina',
                  'mythological-painting',
                  'nude-painting-nu',
                  'portrait',
                  'religious-painting',
                  'still-life',
                  'symbolic-painting'
                 ]
        
        if batch_size % len(self.styles) != 0:
            raise ValueError("batch size must be divisible by num classes")
        self.class_batch = batch_size // len(self.styles)

        self.curPos = { style:0 for style in self.styles }
        self.style_files = { style:[f for f in listdir(location  + style + '/') if isfile(join(location + style + '/', f))] for style in self.styles }
        self.file_count = functools.reduce(lambda x, style: x + len(self.style_files[style]), self.style_files.keys(), 0)
        self.n_batches = self.file_count // batch_size

    def _process_image(self, x):
        x = x.astype(np.float32) / 255.0 * 2 - 1
        return x

    def __iter__(self):
        self.current_batch = 0
        while self.current_batch < self.n_batches:
            dataset = []
            
            for style in self.styles:
                count = 0
                while count < self.class_batch and self.current_batch < self.n_batches:
                    try:
                        file = self.style_files[style][self.curPos[style]]
                        img = load_img(self.location + style + '/' + file)  # this is a PIL image
                        img = img.resize((64,64))
                        x = np.array(img)
                        x = np.rollaxis(x, 2, 0)
                        dataset.append(x)
                        count += 1
                    except OSError:
                        print('Error getting file! Skipping...')
                    self.curPos[style] += 1
                    if self.curPos[style] >= len(self.style_files[style]):
                        self.curPos[style] = 0
                        
            self.current_batch += 1
            
            if len(dataset) != self.batch_size:
                raise StopIteration

            yield(self._process_image(np.array(dataset)))
            
from IPython.display import display
def showX(X, name='test.jpeg', rows=1):
    assert X.shape[0]%rows == 0
    int_X = ( (X+1)/2*255).clip(0,255).astype('uint8')
    # N*3072 -> N*3*32*32 -> 32 * 32N * 3
    int_X = np.moveaxis(int_X.reshape(-1,3,64,64), 1, 3)
    int_X = int_X.reshape(rows, -1, 64, 64,3).swapaxes(1,2).reshape(rows*64,-1, 3)
    Image.fromarray(int_X).save("./output/" + name)
    display(Image.fromarray(int_X))
    

In [None]:
def enable_gradients(net):
    for p in net.parameters():
        p.requires_grad = True


def disable_gradients(net):
    for p in net.parameters():
        p.requires_grad = False


art_data = ArtLoader('./data/', batch_size)
gen_net = Generator(latent_size)
dis_net = Discriminator(image_shape)
print(dis_net)
dis_optimizer = optim.Adam(dis_net.parameters(), lr=lrD, betas=adam_betas)
gen_optimizer = optim.Adam(gen_net.parameters(), lr=lrG, betas=adam_betas)
noise = torch.FloatTensor(batch_size, latent_size, 1, 1).to(DEVICE)
alpha = torch.FloatTensor(batch_size, 1, 1, 1).to(DEVICE)
dis_loss_list, gen_loss_list = [], []
gen_count = 0
seen_batch_count = 0
writer = SummaryWriter(comment="gantoken-main")

for epoch in range(epochs):
        critic_count = 0
        gen_count = 0

        if gen_count < 25:
            _n_critic = 100
        else:
            _n_critic = n_critic

        for art_batch_index, art_batch in enumerate(art_data):
            if critic_count < _n_critic:
                # disable generator grad
                disable_gradients(gen_net)
                enable_gradients(dis_net)
                # convert art image
                image_data_v = torch.tensor(art_batch, requires_grad=False).to(DEVICE)
                # sample normal noise
                noise.data.normal_()
                # use noise to generate image
                fake_data_v = gen_net(noise).detach()
                # send fake and real images through discriminator
                dis_out_fake_v = dis_net(fake_data_v).to(DEVICE)
                dis_out_real_v = dis_net(image_data_v).to(DEVICE)
                # calc the first part of loss
                critic_loss = dis_out_fake_v.mean() - dis_out_real_v.mean()
                # sample our alpha
                alpha.uniform_()
                # calc second part of loss and get its gradient
                interpolation = alpha * image_data_v + (1 - alpha) * fake_data_v
                interpolation = torch.autograd.Variable(interpolation, requires_grad=True).to(DEVICE)
                dis_interpolation = dis_net(interpolation).to(DEVICE)
                gradients = torch.autograd.grad(inputs=interpolation, outputs=dis_interpolation.sum(), create_graph=True)[0]
                critic_loss += ((gradients.norm(2, dim=1) - 1) ** 2).mean() * labmda
                # run backward and update weights via Adam
                dis_net.zero_grad()
                critic_loss.backward()
                dis_optimizer.step()
                dis_loss_list.append(critic_loss.item())
                critic_count += 1
                writer.add_scalar("dis_loss_average", np.mean(dis_loss_list[:100]), seen_batch_count)
                writer.add_scalar("dis_loss", critic_loss, seen_batch_count)

            if seen_batch_count % n_critic == 0 and seen_batch_count >= 10:
                critic_count = 0 # reset critic for advataged training
                # disable discriminator part of the graph
                disable_gradients(dis_net)
                enable_gradients(gen_net)
                gen_net.zero_grad()
                # sample normal noise
                noise.data.normal_()
                generated_images_v = gen_net(noise).to(DEVICE)
                # get loss from discriminator
                gen_loss = - dis_net(generated_images_v).mean()
                gen_loss.backward()
                gen_optimizer.step()
                gen_count += 1
                gen_loss_list.append(gen_loss.item())
                writer.add_scalar("gen_loss_average", np.mean(dis_loss_list[:100]), seen_batch_count)
                writer.add_scalar("gen_loss", gen_loss, seen_batch_count)
                
            if seen_batch_count < 10:
                critic_count= 0

            seen_batch_count += 1

        noise.data.normal_()
        gen_net.eval()
        fake = gen_net(noise).squeeze(-1)
        showX(fake.data.numpy(), str(epoch) + '-epoch.jpeg', 4)
        torch.save(gen_net.state_dict(), './output/generator.h5')
        torch.save(dis_net.state_dict(), './output/discriminator.h5')
        
writer.close()
