# Introduction to Variational Autoencoder and Generative Adversarial Networks in the context of Probabilistic Programming 

In this notebook we constrain ourselves to code sketches of the core ideas behind the integration of generative adversarial networks (GANs) and variational autoencoders (VAEs) with probabilistic programming frameworks.

We will keep using [Turing.jl](https://github.com/TuringLang/Turing.jl), which is now combined with the machine learning stack of [Flux.jl](https://github.com/fluxml/flux.jl) to provide support for layer abstractions. This notebook basically extends upon the initial ideas from the `BayesianNeuralNetworks` notebook.

In [None]:
using Turing, DifferentialEquations, DiffEqFlux

In [None]:
using Base.Iterators: partition
using Flux
using Flux.Optimise: update!
using Flux: logitbinarycrossentropy
using Images
using MLDatasets
using Statistics
using Parameters: @with_kw
using Printf
using Random

We will use the MNIST dataset across the GAN and the VAE example. We will start by creating a structure to hold the hyper-parameters

In [None]:
@with_kw struct HyperParams
    batch_size::Int = 128
    latent_dim::Int = 100
    epochs::Int = 20
    verbose_freq::Int = 1000
    output_x::Int = 6
    output_y::Int = 6
    lr_dscr::Float64 = 0.0002
    lr_gen::Float64 = 0.0002
end

Then we define for ourselves a function to create output images, i.e. act in a generative fashion

In [None]:
function create_output_image(gen, fixed_noise, hparams)
    @eval Flux.istraining() = false
    fake_images = @. cpu(gen(fixed_noise))
    @eval Flux.istraining() = true
    image_array = permutedims(dropdims(reduce(vcat, reduce.(hcat, partition(fake_images, hparams.output_y))); dims=(3, 4)), (2, 1))
    image_array = @. Gray(image_array + 1f0) / 2f0
    return image_array
end

## 1. Sketch of GAN Inclusion in Probabilistic Programming Systems

We will now begin by defining for ourselves the normal GAN structure with a Discriminator and Generator. Beware that the functional syntax of Flux is quite similar to JAX.

In [None]:
function Discriminator()
    Chain(Conv((4,4), 1 => 64; stride = 2, pad = 1),
               x->leakyrelu.(x, 0.2f0),
               Dropout(0.25),
               Conv((4,4), 64 => 128; stride = 2, pad = 1),
               x->leakyrelu.(x, 0.2f0),
               Dropout(σ), 
               x->reshape(x, 7*7*128, :),
               Dense(7*7*128, 1))
end

In [None]:
function Generator()
    Chain(Dense(hparams.latent_dim, 7 * 7 * 256),
               BatchNorm(7 * 7 * 256, relu),
               x->reshape(x, 7, 7, 256, :),
               ConvTranspose((5, 5), 256 => 128; stride = 1, pad = 2),
               BatchNorm(128, relu),
               ConvTranspose((4, 4), 128 => 64; stride = 2, pad = 1),
               BatchNorm(64, relu),
               ConvTranspose((4, 4), 64 => 1, tanh; stride = 2, pad = 1),
               )
end

Now defining our loss functions

In [None]:
function discriminator_loss(real_output, fake_output)
    real_loss = mean(logitbinarycrossentropy.(real_output, 1f0))
    fake_loss = mean(logitbinarycrossentropy.(fake_output, 0f0))
    return real_loss + fake_loss
end

generator_loss(fake_output) = mean(logitbinarycrossentropy.(fake_output, 1f0))

And the trainings functions for the respective structure of the neural network

In [None]:
function train_discriminator!(gen, dscr, x, opt_dscr, hparams)
    noise = randn!(similar(x, (hparams.latent_dim, hparams.batch_size))) 
    fake_input = gen(noise)
    ps = Flux.params(dscr)
    # Taking gradient
    loss, back = Flux.pullback(ps) do
        discriminator_loss(dscr(x), dscr(fake_input))
    end
    grad = back(1f0)
    update!(opt_dscr, ps, grad)
    return loss
end

In [None]:
function train_generator!(gen, dscr, x, opt_gen, hparams)
    noise = randn!(similar(x, (hparams.latent_dim, hparams.batch_size))) 
    ps = Flux.params(gen)
    # Taking gradient
    loss, back = Flux.pullback(ps) do
        generator_loss(dscr(gen(noise)))
    end
    grad = back(1f0)
    update!(opt_gen, ps, grad)
    return loss
end

To train this GAN we would now usually define our training function, which loops over the epochs etc. as you would expect from your machine learning framework. We provide this function here, but abstain from rewriting it for the individual subcases.

In [None]:
function train(; kws...)
    # Model Parameters
    hparams = HyperParams(; kws...)

    # Load MNIST dataset
    images, _ = MLDatasets.MNIST.traindata(Float32)
    # Normalize to [-1, 1]
    image_tensor = reshape(@.(2f0 * images - 1f0), 28, 28, 1, :)
    # Partition into batches
    data = [image_tensor[:, :, :, r] |> gpu for r in partition(1:60000, hparams.batch_size)]

    fixed_noise = [randn(hparams.latent_dim, 1) |> gpu for _=1:hparams.output_x*hparams.output_y]

    # Discriminator
    dscr = Discriminator() |> gpu

    # Generator
    gen =  Generator() |> gpu

    # Optimizers
    opt_dscr = ADAM(hparams.lr_dscr)
    opt_gen = ADAM(hparams.lr_gen)

    # Training
    train_steps = 0
    for ep in 1:hparams.epochs
        @info "Epoch $ep"
        for x in data
            # Update discriminator and generator
            loss_dscr = train_discriminator!(gen, dscr, x, opt_dscr, hparams)
            loss_gen = train_generator!(gen, dscr, x, opt_gen, hparams)

            if train_steps % hparams.verbose_freq == 0
                @info("Train step $(train_steps), Discriminator loss = $(loss_dscr), Generator loss = $(loss_gen)")
                # Save generated fake image
                output_image = create_output_image(gen, fixed_noise, hparams)
                save(@sprintf("output/dcgan_steps_%06d.png", train_steps), output_image)
            end
            train_steps += 1
        end
    end

    output_image = create_output_image(gen, fixed_noise, hparams)
    save(@sprintf("output/dcgan_steps_%06d.png", train_steps), output_image)
end

### 1.1 Construct a Bayesian GAN

The most straightforward idea here would be to build on our ideas from the earlier notebook and turn our GAN into a Bayesian GAN, as shown in literature by [Saachti et al.](https://arxiv.org/abs/1705.09558). Akin to the probabilistic model for the Bayesian neural network we want to abstract over our architecture, which hence required us to custom-code an `unpack` function (a program writing the unpacking code would be a lot more efficient here).

In [None]:
function Discriminator_forward(xs, nn_params::AbstractVector)
    ... = unpack(nn_params)
    nn = Chain(Conv((...), 1 => 64; stride = 2, pad = 1),
               x->leakyrelu.(x, 0.2f0),
               Dropout(σ),
               Conv((...), 64 => 128; stride = 2, pad = 1),
               x->leakyrelu.(x, 0.2f0),
               Dropout(σ), 
               x->reshape(x, ..., :),
               Dense(.., 1))
    return nn(xs)
end

In [None]:
function Generator_forward(xs, nn_params::AbstractVector)
    ... = unpack(nn_params)
    nn = Chain(Dense(hparams.latent_dim, 7 * 7 * 256),
               BatchNorm(7 * 7 * 256, relu),
               x->reshape(x, 7, 7, 256, :),
               ConvTranspose((5, 5), 256 => 128; stride = 1, pad = 2),
               BatchNorm(128, relu),
               ConvTranspose((4, 4), 128 => 64; stride = 2, pad = 1),
               BatchNorm(64, relu),
               ConvTranspose((4, 4), 64 => 1, tanh; stride = 2, pad = 1),
               )
    return nn(xs)
end

Now that we have constructed our forward function we need to specify our probabilistic models. We follow the syntax of `BayesianNeuralNetworks` notebook here

In [None]:
# Create a regularization term and a Gaussian prior variance term
alpha = 0.09
sig = sqrt(1.0 / alpha)

# Specify the probabilistic model for the discriminator
@model bayes_discriminator(xs, ts) = begin
    # Create the weight and bias vector
    nn_params ~ MvNormal(zeros(20), sig .* ones(20))
    
    # Calculate predictions for the inputs given the weights and biases in theta
    preds = Discriminator_forward(xs, nn_params)
    
    # Observe each prediction
    for i = 1:length(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end;

# Specify the probabilistic model for the generator
@model bayes_generator(xs, ts) = begin
    # Create the weight and bias vector
    nn_params ~ MvNormal(zeros(20), sig .* ones(20))
    
    # Calculate predictions for the inputs given the weights and biases in theta
    preds = Generator_forward(xs, nn_params)
    
    # Observe each prediction
    for i = 1:length(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end;

For inference we would now have to construct a third `@model`, i.e. create a higher-order structure which inherits from the two probabilistic models above, which makes for a very involved inference problem - also expensive - but it should not be outright intractable. A big issue here would be the instability of the training procedure though, which could lead our samplers to fail.

### 1.2 Train a Generative Model Emulator

Akin to inference compilation we can also train a GAN to fulfill that role. This would not be entirely automated, but once the generator is trained on the model it should be highly useful as a surrogate to then be used in either inference routines, model-based optimization, model-based reinforcement learning etc. Expanding upon the language for yesterday's lecture, we essentially seek to condition the GAN on the observations of our actual model. For this we'd essentially need to only modify the training procedure and the helper function pipeline

In [None]:
# Specify the number of allowed samples
N = 10000

function train(; kws...)
    # Model Parameters
    hparams = HyperParams(; kws...)
    
    # Load probabilistic model data
    trainings_samples = sample(example_model(), HMC(0.05, 4), N)
    # Partition trainings data into batches
    data = [trainings_samples[:, :, :, r] |> gpu for r in partition(110000, hparams.batch_size)]

    fixed_noise = [randn(hparams.latent_dim, 1) |> gpu for _=1:hparams.output_x*hparams.output_y]

    # Discriminator
    dscr = Discriminator() |> gpu

    # Generator
    gen =  Generator() |> gpu

    # Optimizers
    opt_dscr = ADAM(hparams.lr_dscr)
    opt_gen = ADAM(hparams.lr_gen)

    # Training
    train_steps = 0
    for ep in 1:hparams.epochs
        @info "Epoch $ep"
        for x in data
            # Update discriminator and generator
            loss_dscr = train_discriminator!(gen, dscr, x, opt_dscr, hparams)
            loss_gen = train_generator!(gen, dscr, x, opt_gen, hparams)

            if train_steps % hparams.verbose_freq == 0
                @info("Train step $(train_steps), Discriminator loss = $(loss_dscr), Generator loss = $(loss_gen)")
                # Save generated samples
                generated_sample = create_output_sample(gen, fixed_noise, hparams)
                save(@sprintf("output/dcgan_steps_%06d.png", train_steps), generated_sample)
            end
            train_steps += 1
        end
    end

    generated_sample = create_output_sample(gen, fixed_noise, hparams)
    save(@sprintf("output/dcgan_steps_%06d.png", train_steps), generated_sample)
end

Where quite a bit of the surrounding boilerplate code is missing. But after training we would then be able to use the generator as a surrogate model in our pipeline, e.g. for HMC to sample from to save computational costs. We just need to stay cautious that the sampler does not try to sample from regions, where the surrogate has no "support" and would hence be generalizing in that region - a feat that would be quite remarkable for current machine learning models.

### 1.3 Abstract the GAN as a Probabilistic Program

Viewing the GAN as its own probabilistic program incurs the highest syntactical overhead as, recalling our introduction to higher-order probabilistic programming languages and Turing yesterday, we need to custom-define our neural network layers to act in a distributional sense. This is akin to the Bayesian construction but given a less-constrained stochastic control flow faces a lot fewer programmatic constraints and can even be viewed in a similar light as a neural architecture search.

To begin with such a task we have to define the custom types for our architecture. Where it makes the most sense to decompose the GAN, using the generator as the example here into different blocks, for which we then define our custom distribution structures.

In [None]:
function Generator()
    Chain(Dense(hparams.latent_dim, 7 * 7 * 256),
               BatchNorm(7 * 7 * 256, relu),
               x->reshape(x, 7, 7, 256, :),
               ConvTranspose((5, 5), 256 => 128; stride = 1, pad = 2),
               BatchNorm(128, relu),
               ConvTranspose((4, 4), 128 => 64; stride = 2, pad = 1),
               BatchNorm(64, relu),
               ConvTranspose((4, 4), 64 => 1, tanh; stride = 2, pad = 1),
               )
end

A classical split-up would in this case `ConvTranspose` together with the respective `BatchNorm` with the two other custom blocks being `Dense` and `BatchNorm`, `reshape` and `ConvTranspose`. Following the customized distribution definition for Turing we then have to first establish the structures, taking one as an example here

In [None]:
struct custom_nn_layer <: Multinomial
end

Next we have to define the sampling and evaluation of the log-pdf of our custom layers

In [None]:
Distributions.rand(rng::AbstractRNG, d:custom_nn_layer) = ...
Distributions.logpdf(d::custom_nn_layer, x::Real) = ...

Given the computational complexity of neural network layers we would, in addition to the bijection definition, also have to vectorize the operatiors as inference would otherwise be intractable. We could then reassemble our generator structure as its own model with the unshackled stochastic control flow (there would be a few more difficulties to such use :) )

In [None]:
@model function unshackled_Generator() = begin
        Chain(custom_nn_layer_1(..),
            custom_nn_layer_2(..),
            custom_nn_layer_3(..),
            custom_nn_layer_3(..)
        )
    end

We would now have defined a highly flexible stochastic control flow, which would open up new opportunities on the architectural front. But also present a highly challenging inference problem

## Sketch of VAE Inclusion in Probabilistic Programming Systems

To train a VAE like a probabilistic model we first need to repeat the same steps as for the GAN and write down the supporting functions. Assuming that we still have the MNIST data loading pipeline from the GAN example we begin by establishing our encoder and decoder structures

In [None]:
struct Encoder
    linear
    μ
    logσ
    Encoder(input_dim, latent_dim, hidden_dim, device) = new(
        Dense(input_dim, hidden_dim, tanh) |> device,   # linear
        Dense(hidden_dim, latent_dim) |> device,        # μ
        Dense(hidden_dim, latent_dim) |> device,        # logσ
    )
end

function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.μ(h), encoder.logσ(h)
end

In [None]:
Decoder(input_dim, latent_dim, hidden_dim, device) = Chain(
    Dense(latent_dim, hidden_dim, tanh),
    Dense(hidden_dim, input_dim)
) |> device

A reconstruction function

In [None]:
function reconstuct(encoder, decoder, x, device)
    μ, logσ = encoder(x)
    z = μ + device(randn(Float32, size(logσ))) .* exp.(logσ)
    μ, logσ, decoder(z)
end

And the model loss

In [None]:
function model_loss(encoder, decoder, λ, x, device)
    μ, logσ, decoder_z = reconstuct(encoder, decoder, x, device)
    len = size(x)[end]
    # KL-divergence
    kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logσ) + μ^2 -1f0 - 2f0 * logσ)) / len

    logp_x_z = -sum(logitbinarycrossentropy.(decoder_z, x)) / len
    # regularization
    reg = λ * sum(x->sum(x.^2), Flux.params(decoder))
    
    -logp_x_z + kl_q_p + reg
end

As have both, our encoder and decoder now we can define our probabilistic model (absolute sketches at this point)

In [None]:
@model vae_model(..) = begin
    z ~ Normal(..) # sample from the prior
    decoder(z)
    # score against actual samples

In [None]:
@model guide(..) = begin
    z1, z2 = encoder(..)
    Normal(z1, z2)

The training routine then looks along the lines of the following code:

In [None]:
# load hyperparamters
args = Args(; kws...)
args.seed > 0 && Random.seed!(args.seed)
    
device = cpu

# load MNIST images
loader = get_data(args.batch_size)
    
# initialize encoder and decoder
encoder = Encoder(args.input_dim, args.latent_dim, args.hidden_dim, device)
decoder = Decoder(args.input_dim, args.latent_dim, args.hidden_dim, device)

# ADAM optimizer
opt = ADAM(args.η)
 
# fixed input
original, _ = first(get_data(args.sample_size^2))
original = original |> device
image = convert_to_image(original, args.sample_size)
image_path = joinpath(args.save_path, "original.png")
save(image_path, image)

# Configure ADVI
advi = ADVI(10, 10_000)

# training
train_steps = 0
@info "Start Training, total $(args.epochs) epochs"
for epoch = 1:args.epochs
    @info "Epoch $(epoch)"
    progress = Progress(length(loader))

    for (x, _) in loader 
        q = vi(vae_model, guide)
        epoch_loss += AdvancedVI.elbo()
    end

end