In [1]:
# Encode MNIST images as compressed vectors that can later be decoded back into
# images.
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, mse, throttle
using Base.Iterators: partition
using Parameters: @with_kw
using CUDAapi
if has_cuda()
    @info "CUDA is on"
    import CuArrays
    CuArrays.allowscalar(false)
end

In [2]:
@with_kw mutable struct Args
    lr::Float64 = 1e-3  # Learning rate
    epochs::Int = 10    # Number of epochs
    N::Int = 32         # Size of the encoding
    batchsize::Int = 1000  # Batch size for training
    sample_len::Int = 20   # Number of random digits in the sample image
    throttle::Int = 5      # Throttle timeout
end

Args

In [3]:
function get_processed_data(args)
    # Loading Images
    imgs = MNIST.images()
    #Converting image of type RGB to float 
    imgs = channelview.(imgs)
    # Partition into batches of size 1000
    train_data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, args.batchsize)]
    
    train_data = gpu.(train_data)
    return train_data
end

get_processed_data (generic function with 1 method)

In [4]:
function train(; kws...)
    args = Args(; kws...)

    train_data = get_processed_data(args)

    @info("Constructing model......")
    # You can try to make the encoder/decoder network larger
    # Also, the output of encoder is a coding of the given input.
    # In this case, the input dimension is 28^2 and the output dimension of
    # encoder is 32. This implies that the coding is a compressed representation.
    # We can make lossy compression via this `encoder`.
    encoder = Dense(28^2, args.N, leakyrelu) |> gpu
    decoder = Dense(args.N, 28^2, leakyrelu) |> gpu 

    # Defining main model as a Chain of encoder and decoder models
    m = Chain(encoder, decoder)

    @info("Training model.....")
    loss(x) = mse(m(x), x)
    ## Training
    evalcb = throttle(() -> @show(loss(train_data[1])), args.throttle)
    opt = ADAM(args.lr)
    
    @epochs args.epochs Flux.train!(loss, params(m), zip(train_data), opt, cb = evalcb)

    return m, args
end

train (generic function with 1 method)

In [5]:
using Images

img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))

function sample(m, args)
    imgs = MNIST.images()
    #Converting image of type RGB to float 
    imgs = channelview.(imgs)
    # `args.sample_len` random digits
    before = [imgs[i] for i in rand(1:length(imgs), args.sample_len)]
    # Before and after images
    after = img.(map(x -> cpu(m)(float(vec(x))), before))
    # Stack them all together
    hcat(vcat.(before, after)...)
end

┌ Info: Precompiling Images [916415d5-f1e6-5110-898d-aaa5f9f070e0]
└ @ Base loading.jl:1260


sample (generic function with 1 method)

In [6]:
cd(@__DIR__)
m, args= train()
# Sample output
@info("Saving image sample as sample_ae.png")
save("test_flux_autoencoder.png", sample(m, args))

┌ Info: Constructing model......
└ @ Main In[4]:6
┌ Info: Training model.....
└ @ Main In[4]:18
┌ Info: Epoch 1
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.10221675f0
loss(train_data[1]) = 0.048363827f0


┌ Info: Epoch 2
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.03140944f0


┌ Info: Epoch 3
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.02301975f0


┌ Info: Epoch 4
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.019030891f0


┌ Info: Epoch 5
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.016889704f0


┌ Info: Epoch 6
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.015487493f0


┌ Info: Epoch 7
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.014464614f0


┌ Info: Epoch 8
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.013701256f0


┌ Info: Epoch 9
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.01316015f0


┌ Info: Epoch 10
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss(train_data[1]) = 0.012733687f0


┌ Info: Saving image sample as sample_ae.png
└ @ Main In[6]:4
┌ Info: Precompiling PNGFiles [f57f5aa1-a3ce-4bc8-8ab9-96f992907883]
└ @ Base loading.jl:1260


0

In [None]:
img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))

function sample_encoder(m, args)
    imgs = MNIST.images()
    #Converting image of type RGB to float 
    imgs = channelview.(imgs)
    # `args.sample_len` random digits
    before = [imgs[i] for i in rand(1:length(imgs), args.sample_len)]
    # Before and after images
    after = img.(map(x -> cpu(m(1))(float(vec(x))), before))
    # Stack them all together
    hcat(vcat.(before, after)...)
end