In [None]:
import Winston
using Interact
using PyBokeh

In [None]:
ENV["SEEP_NO_GPU"]=false
using Seep
using MNIST, JLD

T = Float32
const ftest = map(T, testdata()[1]/255)
make_node(s::AbstractString, a::Array) = ANode(s, map(T, a))
make_node(s::AbstractString, x::Int...) = ANode(s, x...)
make_node(a::Array) = ANode(map(T, a))
make_node(x::Int...) = ANode(x...)
get_data(x) = x
nothing

In [None]:
@load "snapshot-vae/statistics.jld"

In [None]:
f = line(figure(), collect(1:length(Ekl)), Ekl)
line(f, collect(1:length(Erec)), Erec, line_color="red")
ylim(f, 0, 2)

In [None]:
# Constants
const input_size = 28*28
const latent_size = 5
const output_size = input_size
const batch_size = 100

# Network
const X = make_node(input_size, batch_size)
μ = make_node(latent_size, batch_size)
lnσ = make_node(latent_size, batch_size)
x = X

# Generate latent variables (encoding)
let encode_size=[20, 20]

    fe = Linear("e_encode", input_size, encode_size[1])
    x = tanh(fe(x))

    fh = Flow[]
    for i in 1:length(encode_size)
        j = i==1 ? i : i-1
        fi = Linear("h_$(i)_encode", encode_size[j], encode_size[i])
        x = tanh(fi(x))
        push!(fh, fi)
    end
    fs = Linear("s_encode", encode_size[end], latent_size)
    x = tanh(fs(x))

    fμ = Linear("μ_encode", latent_size, latent_size)
    μ = fμ(x)

    fσ = Linear("lnσ_encode", latent_size, latent_size)
    lnσ = fσ(x)
end

# Sample Latent Variables
const ϵ = make_node(latent_size, batch_size)
z = μ + exp(lnσ) .* ϵ

# Inference of image from latent variables (decode)
yhat = z
let decode_size=[20, 20]

    fe = Linear("e_decode", latent_size, decode_size[1])
    yhat = tanh(fe(yhat))

    fh = Flow[]
    for i in 1:length(decode_size)
        j = i==1 ? i : i-1
        fi = Linear("h_$(i)_decode", decode_size[j], decode_size[i])
        yhat = tanh(fi(yhat))
        push!(fh, fi)
    end
    fs = Linear("s_decode", decode_size[end], output_size)
    yhat = tanh(fs(yhat))

    fy = Linear("y_decode", output_size, output_size)
    yhat = fy(yhat)
end

vae = instance(yhat, μ, lnσ)
;

In [None]:
function epoch()
    xout = zeros(Float32, input_size, 10000)
    μs = zeros(Float32, latent_size, 10000)
    for i in 1:batch_size:10000
        ii = i+(1:batch_size)-1

        xi= vae[X]
        xi[:,:] = ftest[:, ii]

        ϵi = vae[ϵ]
        ϵi[:, :] = randn(latent_size, batch_size)
        vae()
        
        xout[:, ii] = vae[yhat]
        μs[:, ii] = vae[μ]
    end
    xout, μs
end


In [None]:
n_epoch_start = 0
if isdir("snapshot-vae")
    f = filter(x->contains(x, "snapshot-"), readdir("snapshot-vae"))
    if length(f) > 0
        epochs = map(x->parse(Int, split(split(x, ".")[1], "-")[end]), f)
        ind = indmax(epochs)
        n_epoch_start = epochs[ind]
        @show n_epoch_start
        load_snapshot(joinpath("snapshot-vae", f[ind]), yhat)
    end
end


In [None]:
xhat, μs = epoch()
nothing

In [None]:
import PyPlot
ind = rand(1:10000)

PyPlot.subplot(121)
PyPlot.imshow(reshape(ftest[:, ind], (28,28)), cmap=PyPlot.cm[:bone])
PyPlot.subplot(122)
PyPlot.imshow(reshape(xhat[:, ind], (28,28)), cmap=PyPlot.cm[:bone])
@show μs[:, ind]

nothing

In [None]:
function reconstruct(z, latent_size, decode_size, output_size)
    fe = Linear("e_decode", latent_size, decode_size[1])
    yhat = tanh(fe(z))

    fh = Flow[]
    for i in 1:length(decode_size)
        j = i==1 ? i : i-1
        fi = Linear("h_$(i)_decode", decode_size[j], decode_size[i])
        yhat = tanh(fi(yhat))
        push!(fh, fi)
    end
    fs = Linear("s_decode", decode_size[end], output_size)
    yhat = tanh(fs(yhat))

    fy = Linear("y_decode", output_size, output_size)
    fy(yhat)    
end

zz = make_node(latent_size, 1)
yy = reconstruct(zz, latent_size, [20, 20], output_size)
test = instance(yy)
epochs = map(x->parse(Int, split(split(x, ".")[1], "-")[end]), f)
ind = indmax(epochs)
load_snapshot(joinpath("snapshot-vae", f[ind]), yy)
;

In [None]:
@manipulate for z1=-2:0.1:2, z2=-2:0.1:2, z3=-2:0.1:2, z4=-2:0.1:2, z5=-2:0.1:2
    zi = zeros(Float32, latent_size, 1)
    zi[1] = z1
    zi[2] = z2
    zi[3] = z3
    zi[4] = z4
    zi[5] = z5
    
    zk = test[zz]
    zk[:] = zi
    test()
    xh = test[yy]
    Winston.imagesc(reshape(xh[:, 1], (28,28)))
end