In [1]:
using BSON: @save
using BSON: @load
using CSV
using DataFrames: DataFrame
using Flux
using Flux: logitbinarycrossentropy, binarycrossentropy, BatchNorm
using Flux.Data: DataLoader
using Flux: chunk
using ImageFiltering
using MLDatasets: FashionMNIST
using ProgressMeter: Progress, next!
using Random
using Zygote
using MLDatasets
using Images
using ImageIO
using LinearAlgebra
using FFTW

using NBInclude
@nbinclude("src/functions.ipynb")

All function imported




In [2]:
# We define a reshape layer to use in our decoder
struct Reshape
    shape
end
Reshape(args...) = Reshape(args)
(r::Reshape)(x) = reshape(x, r.shape)
Flux.@functor Reshape ()

In [5]:
function get_train_loader(batch_size, shuffle::Bool)
    # The MNIST training set is made up of 60k 28 by 28 greyscale images
    train_x, train_y = MNIST(split=:train)[:]
    train_x = 1 .- reshape(train_x, (784, :))
    return DataLoader((train_x, train_y), batchsize=batch_size, shuffle=shuffle, partial=false)
end

function save_model(encoder_μ, encoder_logvar, decoder, W, save_dir::String, epoch::Int)
    print("Saving model...")
    let encoder_μ = cpu(encoder_μ), encoder_logvar = cpu(encoder_logvar), decoder = cpu(decoder), W = cpu(W)
        @save joinpath(save_dir, "model-$epoch.bson") encoder_μ encoder_logvar decoder W
    end
    println("Done")
end

function create_vae()
    # Define the encoder and decoder networks
    encoder_features = Chain(
        Dense(784,500, relu),
        Dense(500,500, relu)
    )
    encoder_μ = Chain(encoder_features, Dense(500, 20))
    encoder_logvar = Chain(encoder_features, Dense(500, 20))
    
    decoder = Chain(
        Dense(20, 500, relu, bias = false),
        Dense(500,500, relu, bias = false)
    )
    W = randn(784,500)
    return encoder_μ, encoder_logvar, decoder, W
end

create_vae (generic function with 1 method)

In [6]:
grad_sigmoid(x) = @.(exp(-x)/(exp(-x)+1)^2)
loss_α(F,A) = maximum( sqrt.(sum((F*A).*(F*A), dims = 2))) + 100*norm(A'*A - I(500),2)^2


loss_α (generic function with 1 method)

In [7]:
function vae_loss(encoder_μ, encoder_logvar, decoder, W, x, β, λ, F)
    batch_size = size(x)[end]
    @assert batch_size != 0

    # Forward propagate through mean encoder and std encoders
    μ = encoder_μ(x)
    logvar = encoder_logvar(x)
    # Apply reparameterisation trick to sample latent
    z = μ + randn(Float32, size(logvar)) .* exp.(0.5f0 * logvar)
    # Reconstruct from latent sample

    x̂ = sigmoid(W * decoder(z)) 
    rand_input_last = decoder(randn(20))  

    # WL = (grad_sigmoid(W*rand_input_last)' .* W')' 
    α = loss_α(F, W) 


    # x_rand1 = sigmoid(decoder(randn(20,64)))
    # x_diff_1 = x_rand1 - x̂
    # x_diff_1_norm = sum(x_diff_1.^2, dims = 1)
    # Γ_1 = (F*(x_diff_1)) .^ 2
    # inf_norm_1_sum = maximum(Γ_1 ./ x_diff_1_norm)  


    # Negative reconstruction loss Ε_q[logp_x_z]
    logp_x_z = -sum(binarycrossentropy.(x̂, x)) 

    # KL(qᵩ(z|x)||p(z)) where p(z)=N(0,1) and qᵩ(z|x) models the encoder i.e. reverse KL
    kl_q_p = 0.5f0 * sum(@. (exp(logvar) + μ^2 - logvar - 1f0)) 
    # Weight decay regularisation term
    reg = λ * sum(x->sum(x.^2), Flux.params(encoder_μ, encoder_logvar, decoder, W))
    # We want to maximise the evidence lower bound (ELBO)
    elbo = logp_x_z - β .* kl_q_p
    # So we minimise the sum of the negative ELBO and a weight penalty
    return -elbo + reg + 10000 * α  + 100 * norm(W*decoder(z), Inf)
end

function train(encoder_μ, encoder_logvar, decoder, W, dataloader, num_epochs, λ, β, optimiser, save_dir)
    # The training loop for the model
    trainable_params = Flux.params(encoder_μ, encoder_logvar, decoder, W)
    progress_tracker = Progress(num_epochs, "Training a epoch done")

    for epoch_num = 1:num_epochs
        acc_loss = 0.0
        loss = 0
        F_sub = sample_fourier_without_1(100, 784)
        for (x_batch, y_batch) in dataloader
            # F_sub = sample_fourier(100, 784)
            
            # pullback function returns the result (loss) and a pullback operator (back)
            loss, back = pullback(trainable_params) do
                vae_loss(encoder_μ, encoder_logvar, decoder, W, x_batch, β, λ, F_sub)
            end
            # Feed the pullback 1 to obtain the gradients and update then model parameters
            gradients = back(1f0)
            Flux.Optimise.update!(optimiser, trainable_params, gradients)
            if isnan(loss)
                break
            end
            acc_loss += loss
        end
        next!(progress_tracker; showvalues=[(:loss, loss)])
        @assert length(dataloader) > 0
        avg_loss = acc_loss / length(dataloader)
        metrics = DataFrame(epoch=epoch_num, negative_elbo=avg_loss)
        # println(metrics)
        CSV.write(joinpath(save_dir, "metrics.csv"), metrics, header=(epoch_num==1), append=true)
        save_model(encoder_μ, encoder_logvar, decoder, W, save_dir, epoch_num)
    end
    println("Training complete!")
end



train (generic function with 1 method)

In [9]:
batch_size = 64
shuffle_data = true
η = 0.001
β = 1f0
λ = 0.01f0
num_epochs = 100
save_dir = "trained_GNN/test/trained_GNN/MNIST_relu_v2"
# Define the model and create our data loader
dataloader = get_train_loader(batch_size, shuffle_data)
encoder_μ, encoder_logvar, decoder, W = create_vae()
train(encoder_μ, encoder_logvar, decoder, W, dataloader, num_epochs, λ, β, ADAM(η), save_dir)

Saving model...Done


[32mTraining a epoch done   2%|█                             |  ETA: 1:51:22[39m[K
[34m  loss:  1.3020624802352723e13[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done   3%|█                             |  ETA: 1:50:16[39m[K
[34m  loss:  4.355833974640069e12[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done   4%|██                            |  ETA: 1:49:06[39m[K
[34m  loss:  1.6609350009713628e12[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done   5%|██                            |  ETA: 1:48:05[39m[K
[34m  loss:  6.746575154176963e11[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done   6%|██                            |  ETA: 1:47:16[39m[K
[34m  loss:  2.8202077731633875e11[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done   7%|███                           |  ETA: 1:46:04[39m[K
[34m  loss:  1.188972836854653e11[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done   8%|███                           |  ETA: 1:44:49[39m[K
[34m  loss:  4.98621074845266e10[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done   9%|███                           |  ETA: 1:43:51[39m[K
[34m  loss:  2.0560525594455326e10[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  10%|████                          |  ETA: 1:43:09[39m[K
[34m  loss:  8.232156921631474e9[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  11%|████                          |  ETA: 1:42:19[39m[K
[34m  loss:  3.1476635149590836e9[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  12%|████                          |  ETA: 1:41:39[39m[K
[34m  loss:  1.1211552787770045e9[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  13%|████                          |  ETA: 1:40:46[39m[K
[34m  loss:  3.575983157439938e8[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  14%|█████                         |  ETA: 1:39:40[39m[K
[34m  loss:  9.571872873485216e7[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  15%|█████                         |  ETA: 1:38:19[39m[K
[34m  loss:  1.929762662774878e7[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  16%|█████                         |  ETA: 1:36:59[39m[K
[34m  loss:  2.471253436114341e6[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  17%|██████                        |  ETA: 1:35:41[39m[K
[34m  loss:  183411.8803122576[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  18%|██████                        |  ETA: 1:34:24[39m[K
[34m  loss:  35764.89179251737[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  19%|██████                        |  ETA: 1:33:07[39m[K
[34m  loss:  28033.456938923842[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  20%|███████                       |  ETA: 1:31:51[39m[K
[34m  loss:  24120.32765324868[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  21%|███████                       |  ETA: 1:30:38[39m[K
[34m  loss:  20972.441123293498[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  22%|███████                       |  ETA: 1:29:23[39m[K
[34m  loss:  19839.356124307007[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  23%|███████                       |  ETA: 1:28:08[39m[K
[34m  loss:  19192.796093795892[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  24%|████████                      |  ETA: 1:26:55[39m[K
[34m  loss:  18721.433921579617[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  25%|████████                      |  ETA: 1:25:42[39m[K
[34m  loss:  19062.37514079121[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  26%|████████                      |  ETA: 1:24:29[39m[K
[34m  loss:  17998.1856377025[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  27%|█████████                     |  ETA: 1:23:17[39m[K
[34m  loss:  18206.824823914532[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  28%|█████████                     |  ETA: 1:22:06[39m[K
[34m  loss:  17429.350441273225[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  29%|█████████                     |  ETA: 1:20:55[39m[K
[34m  loss:  17320.494694815792[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  30%|██████████                    |  ETA: 1:19:43[39m[K
[34m  loss:  17564.088467403355[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  31%|██████████                    |  ETA: 1:18:33[39m[K
[34m  loss:  16804.475976155856[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  32%|██████████                    |  ETA: 1:17:22[39m[K
[34m  loss:  17142.060724772164[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  33%|██████████                    |  ETA: 1:16:11[39m[K
[34m  loss:  16887.5022877921[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  34%|███████████                   |  ETA: 1:15:01[39m[K
[34m  loss:  17025.384190906403[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  35%|███████████                   |  ETA: 1:13:50[39m[K
[34m  loss:  16673.221739581273[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  36%|███████████                   |  ETA: 1:12:40[39m[K
[34m  loss:  17040.269278439795[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  37%|████████████                  |  ETA: 1:11:31[39m[K
[34m  loss:  16674.44798251142[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  38%|████████████                  |  ETA: 1:10:21[39m[K
[34m  loss:  16758.775251769654[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  39%|████████████                  |  ETA: 1:09:12[39m[K
[34m  loss:  16414.427943221075[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  40%|█████████████                 |  ETA: 1:08:02[39m[K
[34m  loss:  16608.368209042183[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  41%|█████████████                 |  ETA: 1:06:53[39m[K
[34m  loss:  16366.430270436642[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  42%|█████████████                 |  ETA: 1:05:44[39m[K
[34m  loss:  16619.00471856997[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  43%|█████████████                 |  ETA: 1:04:35[39m[K
[34m  loss:  16789.332340854206[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  44%|██████████████                |  ETA: 1:03:26[39m[K
[34m  loss:  16686.275689430986[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  45%|██████████████                |  ETA: 1:02:17[39m[K
[34m  loss:  16786.684618941792[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  46%|██████████████                |  ETA: 1:01:08[39m[K
[34m  loss:  16795.364823679884[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  47%|███████████████               |  ETA: 0:59:59[39m[K
[34m  loss:  16722.306741132372[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  48%|███████████████               |  ETA: 0:58:50[39m[K
[34m  loss:  17119.303546238054[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  49%|███████████████               |  ETA: 0:57:42[39m[K
[34m  loss:  16742.3340047527[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  50%|████████████████              |  ETA: 0:56:33[39m[K
[34m  loss:  15994.045754852012[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  51%|████████████████              |  ETA: 0:55:24[39m[K
[34m  loss:  16703.005013783455[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  52%|████████████████              |  ETA: 0:54:16[39m[K
[34m  loss:  16427.624458374983[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  53%|████████████████              |  ETA: 0:53:08[39m[K
[34m  loss:  16573.48760088091[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  54%|█████████████████             |  ETA: 0:51:59[39m[K
[34m  loss:  16657.596846401753[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  55%|█████████████████             |  ETA: 0:50:51[39m[K
[34m  loss:  16453.048923908576[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  56%|█████████████████             |  ETA: 0:49:43[39m[K
[34m  loss:  16180.179730976935[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  57%|██████████████████            |  ETA: 0:48:35[39m[K
[34m  loss:  17066.55711833993[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  58%|██████████████████            |  ETA: 0:47:26[39m[K
[34m  loss:  17103.98293800538[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  59%|██████████████████            |  ETA: 0:46:18[39m[K
[34m  loss:  17215.744531635948[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  60%|███████████████████           |  ETA: 0:45:10[39m[K
[34m  loss:  17282.396630031668[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  61%|███████████████████           |  ETA: 0:44:02[39m[K
[34m  loss:  16149.05687700077[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  62%|███████████████████           |  ETA: 0:42:54[39m[K
[34m  loss:  16803.186458664626[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  63%|███████████████████           |  ETA: 0:41:46[39m[K
[34m  loss:  16388.638787607448[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  64%|████████████████████          |  ETA: 0:40:38[39m[K
[34m  loss:  16124.390944038118[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  65%|████████████████████          |  ETA: 0:39:30[39m[K
[34m  loss:  16235.070695137954[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  66%|████████████████████          |  ETA: 0:38:22[39m[K
[34m  loss:  16557.367735497603[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  67%|█████████████████████         |  ETA: 0:37:14[39m[K
[34m  loss:  16434.25065263416[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  68%|█████████████████████         |  ETA: 0:36:07[39m[K
[34m  loss:  16042.138537937502[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  69%|█████████████████████         |  ETA: 0:34:59[39m[K
[34m  loss:  16434.391681240922[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  70%|██████████████████████        |  ETA: 0:33:51[39m[K
[34m  loss:  16445.280611476606[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  71%|██████████████████████        |  ETA: 0:32:43[39m[K
[34m  loss:  16348.887707907827[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  72%|██████████████████████        |  ETA: 0:31:35[39m[K
[34m  loss:  16120.629819063199[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  73%|██████████████████████        |  ETA: 0:30:27[39m[K
[34m  loss:  16748.16212293111[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  74%|███████████████████████       |  ETA: 0:29:19[39m[K
[34m  loss:  16588.915784448833[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  75%|███████████████████████       |  ETA: 0:28:12[39m[K
[34m  loss:  16375.05502180533[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  76%|███████████████████████       |  ETA: 0:27:04[39m[K
[34m  loss:  16215.917220970558[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  77%|████████████████████████      |  ETA: 0:25:56[39m[K
[34m  loss:  16392.0635471284[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  78%|████████████████████████      |  ETA: 0:24:48[39m[K
[34m  loss:  16440.779665259146[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  79%|████████████████████████      |  ETA: 0:23:41[39m[K
[34m  loss:  15879.975012490555[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  80%|█████████████████████████     |  ETA: 0:22:33[39m[K
[34m  loss:  16202.243970869367[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  81%|█████████████████████████     |  ETA: 0:21:25[39m[K
[34m  loss:  16558.80335918391[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  82%|█████████████████████████     |  ETA: 0:20:17[39m[K
[34m  loss:  16616.896462059573[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  83%|█████████████████████████     |  ETA: 0:19:10[39m[K
[34m  loss:  16776.01427433276[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  84%|██████████████████████████    |  ETA: 0:18:02[39m[K
[34m  loss:  16477.099847775946[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  85%|██████████████████████████    |  ETA: 0:16:54[39m[K
[34m  loss:  15926.203232090147[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  86%|██████████████████████████    |  ETA: 0:15:47[39m[K
[34m  loss:  16003.172914545285[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  87%|███████████████████████████   |  ETA: 0:14:39[39m[K
[34m  loss:  16344.5498924393[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  88%|███████████████████████████   |  ETA: 0:13:31[39m[K
[34m  loss:  15739.68692527646[39m[K[A


Saving model...Done


[K[A[32mTraining a epoch done  89%|███████████████████████████   |  ETA: 0:12:24[39m[K
[34m  loss:  16542.076168129202[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  90%|████████████████████████████  |  ETA: 0:11:16[39m[K
[34m  loss:  15853.199917419344[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  91%|████████████████████████████  |  ETA: 0:10:08[39m[K
[34m  loss:  15873.56837651449[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  92%|████████████████████████████  |  ETA: 0:09:01[39m[K
[34m  loss:  16554.53939055796[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  93%|████████████████████████████  |  ETA: 0:07:53[39m[K
[34m  loss:  16674.680784796074[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  94%|█████████████████████████████ |  ETA: 0:06:46[39m[K
[34m  loss:  16196.195543244108[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  95%|█████████████████████████████ |  ETA: 0:05:38[39m[K
[34m  loss:  16637.330256298166[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  96%|█████████████████████████████ |  ETA: 0:04:30[39m[K
[34m  loss:  16535.138589707083[39m[K[A

Saving model...Done


[K[A[32mTraining a epoch done  97%|██████████████████████████████|  ETA: 0:03:23[39m[K
[34m  loss:  16406.627993205046[39m[K[A


Saving model...Done



[K[A[32mTraining a epoch done  98%|██████████████████████████████|  ETA: 0:02:15[39m[K
[34m  loss:  16630.717472662913[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done  99%|██████████████████████████████|  ETA: 0:01:08[39m[K
[34m  loss:  16384.399968524885[39m[K[A

Saving model...Done



[K[A[32mTraining a epoch done 100%|██████████████████████████████| Time: 1:52:36[39m[K
[34m  loss:  16303.007742751302[39m[K


Saving model...Done
Training complete!


In [22]:
function get_test_loader(batch_size, shuffle::Bool)
    # The FashionMNIST test set is made up of 10k 28 by 28 greyscale images
    test_x, test_y = MNIST(split=:test)[:]
    test_x = 1 .- reshape(test_x, (784, :))
    return DataLoader((test_x, test_y), batchsize=batch_size, shuffle=shuffle)
end

function save_to_images(x_batch, save_dir::String, prefix::String, num_images::Int64)
    @assert num_images <= size(x_batch)[2]
    for i=1:num_images
        save(joinpath(save_dir, "$prefix-$i.png"), colorview(Gray, reshape(x_batch[:, i], 28,28)' ))
    end
end

function reconstruct_images(encoder_μ, encoder_logvar, decoder, x)
    # Forward propagate through mean encoder and std encoders
    μ = encoder_μ(x)
    logvar = encoder_logvar(x)
    # Apply reparameterisation trick to sample latent
    z = μ + randn(Float32, size(logvar)) .* exp.(0.5f0 * logvar)
    # Reconstruct from latent sample

    x̂ = sigmoid(decoder(z))
    return clamp.(x̂, 0 ,1)
end

function load_model_identity(load_dir::String, epoch::Int)
    print("Loading model...")
    @load joinpath(load_dir, "model-$epoch.bson") encoder_μ encoder_logvar decoder 
    println("Done")
    return encoder_μ, encoder_logvar, decoder
end

function visualise()
    # Define some parameters
    batch_size = 1
    shuffle = true
    num_images = 1
    epoch_to_load = 20
    # Load the model and test set loader
    dir = "test/trained_GNN/MNIST_sigmoid_inco"
    encoder_μ, encoder_logvar, decoder= load_model_identity(dir, epoch_to_load)
    dataloader = get_test_loader(batch_size, shuffle)
    # Reconstruct and save some images
    for (x_batch, y_batch) in dataloader
        save_to_images(x_batch, dir, "test-image", num_images)
        x̂_batch = reconstruct_images(encoder_μ, encoder_logvar, decoder, x_batch)
        print(size(x_batch))
        save_to_images(x̂_batch, dir, "reconstruction", num_images)
        break
    end
end


visualise (generic function with 1 method)

In [None]:
visualise()

In [None]:
load("test/trained_GNN/MNIST_sigmoid_inco/reconstruction-1.png")

In [None]:
load("test/trained_GNN/MNIST_sigmoid_inco/test-image-1.png")

In [24]:
using NBInclude
@nbinclude("src/functions.ipynb")
epoch_to_load =40
# Load the model and test set loader
dir = "trained_GNN/test/trained_GNN/MNIST_relu_v2"
encoder_μ, encoder_logvar, decoder = load_model_identity(dir, epoch_to_load);


batch_size = 64; shuffle = true
dataloader = get_test_loader(batch_size, shuffle)
(x_batch, y_batch) = first(dataloader)

# # x = reshape(x_batch[:,1], 784,1)

# μ = encoder_μ(x_batch)
# logvar = encoder_logvar(x_batch)
# # Apply reparameterisation trick to sample latent
# z = μ + randn(Float32, size(logvar)) .* exp.(0.5f0 * logvar);


# z1 = z[:,1]
# z2 = z[:,2]
# β = 1
# colorview(Gray,reshape(sigmoid(decoder(β * z2 + (1-β) *z1))[:,1], 28,28)')
# colorview(Gray,reshape(sigmoid(decoder(randn(20)))[:,1], 28,28)')






All function imported
Loading model...Done


(Float32[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; … ; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0], [4, 0, 4, 2, 9, 9, 1, 3, 8, 5  …  8, 6, 7, 6, 0, 8, 3, 9, 7, 4])

In [30]:
mean(x_batch, dims = 2)



784×1 Matrix{Float32}:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 ⋮
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0

In [None]:
W*Flux.params(decoder_inco)[2]*Flux.params(decoder_inco)[1]

In [None]:
F_sub = sample_fourier(100, 784)
loss_α(F,A) = maximum(sqrt.(sum((F*A).*(F*A), dims = 2))) + 100*norm(A'*A - I(500),2)^2


loss_α(F_sub, J[1])