In [1]:
using BSON
using CUDA
using DrWatson: struct2dict
using Flux
using Flux: @functor, chunk
using Flux.Losses: logitbinarycrossentropy
using Flux: onehotbatch, onecold, @epochs
using Flux.Data: DataLoader
using Images
using Logging: with_logger
using MLDatasets
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using TensorBoardLogger: TBLogger, tb_overwrite
using Random

In [2]:
# load MNIST images and return loader
function get_data(batch_size)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    
     # Reshape Data in order to flatten each image into a linear array
    xtrain = Flux.flatten(xtrain)

    # One-hot-encode the labels
    ytrain = onehotbatch(ytrain, 0:9)
    
    
    DataLoader((xtrain, ytrain), batchsize=batch_size, shuffle=true)
end

get_data (generic function with 1 method)

In [3]:
struct Encoder
    linear
    μ
    logσ
end
@functor Encoder
    
Encoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Encoder(
    Dense(input_dim, hidden_dim, tanh),   # linear
    Dense(hidden_dim, latent_dim),        # μ
    Dense(hidden_dim, latent_dim),        # logσ
)

function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.μ(h), encoder.logσ(h)
end

In [4]:
Decoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Chain(
    Dense(latent_dim, hidden_dim, tanh),
    Dense(hidden_dim, input_dim)
)

Decoder (generic function with 1 method)

In [5]:
function reconstuct(encoder, decoder, x, y, device)
    x_encoder = vcat(x, y)
    μ, logσ = encoder(x_encoder)
    z = μ + device(randn(Float32, size(logσ))) .* exp.(logσ)
    x_decoder = vcat(z, y)
    μ, logσ, decoder(x_decoder)
end

reconstuct (generic function with 1 method)

In [6]:
function model_loss(encoder, decoder, λ, x, y, device)
    
    μ, logσ, decoder_z = reconstuct(encoder, decoder, x, y, device)
    len = size(x)[end]
    # KL-divergence
    kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logσ) + μ^2 -1f0 - 2f0 * logσ)) / len

    logp_x_z = -logitbinarycrossentropy(decoder_z, x, agg=sum) / len
    # regularization
    reg = λ * sum(x->sum(x.^2), Flux.params(decoder))
    
    -logp_x_z + kl_q_p + reg
end

model_loss (generic function with 1 method)

In [7]:
function convert_to_image(x, y_size)
    Gray.(permutedims(vcat(reshape.(chunk(x |> cpu, y_size), 28, :)...), (2, 1)))
end

convert_to_image (generic function with 1 method)

In [41]:
# arguments for the `train` function 
@with_kw mutable struct Args
    η = 1e-3                # learning rate
    λ = 0.01f0              # regularization paramater
    batch_size = 128        # batch size
    sample_size = 10        # sampling size for output    
    epochs = 80             # number of epochs
    seed = 0                # random seed
    cuda = true             # use GPU
    input_dim = 28^2        # image size
    latent_dim = 2          # latent dimension
    hidden_dim = 500        # hidden dimension
    verbose_freq = 10       # logging for every verbose_freq iterations
    tblogger = false        # log training with tensorboard
    save_path = "output"    # results path
end

Args

In [24]:
function train(; kws...)
    # load hyperparamters
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)

    # GPU config
    if args.cuda && CUDA.has_cuda()
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    # load MNIST images
    loader = get_data(args.batch_size)
    
    # initialize encoder and decoder

    encoder = Encoder(args.input_dim+10, args.latent_dim, args.hidden_dim) |> device
    decoder = Decoder(args.input_dim, args.latent_dim+10, args.hidden_dim) |> device

    # ADAM optimizer
    opt = ADAM(args.η)
    
    # parameters
    ps = Flux.params(encoder.linear, encoder.μ, encoder.logσ, decoder)

    !ispath(args.save_path) && mkpath(args.save_path)

    # logging by TensorBoard.jl
    if args.tblogger
        tblogger = TBLogger(args.save_path, tb_overwrite)
    end

    # fixed input
    original, original_y = first(get_data(args.sample_size^2))
    original = original |> device
    image = convert_to_image(original, args.sample_size)
    image_path = joinpath(args.save_path, "original.png")
    save(image_path, image)

    # training
    train_steps = 0
    @info "Start Training, total $(args.epochs) epochs"
    for epoch = 1:args.epochs
        @info "Epoch $(epoch)"
        progress = Progress(length(loader))

        for (x, y) in loader 
            loss, back = Flux.pullback(ps) do
                model_loss(encoder, decoder, args.λ, x |> device, y |> device, device)
            end
            grad = back(1f0)
            Flux.Optimise.update!(opt, ps, grad)
            # progress meter
            next!(progress; showvalues=[(:loss, loss)]) 

            # logging with TensorBoard
            if args.tblogger && train_steps % args.verbose_freq == 0
                with_logger(tblogger) do
                    @info "train" loss=loss
                end
            end

            train_steps += 1
        end
        # save image
        _, _, rec_original = reconstuct(encoder, decoder, original, original_y, device)
        rec_original = sigmoid.(rec_original)
        image = convert_to_image(rec_original, args.sample_size)
        image_path = joinpath(args.save_path, "epoch_$(epoch).png")
        save(image_path, image)
        @info "Image saved: $(image_path)"
    end

    # save model
    model_path = joinpath(args.save_path, "model.bson") 
    let encoder = cpu(encoder), decoder = cpu(decoder), args=struct2dict(args)
        BSON.@save model_path encoder decoder args
        @info "Model saved: $(model_path)"
    end
    return decoder
end

train (generic function with 1 method)

In [42]:
decoder = train()

│  - To prevent this behaviour, do `ProgressMeter.ijulia_behavior(:append)`. 
└ @ ProgressMeter C:\Users\user\.julia\packages\ProgressMeter\Vf8un\src\ProgressMeter.jl:620
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:09[39m
[34m  loss:  163.73265[39m
┌ Info: Image saved: output\epoch_80.png
└ @ Main In[24]:74
┌ Info: Model saved: output\model.bson
└ @ Main In[24]:81


Chain(
  Dense(12, 500, tanh),                 [90m# 6_500 parameters[39m
  Dense(500, 784),                      [90m# 392_784 parameters[39m
)[90m                   # Total: 4 arrays, [39m399_284 parameters, 448 bytes.

In [45]:
z = randn(Float32, 2, 10)
y = repeat([0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 1, 10)
x_decoder = vcat(z, y) |> gpu
output = decoder(x_decoder)

784×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 -7.79902  -7.56589  -7.24152  -7.59908  …  -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24152  -7.59908     -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24152  -7.59908     -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24152  -7.59908     -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24152  -7.59908     -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24152  -7.59908  …  -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24152  -7.59908     -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24151  -7.59908     -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24152  -7.59908     -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24151  -7.59908     -7.46114  -7.22245  -7.75028
 -7.79902  -7.56589  -7.24152  -7.59908  …  -7.46114  -7.22245  -7.75029
 -7.79902  -7.56589  -7.24152  -7.59908     -7.46114  -7.22245  -7.75029
 -7.78478  -7.57004  -7.23049  -7.57948     -7.46446  -7.16838  -7.73162


In [46]:
rec_original = sigmoid.(output)
image = convert_to_image(rec_original, 10)
image_path = joinpath("output", "1.png")
save(image_path, image)
@info "Image saved: $(image_path)"

┌ Info: Image saved: output\1.png
└ @ Main In[46]:5
