In [1]:
using BSON
using CUDA
using DrWatson: struct2dict
using Flux
using Flux: @functor, chunk
using Flux.Losses: logitbinarycrossentropy
using Flux.Data: DataLoader
using Images
using Logging: with_logger
using MLDatasets
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using Random
using JSServe
using Markdown
using WGLMakie

In [2]:
@with_kw mutable struct Args
    η = 1e-3                # learning rate
    λ = 0.01f0              # regularization paramater
    batch_size = 128        # batch size
    sample_size = 10        # sampling size for output    
    epochs = 60             # number of epochs
    seed = 0                # random seed
    cuda = true             # use GPU
    input_dim = 28^2        # image size
    latent_dim = 3          # latent dimension
    hidden_dim = 500        # hidden dimension
    verbose_freq = 10       # logging for every verbose_freq iterations
    save_path = "output"    # results path
end

Args

In [3]:
function get_data(batch_size)
    xtrain, ytrain = MLDatasets.MNIST(:train)[:]
    xtrain = reshape(xtrain, 28^2, :)
    DataLoader((xtrain, ytrain), batchsize=batch_size, shuffle=true)
end

struct Encoder
    linear
    μ
    logσ
end
@functor Encoder
    
Encoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Encoder(
    Dense(input_dim, hidden_dim, tanh),   # linear
    Dense(hidden_dim, latent_dim),        # μ
    Dense(hidden_dim, latent_dim),        # logσ
)

function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.μ(h), encoder.logσ(h)
end

Decoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Chain(
    Dense(latent_dim, hidden_dim, tanh),
    Dense(hidden_dim, input_dim)
)

function reconstuct(encoder, decoder, x)
    μ, logσ = encoder(x)
    z = μ + gpu(randn(Float32, size(logσ))) .* exp.(logσ)
    μ, logσ, decoder(z)
end

function model_loss(encoder, decoder, λ, x)
    μ, logσ, decoder_z = reconstuct(encoder, decoder, x)
    len = size(x)[end]
    # KL-divergence
    kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logσ) + μ^2 -1f0 - 2f0 * logσ)) / len

    logp_x_z = -logitbinarycrossentropy(decoder_z, x, agg=sum) / len
    # regularization
    reg = λ * sum(x->sum(x.^2), Flux.params(decoder))
    
    -logp_x_z + kl_q_p + reg
end

function convert_to_image(x, y_size)
    Gray.(permutedims(vcat(reshape.(chunk(x |> cpu, y_size), 28, :)...), (2, 1)))
end

convert_to_image (generic function with 1 method)

In [41]:
# load hyperparamters
args = Args(;)
args.seed > 0 && Random.seed!(args.seed)

# load MNIST images
loader = get_data(args.batch_size)

# initialize encoder and decoder
encoder = Encoder(args.input_dim, args.latent_dim, args.hidden_dim) |> gpu
decoder = Decoder(args.input_dim, args.latent_dim, args.hidden_dim) |> gpu

# ADAM optimizer
opt = ADAM(args.η)

# parameters
ps = Flux.params(encoder.linear, encoder.μ, encoder.logσ, decoder)

!ispath(args.save_path) && mkpath(args.save_path)

# fixed input
original, _ = first(get_data(args.sample_size^2))
original = original |> gpu
image = convert_to_image(original, args.sample_size)
image_path = joinpath(args.save_path, "original.png")
save(image_path, image)

# training
train_steps = 0
@info "Start Training, total $(args.epochs) epochs"
for epoch = 1:args.epochs
    @info "Epoch $(epoch)"
    progress = Progress(length(loader))

    for (x, _) in loader 
        loss, back = Flux.pullback(ps) do
            model_loss(encoder, decoder, args.λ, x |> gpu)
        end
        grad = back(1f0)
        Flux.Optimise.update!(opt, ps, grad)
        # progress meter
        next!(progress; showvalues=[(:loss, loss)]) 

        train_steps += 1
    end
    
    # save image
    _, _, rec_original = reconstuct(encoder, decoder, original)
    rec_original = sigmoid.(rec_original)
    image = convert_to_image(rec_original, args.sample_size)
    image_path = joinpath(args.save_path, "epoch_$(epoch).png")
    save(image_path, image)
    @info "Image saved: $(image_path)"
end

# save model
model_path = joinpath(args.save_path, "model.bson") 
let encoder = cpu(encoder), decoder = cpu(decoder), args=struct2dict(args)
    BSON.@save model_path encoder decoder args
    @info "Model saved: $(model_path)"
end

│  - To prevent this behaviour, do `ProgressMeter.ijulia_behavior(:append)`. 
└ @ ProgressMeter /home/server/.julia/packages/ProgressMeter/sN2xr/src/ProgressMeter.jl:618
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[34m  loss:  151.85818[39m
┌ Info: Image saved: output/epoch_60.png
└ @ Main In[41]:52
┌ Info: Model saved: output/model.bson
└ @ Main In[41]:59


In [4]:
Page(exportable=true, offline=true)
WGLMakie.activate!()

In [5]:
BSON.@load "output/model.bson" encoder decoder args
args = Args(; args...)
encoder, decoder = encoder |> gpu, decoder |> gpu
# load MNIST images
loader = get_data(args.batch_size)

DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4  …  9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 128, false, true, true, false, Val{nothing}(), Random._GLOBAL_RNG())

In [17]:
x = loader.data[1][:,1:700]
y = loader.data[2][1:700]
μ, logσ2 = encoder(x |> gpu)
μ = collect(μ)

3×700 Matrix{Float32}:
 0.300659  2.29361    0.680258  -0.628466  …   1.68725   -1.95022   0.599636
 1.19346   1.748     -1.17821   -0.456532     -4.19172   -0.199626  1.8607
 0.377838  0.374356  -3.30575    3.05091      -0.279771   3.13446   0.145775

In [18]:
meshscatter(μ[1,:],μ[2,:],μ[3,:],color=y)

In [48]:
# clustering in the latent space
# visualize first two dims
plt = scatter3d(palette=:rainbow)
for (i, (x, y)) in enumerate(loader)
    i < 20 || break
    μ, logσ = encoder(x |> gpu)
    scatter!(μ[1, :], μ[2, :], 
        markerstrokewidth=0, markeralpha=0.8,
        aspect_ratio=1,
        markercolor=y, label="")
end
savefig(plt, "output/clustering.png")

In [16]:
z = range(-2.0, stop=2.0, length=11)
len = Base.length(z)
z1 = repeat(z, len)
z2 = sort(z1)
x = zeros(Float32, args.latent_dim, len^2) |> gpu
x[1, :] = z1
x[2, :] = z2
samples = decoder(x)
samples = sigmoid.(samples)
image = convert_to_image(samples, len)
save("output/manifold.png", image)