In [None]:
using BSON: @save
using BSON: @load
using CSV
using DataFrames: DataFrame
using Flux
using Flux: logitbinarycrossentropy, binarycrossentropy, BatchNorm
using Flux.Data: DataLoader
using Flux: chunk
using ImageFiltering
using MLDatasets: FashionMNIST
using ProgressMeter: Progress, next!
using Random
using Zygote
using MLDatasets
using Images
using ImageIO
using LinearAlgebra
using FFTW

using NBInclude
@nbinclude("src/functions.ipynb")

In [None]:
# We define a reshape layer to use in our decoder
struct Reshape
    shape
end
Reshape(args...) = Reshape(args)
(r::Reshape)(x) = reshape(x, r.shape)
Flux.@functor Reshape ()

In [28]:
function get_train_loader(batch_size, shuffle::Bool)
    # The MNIST training set is made up of 60k 28 by 28 greyscale images
    train_x, train_y = MNIST(split=:train)[:]
    train_x = 1 .- reshape(train_x, (784, :))
    return DataLoader((train_x, train_y), batchsize=batch_size, shuffle=shuffle, partial=false)
end

function save_model(encoder_μ, encoder_logvar, decoder, W, save_dir::String, epoch::Int)
    print("Saving model...")
    let encoder_μ = cpu(encoder_μ), encoder_logvar = cpu(encoder_logvar), decoder = cpu(decoder), W = cpu(W)
        @save joinpath(save_dir, "model-$epoch.bson") encoder_μ encoder_logvar decoder  W
    end
    println("Done")
end

function create_vae()
    # Define the encoder and decoder networks
    encoder_features = Chain(
        Dense(784,500, relu),
        Dense(500,500, relu)
    )
    encoder_μ = Chain(encoder_features, Dense(500, 20))
    encoder_logvar = Chain(encoder_features, Dense(500, 20))
    
    decoder = Chain(
        Dense(20, 500, relu, bias = false),
        Dense(500,500, relu, bias = false),
        # Dense(500,784, bias = true)
    )
    W = randn(784,500)
    return encoder_μ, encoder_logvar, decoder, W
end

create_vae (generic function with 1 method)

In [40]:
function vae_loss(encoder_μ, encoder_logvar, decoder, W, x, β, λ, F)
    batch_size = size(x)[end]
    @assert batch_size != 0

    # Forward propagate through mean encoder and std encoders
    μ = encoder_μ(x)
    logvar = encoder_logvar(x)
    # Apply reparameterisation trick to sample latent
    z = μ + randn(Float32, size(logvar)) .* exp.(0.5f0 * logvar)
    # Reconstruct from latent sample

    x̂ = W * (decoder(z))  

    # cent = abs(sum(x̂))
    loss_α(F,A) = maximum(sqrt.(sum((F*A).*(F*A), dims = 2))) + 100*norm(A'*A - I(500),2)^2
    α = loss_α(F, W)
    
    # x_rand1 = (decoder(randn(20,64)))
    # # x_rand2 = (decoder(randn(20,64)))
    # x_diff = x_rand1 - x̂
    # x_diff_norm = sum(x_diff.^2, dims = 1)
    # Γ = (F*(x_diff)) .^ 2
    # inf_norm_sum = maximum(Γ ./ x_diff_norm) + .001(maximum(abs.(x_rand1)) + abs(maximum(x̂) + minimum(x̂)) )
    


    # for i in 1:64
    #     inf_norm_sum += norm(Γ[:,i], Inf) / x_diff_norm[i]
    # end

    # Negative reconstruction loss Ε_q[logp_x_z]
    logp_x_z = -sum(logitbinarycrossentropy.(x̂, x)) 
    # KL(qᵩ(z|x)||p(z)) where p(z)=N(0,1) and qᵩ(z|x) models the encoder i.e. reverse KL
    # The @. macro makes sure that all operates are elementwise
    kl_q_p = 0.5f0 * sum(@. (exp(logvar) + μ^2 - logvar - 1f0)) 
    # Weight decay regularisation term
    reg = λ * sum(x->sum(x.^2), Flux.params(encoder_μ, encoder_logvar, decoder, W))
    # We want to maximise the evidence lower bound (ELBO)
    elbo = logp_x_z - β .* kl_q_p
    # So we minimise the sum of the negative ELBO and a weight penalty
    return -elbo + reg + 0.1*norm(x̂ - x, 2)^2 + 10000α
end

function train(encoder_μ, encoder_logvar, decoder, W, dataloader, num_epochs, λ, β, optimiser, save_dir)
    # The training loop for the model
    trainable_params = Flux.params(encoder_μ, encoder_logvar, decoder, W)
    progress_tracker = Progress(num_epochs, "Training a epoch done")

    for epoch_num = 1:num_epochs
        acc_loss = 0.0
        loss = 0
        # F_sub = dct(diagm(ones(784)),2)
        for (x_batch, y_batch) in dataloader
            F_sub = sample_fourier(100, 784)
            
            # pullback function returns the result (loss) and a pullback operator (back)
            loss, back = pullback(trainable_params) do
                vae_loss(encoder_μ, encoder_logvar, decoder, W, x_batch, β, λ, F_sub)
            end
            # Feed the pullback 1 to obtain the gradients and update then model parameters
            gradients = back(1f0)
            Flux.Optimise.update!(optimiser, trainable_params, gradients)
            if isnan(loss)
                break
            end
            acc_loss += loss
        end
        next!(progress_tracker; showvalues=[(:loss, loss)])
        @assert length(dataloader) > 0
        avg_loss = acc_loss / length(dataloader)
        metrics = DataFrame(epoch=epoch_num, negative_elbo=avg_loss)
        # println(metrics)
        CSV.write(joinpath(save_dir, "metrics.csv"), metrics, header=(epoch_num==1), append=true)
        save_model(encoder_μ, encoder_logvar, decoder, W, save_dir, epoch_num)
    end
    println("Training complete!")
end

train (generic function with 2 methods)

In [41]:
batch_size = 64
shuffle_data = true
η = 0.001
β = 1f0
λ = 0.01f0
num_epochs = 40
save_dir = "test/trained_GNN/MNIST_relu"
# Define the model and create our data loader
dataloader = get_train_loader(batch_size, shuffle_data)
encoder_μ, encoder_logvar, decoder, W= create_vae()
train(encoder_μ, encoder_logvar, decoder, W, dataloader, num_epochs, λ, β, ADAM(η), save_dir)






Saving model...Done


[32mTraining a epoch done   5%|██                            |  ETA: 0:39:41[39m[K
[34m  loss:  1.3092465849116094e13[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done   8%|███                           |  ETA: 0:38:32[39m[K
[34m  loss:  4.384635184017403e12[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  10%|████                          |  ETA: 0:37:27[39m[K
[34m  loss:  1.672114909692885e12[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  12%|████                          |  ETA: 0:36:19[39m[K
[34m  loss:  6.788810061045419e11[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  15%|█████                         |  ETA: 0:35:13[39m[K
[34m  loss:  2.8355689927968024e11[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  18%|██████                        |  ETA: 0:34:08[39m[K
[34m  loss:  1.1942139533126756e11[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  20%|███████                       |  ETA: 0:33:04[39m[K
[34m  loss:  5.002059081147454e10[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  22%|███████                       |  ETA: 0:32:01[39m[K
[34m  loss:  2.059630010262917e10[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  25%|████████                      |  ETA: 0:30:58[39m[K
[34m  loss:  8.232465925454814e9[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  28%|█████████                     |  ETA: 0:29:56[39m[K
[34m  loss:  3.141317229983593e9[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  30%|██████████                    |  ETA: 0:28:53[39m[K
[34m  loss:  1.1160513312001517e9[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  32%|██████████                    |  ETA: 0:27:50[39m[K
[34m  loss:  3.548110631246367e8[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  35%|███████████                   |  ETA: 0:26:49[39m[K
[34m  loss:  9.455392497002943e7[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  38%|████████████                  |  ETA: 0:25:47[39m[K
[34m  loss:  1.8941920140802372e7[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  40%|█████████████                 |  ETA: 0:24:44[39m[K
[34m  loss:  2.405596192073841e6[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  42%|█████████████                 |  ETA: 0:23:42[39m[K
[34m  loss:  182679.2408216791[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  45%|██████████████                |  ETA: 0:22:41[39m[K
[34m  loss:  43048.227735772605[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  48%|███████████████               |  ETA: 0:21:39[39m[K
[34m  loss:  38018.033055404485[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  50%|████████████████              |  ETA: 0:20:37[39m[K
[34m  loss:  35739.02178098305[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  52%|████████████████              |  ETA: 0:19:35[39m[K
[34m  loss:  32178.477095113674[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  55%|█████████████████             |  ETA: 0:18:33[39m[K
[34m  loss:  29302.880524947817[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  58%|██████████████████            |  ETA: 0:17:31[39m[K
[34m  loss:  27803.85766012768[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  60%|███████████████████           |  ETA: 0:16:29[39m[K
[34m  loss:  27093.311004330862[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  62%|███████████████████           |  ETA: 0:15:27[39m[K
[34m  loss:  28136.15584714433[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  65%|████████████████████          |  ETA: 0:14:25[39m[K
[34m  loss:  27811.142165048248[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  68%|█████████████████████         |  ETA: 0:13:23[39m[K
[34m  loss:  26568.440955120783[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  70%|██████████████████████        |  ETA: 0:12:21[39m[K
[34m  loss:  26474.323538713437[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  72%|██████████████████████        |  ETA: 0:11:19[39m[K
[34m  loss:  26090.76011052048[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  75%|███████████████████████       |  ETA: 0:10:18[39m[K
[34m  loss:  26255.247139765874[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  78%|████████████████████████      |  ETA: 0:09:16[39m[K
[34m  loss:  25893.564563009553[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  80%|█████████████████████████     |  ETA: 0:08:14[39m[K
[34m  loss:  25599.715411894467[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  82%|█████████████████████████     |  ETA: 0:07:12[39m[K
[34m  loss:  25507.63992454712[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  85%|██████████████████████████    |  ETA: 0:06:11[39m[K
[34m  loss:  25998.303772211024[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  88%|███████████████████████████   |  ETA: 0:05:09[39m[K
[34m  loss:  26697.042364177374[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  90%|████████████████████████████  |  ETA: 0:04:07[39m[K
[34m  loss:  26146.513268748073[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  92%|████████████████████████████  |  ETA: 0:03:05[39m[K
[34m  loss:  26068.6216221265[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  95%|█████████████████████████████ |  ETA: 0:02:03[39m[K
[34m  loss:  26201.432185708058[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  98%|██████████████████████████████|  ETA: 0:01:02[39m[K
[34m  loss:  26838.69821468214[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done 100%|██████████████████████████████| Time: 0:41:09[39m[K
[34m  loss:  26297.0811641319[39m[K


Saving model...Done
Training complete!


In [31]:
function get_test_loader(batch_size, shuffle::Bool)
    # The FashionMNIST test set is made up of 10k 28 by 28 greyscale images
    test_x, test_y = MNIST(split=:test)[:]
    test_x = 1 .- reshape(test_x, (784, :))
    return DataLoader((test_x, test_y), batchsize=batch_size, shuffle=shuffle)
end

function save_to_images(x_batch, save_dir::String, prefix::String, num_images::Int64)
    @assert num_images <= size(x_batch)[2]
    for i=1:num_images
        save(joinpath(save_dir, "$prefix-$i.png"), colorview(Gray, reshape(x_batch[:, i], 28,28)' ))
    end
end

function reconstruct_images(encoder_μ, encoder_logvar, decoder, x)
    # Forward propagate through mean encoder and std encoders
    μ = encoder_μ(x)
    logvar = encoder_logvar(x)
    # Apply reparameterisation trick to sample latent
    z = μ + randn(Float32, size(logvar)) .* exp.(0.5f0 * logvar)
    # Reconstruct from latent sample

    x̂ = sigmoid(decoder(z))
    return clamp.(x̂, 0 ,1)
end

function load_model_identity(load_dir::String, epoch::Int)
    print("Loading model...")
    @load joinpath(load_dir, "model-$epoch.bson") encoder_μ encoder_logvar decoder W
    println("Done")
    return encoder_μ, encoder_logvar, decoder, W
end

function visualise()
    # Define some parameters
    batch_size = 1
    shuffle = true
    num_images = 1
    epoch_to_load = 20
    # Load the model and test set loader
    dir = "test/trained_GNN/MNIST_sigmoid_inco"
    encoder_μ, encoder_logvar, decoder= load_model_identity(dir, epoch_to_load)
    dataloader = get_test_loader(batch_size, shuffle)
    # Reconstruct and save some images
    for (x_batch, y_batch) in dataloader
        save_to_images(x_batch, dir, "test-image", num_images)
        x̂_batch = reconstruct_images(encoder_μ, encoder_logvar, decoder, x_batch)
        print(size(x_batch))
        save_to_images(x̂_batch, dir, "reconstruction", num_images)
        break
    end
end


visualise (generic function with 1 method)

In [None]:
visualise()

In [None]:
load("test/trained_GNN/MNIST_sigmoid_inco/reconstruction-1.png")

In [None]:
load("test/trained_GNN/MNIST_sigmoid_inco/test-image-1.png")

In [35]:
using NBInclude
@nbinclude("src/functions.ipynb")
epoch_to_load =20
# Load the model and test set loader
dir = "test/trained_GNN/MNIST_relu"
encoder_μ, encoder_logvar, decoder, W = load_model_identity(dir, epoch_to_load);


# batch_size = 64; shuffle = true
# dataloader = get_test_loader(batch_size, shuffle)
# (x_batch, y_batch) = first(dataloader)

# # x = reshape(x_batch[:,1], 784,1)

# μ = encoder_μ(x_batch)
# logvar = encoder_logvar(x_batch)
# # Apply reparameterisation trick to sample latent
# z = μ + randn(Float32, size(logvar)) .* exp.(0.5f0 * logvar);


# z1 = z[:,1]
# z2 = z[:,2]
# β = 1
# colorview(Gray,reshape(sigmoid(decoder(β * z2 + (1-β) *z1))[:,1], 28,28)')
# colorview(Gray,reshape(sigmoid(decoder(randn(20)))[:,1], 28,28)')






All function imported
Loading model...Done


(Chain(Chain(Dense(784 => 500, relu), Dense(500 => 500, relu)), Dense(500 => 20)), Chain(Chain(Dense(784 => 500, relu), Dense(500 => 500, relu)), Dense(500 => 20)), Chain(Dense(20 => 500, relu; bias=false), Dense(500 => 500, relu; bias=false)), [0.005073787634532151 0.019671017698285245 … 0.005401578961626009 0.03505487141940171; -0.02598047421298229 -0.0008783655326210292 … 0.012016930845491403 0.006034566162204524; … ; -0.002879707078845804 -0.016841959331383664 … 1.1762516885401466e-5 -0.00930029475700759; -0.007068885192699047 0.007394328279686085 … 0.0017658044990063785 -0.03774220016757633])

In [None]:
F = dct(diagm(ones(784)),2);

x_rand1 = sigmoid(decoder(randn(20,100)))
x_rand2 = sigmoid(decoder(randn(20,100)))
x_diff = x_rand1 - x_rand2
x_diff_norm = sum(x_diff.^2, dims = 1)
Γ = (F*(x_diff)) .^ 2


maximum(Γ ./ x_diff_norm)

# inf_norm_sum = 0

# for i in 1:64
#     inf_norm_sum += norm(Γ[:,i], Inf) / x_diff_norm[i]
# end



In [None]:
x_rand1 = sigmoid(decoder(randn(20,1)))
x_rand2 = sigmoid(decoder(randn(20,1)))
x_diff = x_rand1 - x_rand2
colorview(Gray,reshape(x_diff, 28,28)')


In [None]:
norm(randn(2,2), Inf)

In [None]:
maximum(abs(randn(2,2)))

In [None]:
Γ[:,2] / x_diff_norm[2]