# Mars StyleGAN

The [Basic Gan](https://github.com/kheyer/MarsGAN/blob/master/Basic%20Gan.ipynb) notebook showed that a very simple generator/discriminator setup can generate fake Mars landscapes that are poor quality but still recognizable. The basic GAN suffered from several flaws. The images generated were blurry, pixilated and full of artifacts. The images were small, 64x64 pixels. Training a model of the same architecture to create larger images proved unstable. In this notebook, we use a more sophistocated generator to create higher quality images at larger sizes. We are going to use the famous [StyleGAN](https://arxiv.org/pdf/1812.04948.pdf).

StyleGAN is noteworthy because it was the first GAN model to disentangle the latent factors used to generate an image. Most GAN models start with a random vector that is used to generate an image. StyleGAN does things differently. A random latent vector is first sent through a mapping network consisting of several fully connected layers. This allows the generator to learn an intermediate latent space that is best for generating images. The intermediate latent vector is then injected into the generator at different layers.

![](media/stylegan.png)

The key finding of StyleGAN was that adding the intermediate latent vector at different layers of the generator caused different effects on the output image. Injection at early layers of the generator influenced major details of the image, while injection at later layers influenced finer details. The result is that the StyleGAN generator can be used to blend different latent vectors.

![](media/stylegan_mixing.jpg)

StyleGAN is able to train with good stability to images as large as 1024x1024 by training with [progressive growing](https://arxiv.org/pdf/1710.10196.pdf). We start training on 4x4 images. After some time, we grow to 8x8, then 16x16 and so on. When we grow the model, there's a small issue. We now have a new layer that converts a tensor of activations to a 3 channel RGB image. Since this new layer is untrained, it likely produces poor quality images. To prevent this from messing up training, the new RGB layer is phased in.

![](media/rgb.png)

So that's a quick overview. Lets start looking at the model.

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai import *
from fastai.vision import *
from fastai.vision.gan import *
from fastai.callbacks import *

In [3]:
path = Path('G:/Mars/')

In [4]:
def get_data(bs, size, path, num_workers, noise_size=512):
    return (GANItemList.from_folder(path, noise_sz=noise_size)
               .split_none()
               .label_from_func(noop)
               .transform(tfms=[[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], []], size=size, tfm_y=True)
               .databunch(bs=bs, num_workers=num_workers)
               .normalize(stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_x=False, do_y=True))

In [5]:
# I resized the main dataset into individual sets for each image size
# This has a noticeable effect on training speed
data_4 = get_data(256, 4, path/'patches_4', 8)
data_8 = get_data(256, 8, path/'patches_8', 8)
data_16 = get_data(128, 16, path/'patches_16', 8)
data_32 = get_data(128, 32, path/'patches_32', 8)
data_64 = get_data(48, 64, path/'patches_64', 8)
data_128 = get_data(22, 128, path/'patches_128', 8)
data_256 = get_data(10, 256, path/'patches_256', 8)
data_512 = get_data(4, 512, path/'image_patches', 8)

## StyleGAN Fundamentals

First we start with the basic building blocks of the StyleGAN. StyleGAN uses Equalized Learning Rates for all layers. This is a technique from the Progressive GAN paper.

Layers are typically initialized to a careful distribution to ensure the initial output of the untrained layer has approximately mean zero and standard deviation one. This is to avoid having activations or gradients vanish or explode early in training. Kaiming initialization scales all values by a factor of $c = \frac{\sqrt2}{W}$ where $W$ is the length of the weight matrix along one dimension.

Equalized learning rates applies the Kaiming constant dynamically during runtime, rather than just at the initialization. The reason for doing this is, to quote the paper, "somewhat subtle". Many optimization algorithms normalize gradients in such a way that the scale of the gradient update independent of the scale of the parameter being updated. This leads to some small parameters receiving updates that are too large, and large parameters receiving updates that are too small. The effect is that the learning rate is too high for some parameters and too low for others, leading to instability. The Equalized Learning Weights technique applies Kaiming scaling on every forward pass to ensure all the weights in the model are of a similar range.

In [None]:
class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1) * weight.data[0][0].numel()

        return weight * math.sqrt(2 / fan_in)

    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)


def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)

    return module

In [None]:
class EqualConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_lr(conv)

    def forward(self, input):
        return self.conv(input)


class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()

        linear = nn.Linear(in_dim, out_dim)
        linear.weight.data.normal_()
        linear.bias.data.zero_()

        self.linear = equal_lr(linear)

    def forward(self, input):
        return self.linear(input)

In [None]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        padding,
        kernel_size2=None,
        padding2=None,
        pixel_norm=True,
        spectral_norm=False,
    ):
        super().__init__()

        pad1 = padding
        pad2 = padding
        if padding2 is not None:
            pad2 = padding2

        kernel1 = kernel_size
        kernel2 = kernel_size
        if kernel_size2 is not None:
            kernel2 = kernel_size2

        self.conv = nn.Sequential(
            EqualConv2d(in_channel, out_channel, kernel1, padding=pad1),
            nn.LeakyReLU(0.2),
            EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
            nn.LeakyReLU(0.2),
        )

    def forward(self, input):
        out = self.conv(input)

        return out

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)

In [None]:
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel, style_dim):
        super().__init__()

        self.norm = nn.InstanceNorm2d(in_channel)
        self.style = EqualLinear(style_dim, in_channel * 2)

        self.style.linear.bias.data[:in_channel] = 1
        self.style.linear.bias.data[in_channel:] = 0

    def forward(self, input, style):
        style = self.style(style).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, 1)

        out = self.norm(input)
        out = gamma * out + beta

        return out

In [None]:
class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))

    def forward(self, image, noise):
        return image + self.weight * noise


class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, input):
        batch = input.shape[0]
        out = self.input.repeat(batch, 1, 1, 1)

        return out