# Variational Autoencoder (VAE) on dummy-data
A first try to work with the data provided by Caroline Broichhagen (broichha@imbi.uni-freiburg.de). Ultimate goal of this kernel is going to be a simple Variational Autoencoder with one-layer convolutional encoder and decoder to reconstruct the data.

First, let's import the data and have a look at it.

In [1]:
using MAT, Plots, ColorSchemes

# by broichha@imbi.uni-freiburg.de
file = matopen("/home/flo/projects/thesis/dummydata/dummyData2000.mat")
data = read(file, "pix3D")
close(file)

# store data in WHCN order (width, height, channel, batches)
data = reshape(data, (60, 60, 1, size(data, 3)))
print("Data imported.")

Data imported.

In [2]:
print("Data dimensions: ", size(data))

# by broichha@imbi.uni-freiburg.de
anim = @animate for i=1:size(data, 4)
    Plots.heatmap(data[:, :, 1, i], seriescolor=cgrad(ColorSchemes.gray.colors))
end
gif(anim, "in.gif", fps=15)

Data dimensions: (60, 60, 1, 2000)

┌ Info: Saved animation to 
│   fn = /home/flo/projects/thesis/code/in.gif
└ @ Plots /home/flo/.julia/packages/Plots/Iuc9S/src/animation.jl:95


The dummy data consists of 2000 60x60 images that simulate the activity represented in the above gif. For now, we'll just challenge my sparse Julia abilities to build a VAE to reconstruct the above data. We will later on attempt to use the latent representation learned by our VAE to find regions of interest (ROIs) in the single images of the dataset and discriminate from background activity and noise.

In [3]:
using Flux
using Flux: @epochs, mse
using Base.Iterators: partition
using Images
using Flux: Conv, MaxPool, Dense, ConvTranspose
using BSON: @save, @load

In [19]:
# hyperparameters
# TODO: all params over here
changed = false; # set to true, if model parameters were changed
latent_dimension = 15;
epochs = 50;
out_ch1 = 8;
learning_rate = 0.001;
batch_size = 20;

The first layer of our encoder will be a convolutional one. For debugging purposes we output intermediate sizes of data tensors. Good hyperparameters still up to be tested. 

In [20]:
layer1 = Conv((3, 3), 1=>out_ch1, relu, pad=1);
sample1 = layer1(data[:, :, :, 1:10])
print("Size of output at layer1: ", size(sample1), "\n");

pool1 = MaxPool((3, 3));
sample2 = pool1(sample1)
out_size = size(sample2)
print("Size of output after max pooling: ", out_size)

conv1(x) = pool1(layer1(x));

Size of output at layer1: (60, 60, 8, 10)
Size of output after max pooling: (20, 20, 8, 10)

Next, we will flatten the output tensor and input it into a fully connected layer. Again, we keep track of dimensions to avoid mistakes. Maybe we should find a better solution than inputting data samples for that.

In [21]:
flattened_size = out_size[1] * out_size[2] * out_size[3];
layer2 = Dense(flattened_size, latent_dimension, relu);
sample3 = layer2(reshape(sample2, (flattened_size, 10)));
print("Size of output at layer2: ", size(sample3))

encoder(x) = layer2(reshape(conv1(x), (flattened_size, :)));

Size of output at layer2: (15, 10)

We can now define our z variable, the very heart of our VAE. In order to maintain the ability to use backpropagation for parameter optimization, we make use of a reparametrization trick where we split z into mean, variance and a random part epsilon.

In [22]:
# reparametrization trick in vae
function sampling(mean, variance)
    epsilon = rand(latent_dimension)
    return mean .+ exp.(0.5 .* variance) .* epsilon
end;

In [23]:
z_mean = Dense(latent_dimension, latent_dimension, relu)
z_log_var = Dense(latent_dimension, latent_dimension, relu)
z(x) = sampling(z_mean(x), z_log_var(x))

sample4 = z(sample3)
print("Size of z: ", size(sample4))

Size of z: (15, 10)

For the decoder, for now, we just use a fully connected layer followed by a deconvolutional layer. Not sure, if the output layer should be another fully connected one with sigmoid activation. For now, we'll leave it like this and improve later.

In [24]:
# decoder
layer3 = Dense(latent_dimension, 400, relu)

layer4 = ConvTranspose((3, 3), 1=>out_ch1, relu, stride=3)
sample5 = layer4(reshape(layer3(sample4), (20, 20, 1, 10)))
print("Size of output at layer 4: ", size(sample5), "\n")

layer5 = ConvTranspose((3, 3), out_ch1=>1, relu, pad=1);
sample6 = layer5(sample5);
print("Size of output at layer 5: ", size(sample6))

decoder(x) = layer5(layer4(reshape(layer3(x), (20, 20, 1, :))));

Size of output at layer 4: (60, 60, 8, 10)
Size of output at layer 5: (60, 60, 1, 10)

We can now combine encoder, latent variable z and decoder and train the model afterwards.

In [25]:
cvae(x) = decoder(z(encoder(x)));

loss(x) = mse(cvae(x), x);

optimizer = Flux.ADAM(learning_rate);

In [26]:
# for some weird reason ∇maxpool has a problem with Float64 data
# as dy seems to be computed in Float32 and the function takes
# only arrays with values of same type
data = convert.(Float32, data);

# this also looks a bit hacked, but it does the job of bringing the 
# data into the shape preferred by Flux.train! and creating batches
# (we'll find a better solution)
batches = [reshape(data[:, :, :, ((i-1)*batch_size+1):(i*batch_size)], (60, 60, 1, batch_size)) for i in 1:size(data, 4)÷batch_size];

# load previous parameters, if existing
if isfile("cvae_model.bson") && !changed
    @load "cvae_model.bson" cvae
end

print("Training...")
@epochs epochs Flux.train!(loss, params(cvae), zip(batches), optimizer)
print("done.")
# store model for later
@save "cvae_model.bson" cvae

Training...

┌ Info: Epoch 1
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 2
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 3
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 4
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 5
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 6
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 7
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 8
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 9
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 10
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 11
└ @ Main /home/flo/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105
┌ Info: Epoch 12
└ @ Main /hom

done.

Code seems to be running through. Time to see, how well our model does at reconstructing our dummy data.

In [28]:
sequence = [reshape(data[:, :, :, i], (60, 60, 1, 1)) for i in 1:size(data, 4)];
anim1 = @animate for i=1:2000
    output = cvae(sequence[i])
    output = reshape(Flux.Tracker.data(output), 60, 60)
    Plots.heatmap(output, seriescolor=cgrad(ColorSchemes.gray.colors))
end;

gif(anim1, "out.gif", fps=15)

┌ Info: Saved animation to 
│   fn = /home/flo/projects/thesis/code/out.gif
└ @ Plots /home/flo/.julia/packages/Plots/Iuc9S/src/animation.jl:95


Looks like we're not doing well so far. Even without looking at our metric, we can see, that our reconstructed images are bad. We can identify the following reasons for this:

- our model has so far only been trained on very few epochs as we did not have access to any high power machines
- we did not do any tweaking of the model architecture and hyperparameters
- the author is not sure, if the decoder is implemented correctly

TODO: 
- run code on server and enable gpu support (https://fluxml.ai/Flux.jl/stable/gpu/)
- test different architectures for encoder and decoder
- run on bigger training set and run grid search for good parameters
- set up dual-encoder-decoder architecture
- think about decoder structure

Just some debugging down here...

In [57]:
print("Begining to plot...\n")
@gif for i=1:10
    output = cvae(sequence[i])
    output = reshape(Flux.Tracker.data(output), 60, 60)
    Plots.heatmap(output, seriescolor=cgrad(ColorSchemes.gray.colors))
end;

Begining to plot...


┌ Info: Saved animation to 
│   fn = /home/flo/projects/thesis/code/tmp.gif
└ @ Plots /home/flo/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [34]:
size(data)

(60, 60, 1, 2000)

In [39]:
sequence = [data[:, :, :, i] for i in 1:size(data, 4)]
size(sequence[1])

(60, 60, 1)

In [36]:
data[2]

0.5358083559526245

60×60×1×2000 Array{Float64,4}:
[:, :, 1, 1] =
 0.424376  0.540171  0.646003  0.696327  …  0.504878  0.487778  0.649834
 0.535808  0.43302   0.684178  0.487076     0.480037  0.401764  0.430195
 0.564302  0.561957  0.479192  0.502272     0.671898  0.486677  0.460516
 0.42173   0.354076  0.409455  0.533954     0.309892  0.544167  0.467044
 0.492546  0.391326  0.520655  0.453911     1.36031   0.389626  0.591512
 0.611339  0.629831  0.515894  0.664472  …  1.21283   0.439241  0.616249
 0.340172  0.406445  0.431677  0.510153     1.27216   0.333637  0.492283
 0.528009  0.52316   0.461787  0.506869     1.36426   0.474291  0.550353
 0.440486  0.37651   0.523665  0.637959     1.43239   0.655869  0.4776  
 0.684479  0.495542  0.405739  0.482995     1.50989   0.451748  0.358121
 0.528378  0.692808  0.65008   0.444677  …  0.328437  0.643109  0.46511 
 0.422358  0.626238  0.493982  0.552668     0.305274  0.696971  0.593037
 0.377155  0.562758  0.366504  0.388486     0.484221  0.541331  0.35749 
 ⋮   