### Code Authors
Özgür Aslan 2236958 aslan.ozgur@metu.edu.tr
Burak Bolat 2237097 burak.bolat@maetu.edu.tr

### Paper Information
The paper we selected to implement is [HistoGAN: Controlling Colors of GAN-Generated and Real Images via Color Histograms](https://arxiv.org/abs/2011.11731)
The main idea of the paper is to use color histogram of target images to control the colors of the generated image without changing the high level features of the generated image (gender, having glasses, beard, hair style and objects in the background...)  
To accomplish this idea, they modify the StyleGAN2 architecture:
- In the last 2 style blocks, instead of using affine transformation of the w vector, they use the color histogram projected by a neural network.
- Different from mixing regularization, they use a histogram based loss.
- The histogram loss uses two different target images to compute color histograms and interpolate this histograms to obtain a new one. The interpolated histogram is given to generator network to generate a target image with colors controlled by the interpolated histogram. This way the authors try to prevent the generator network to overfit color histograms of the trained dataset.
- Due to hardware limitations the network does not generate 1024x1024 resolution images but generates 256x256 images.
- Also due to hardware limitations they use batches of size 2 with gradient accumulation.

![arch](materials/arch.png)

#### Histogram Computation
Computing Histogram is critical since it directly affects style of last 2 blocks. Authors used chrominance logarithm space. It normalizes each color channel with respect to other two channels in logarithmic space. In this chrominance space, there is u and v axes. That is, if we look at red channel's chrominance space, u is the normalization of red channel with respect to green and v is the normalization of red channel with respect to blue. Same holds for all color channels.  

After shifting RGB space to RGB-uv space, the histogram is computed as it is computationally efficient and more stable. Authors used 64 bin for the histogram which results in 64x64 histogram for u and v channel. We have 3 channels, namely red, green and blue, thus, overall the histogram is 3x64x64. Histogram is weighted with respect to pixel intesity, i.e. if a pixel has high RGB values its affect on the histogram bin is higher. Last difference of the histogram than histograms of previous works is kernels for computing bins. Authors do not used exact bin selection. Instead, they put a normalized pixel into a bin with respect to soft kernel. That means, if we have a red channel after normalized with respect to green and blue, we have some u and v values. Instead of just adding 1 (1 being chosen for simplicty, remember intensity multiplication) to the bin of H(u,v), they add values to the neighbour of (u,v) with the value after inverse quadratic kernel.

Histogram feature is computed like syle vector (w). It passed through the same neural network architecture with different parametes. More precisely, histogram passes through 8 layer MLP and outputs latent histogram vector size of 512.  

We put some computed histograms by us. Images taken from internet crawling.  

![asd](materials/gresized.png) ![asc](materials/ghist1.png)  
![asb](materials/rresized.png) ![asj](materials/rhist1.png)

#### Loss for Training with Histogram
Since the paper uses target histogram for generation, generated image should have close histogram to target. Thus a closeness measure Hellinger distance between histogram of generated and target images is computed and tried to minimize. You can check the losses belove.

Difference between histograms  
![l1](materials/hloss.png)

Total loss for generator  
![lt](materials/total_loss.png)

### Discriminator

Discriminator consist of residual blocks. There are log_2(N)-1 such bloks where N is image resolution, to be spesific 256. As a result, the discriminator has 7 layers. First block takes 3 channel image as input and outputs m channel features. After the first block, each block produces 2*m of previous block. At the end of residual blocks, a FC layer outputs a scaler.

![res](materials/residual.png)

### Important Note on Dataset
The Anime Face Dataset is a Kaggle dataset, thus, requires a Kaggle account. Using an account one can download it from:
https://www.kaggle.com/datasets/splcher/animefacedataset


### Faced Challenges

#### Architecture Challenges

As we stated, HistoGAN is built on StyleGAN2. StyleGAN2 scales directly the weigth of the model, unlike the first version does the nearly same operations, namely mod-demod, on convolved image (or say not directly weights for convoltion filters). The original StyleGAN2 implemented using Tensorflow which allows to multiplication on weights, that is called in-place operation on variables. However, Pytorch does not allow in-place operations on built in modules like torch.nn.Conv2d. Therefore, we implemented a conv2d version. Model parameters are Pytorch Variables and convolution operation is handled with fold and unfold operations of Pytorch. Doing so we can apply convolution after scaling weights of convolution filters.

#### Saturation of Generator

After implementing StyleGAN2 and HistoGAN, we tried to train models. We saw that generator of HistoGAN does not learn and tried to train StyleGAN (with the shallow version that HistoGAN uses, it produces 256x256 images). However, we get rapid saturation of generator and have not solved yet. Here we present some generated images from training. 

![m1](materials/fake_0_199_0.png) ![m2](materials/fake_0_199_1.png) ![m3](materials/fake_0_399_0.png) ![m4](materials/fake_0_599_0.png) 

#### Training Challenges

During the training phase, the paper does not mentioned how generator outputs the images. We made different assumption such as using sigmoid or tanh to generate pixels in a range. Another assumption for the same problem is using ReLU or leaky ReLU that we saw from other generator implementation.  
StyleGAN2 stated that they used non saturating loss for some datasets and WGAN-GP loss for other datasets. HistoGAN paper does not clearly mention on this. Consequently, we implemented both but non saturating loss lead numerical issues like nan or infs. On the other hand, WGAN-GP computes high loss values and results in rapid saturation (see above figures). This issues may be resulted from hand implemented convolution operations. 

#### Implementation Info
- python 3.7.13
- pytorch 1.11.0 with cuda10.2   
We used conda environments for clean library setups and included environment.yml file. 

In [None]:
# imports
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image

from data import AnimeFacesDataset
from model import Discriminator, HistoGAN
from loss import compute_gradient_penalty, pl_reg, r1_reg, wgan_gp_disc_loss, wgan_gp_gen_loss
from utils import random_interpolate_hists, histogram_feature_v2

import os

# for debugging
torch.autograd.set_detect_anomaly(True)
# for faster trainingg
torch.backends.cudnn.benchmark = True

In [2]:
# parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
real_image_dir = "images"
transform = transforms.Compose(
        [transforms.Resize((256,256)),
        transforms.RandomHorizontalFlip(0.5)])
dataset = AnimeFacesDataset(real_image_dir, transform, device)
# due to hardware limitations similar to paper's authors we kept the batch size small
batch_size = 2
# the dataset contains 63,632 datum, and the we could not make the network to generate meaningfull images therefore kept the epochs small and experimented
num_epochs = 2
# variable to hold after how many discriminator updates to update the generator
g_update_iter = 5
# after how many gradient accumulation to optimize parameters
acc_gradient_iter = 1
# scalar of R1 regularization
r1_factor = 10
# variables for Path length regularization
# please see StyleGAN2 paper B. Implementation Details Path length regularization
ema_decay_coeff = 0.99
target_scale = torch.tensor([0]).to(device)
plr_factor = np.log(2)/(256**2*(np.log(256)-np.log(2)))
# after how many iterations to save the nework parameters and generated images
save_iter = 200
# path to save generated images
fake_image_dir = "generated_images"
if not os.path.isdir(fake_image_dir):
    os.mkdir(fake_image_dir)
# number of residual blocks in the discriminator 
num_res_blocks = 7
# network capacity to decide the intermediate channel sizes of discrimimator and learnable constant channel size of generator 
network_capacity = 16 
# histogram's bin size
bin_size = 64
# the number of channels are decides as log2(image_res) -1 since we generate 256 res images, there are 7 channels
generator_channel_sizes = [1024, 512, 512, 512, 256, 128, 64]
learning_rate = 2e-4
# coefficient of gradient penalty
coeff_penalty = 10 # same as the StyleGAN2 paper

In [3]:
# Initialize Dataset, Dataloader, Discriminator and Generator
dataset = AnimeFacesDataset(real_image_dir, transform, device)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
generator = HistoGAN(network_capacity, bin_size, generator_channel_sizes)
discriminator = Discriminator(num_res_blocks, network_capacity)

# If a pretrained network exists, load their parameters to continue training
if os.path.exists("generator.pt"):
    generator.load_state_dict(torch.load("generator.pt"))
if os.path.exists("discriminator.pt"):
    discriminator.load_state_dict(torch.load("discriminator.pt"))


discriminator = discriminator.to(device)
generator=generator.to(device)

# Initialize optimizers 
gene_optim = torch.optim.Adam(generator.parameters(), lr= learning_rate)
disc_optim = torch.optim.Adam(discriminator.parameters(), lr= learning_rate)

In [4]:
# Traning loop without gradient accumulation
# Gradient accumulation is implemented and tried in train2.py but had some performance and memory consumption issues therefore not added here
for epoch in range(num_epochs):
    for iter, batch_data in enumerate(dataloader):
        # torch.cuda.empty_cache() 
        training_percent = 100*iter*batch_data.size(0)/len(dataset)
        batch_data = batch_data.to(device)
        # Sample random Gaussian noise
        z = torch.randn(batch_data.size(0), 512).to(device)
        # Interpolate between target image histogram 
        # to prevent overfitting to dataset images
        target_hist = random_interpolate_hists(batch_data)
        # Generate fake images
        fake_data, w = generator(z, target_hist)

        # Detach fake data so no gradient accumalition 
        # to generator while only training discriminator
        fake_data = fake_data.detach()

        # Compute real probabilities computed by discriminator
        fake_scores = discriminator(fake_data)
        real_scores = discriminator(batch_data)
        gradient_penalty = compute_gradient_penalty(fake_data, batch_data, discriminator)
        d_loss = wgan_gp_disc_loss(real_scores, fake_scores, gradient_penalty, coeff_penalty)
        #d_loss = disc_loss(fake_scores, real_scores)
        # in stylegan2 paper they argue applying regularization in every 16 iteration does not hurt perfrormance 
        if (iter+1) % 16 == 0: 
            # r1 regulatization
            d_loss = d_loss + r1_reg(batch_data, discriminator, r1_factor)  

        print("%", training_percent, " Disc loss:", d_loss.item())
        d_loss.backward()
        disc_optim.step()
        disc_optim.zero_grad()

        if (iter+1) % g_update_iter == 0:
            z = torch.randn(batch_data.size(0), 512).to(device)
            fake_data, w = generator(z, target_hist) 

            disc_score = discriminator(fake_data)
            g_loss = wgan_gp_gen_loss(disc_score)
            if (iter+1) % (8*g_update_iter) == 0:
                plr, ema_decay_coeff = pl_reg(fake_data, w, target_scale, plr_factor, ema_decay_coeff)
                g_loss = g_loss + plr

            print("%", training_percent, "Gen loss:", g_loss.item())
            g_loss.backward()
            gene_optim.step()
            gene_optim.zero_grad()
            
        if (iter+1) % save_iter == 0:
            for i in range(fake_data.size(0)):
                save_image(fake_data[i], os.path.join(fake_image_dir, "fake_{}_{}_{}.png".format(epoch, iter, i)))
            torch.save(generator.state_dict(), "generator.pt")
            torch.save(discriminator.state_dict(), "discriminator.pt")

% 0.0  Disc loss: 947.0718994140625
% 0.003146385589554  Disc loss: -300.9945373535156
% 0.006292771179108  Disc loss: -2331.33837890625
% 0.009439156768662  Disc loss: -9841.419921875
% 0.012585542358216  Disc loss: -93025.453125
% 0.012585542358216 Gen loss: -108935.6875
% 0.015731927947769998  Disc loss: -127629.859375
% 0.018878313537324  Disc loss: -575892.1875
% 0.022024699126878  Disc loss: -2343000.25
% 0.025171084716432  Disc loss: -9243621.0
% 0.028317470305985998  Disc loss: -22108960.0
% 0.028317470305985998 Gen loss: -39207744.0
% 0.031463855895539995  Disc loss: 38732028.0
% 0.034610241485094  Disc loss: -11762508.0
% 0.037756627074648  Disc loss: -19876096.0
% 0.040903012664201994  Disc loss: -10089273.0
% 0.044049398253756  Disc loss: -19209162.0
% 0.044049398253756 Gen loss: -28446576.0
% 0.047195783843309996  Disc loss: 7060448.5
% 0.050342169432864  Disc loss: -10064836.0
% 0.053488555022418  Disc loss: -1532740.0
% 0.056634940611971996  Disc loss: -8062132.0
% 0.059

KeyboardInterrupt: 