In [None]:
using MNIST

In [None]:
ENV["SEEP_NO_GPU"]="1"
#using Seep
include(Pkg.dir("Seep", "src", "Seep.jl"))
;

In [None]:
using ProgressMeter

In [None]:
using PyPlot

In [None]:
const minibatch = 50
const np = 28*28
const nh = 256
;

In [None]:
encoder = Seep.Flow[
    Seep.Linear(Float32, "enc1", np, nh), tanh,
    Seep.Linear(Float32, "enc2", nh, nh), tanh,
    Seep.Linear(Float32, "enc3", nh, nh)
];

In [None]:
decoder = Seep.Flow[
    Seep.Linear(Float32, "dec1", nh, nh), tanh,
    Seep.Linear(Float32, "dec2", nh, nh), tanh,
    Seep.Linear(Float32, "dec3", nh, np)
];

In [None]:
discriminator = Seep.Flow[
    Seep.Linear(Float32, "dis1", np, nh), tanh,
    Seep.Linear(Float32, "dis2", nh, nh), tanh,
    Seep.Linear(Float32, "dis3", nh, 1), Seep.sigm
];

In [None]:
pool = Seep.BuddyPool(Array{Float32}(1<<30))
;

In [None]:
∇R = Seep.gradients(Seep.name!(1e-5*reduce(+, [sum(θ.^2) for θ in Seep.reg_params([encoder; decoder; discriminator])]), "reg"))
;

In [None]:
x = Seep.ANode("x", zeros(Float32, np, minibatch))
z1 = Seep.ANode("z1", zeros(Float32, 256, minibatch))
z2 = Seep.ANode("z2", zeros(Float32, 256, minibatch))
;

In [None]:
let
    @Seep.named x̂ decoder(encoder(x))
    ∇A = Seep.gradients(@Seep.named(loss, sum((x̂ - x).^2)))

    global const autoencoder = Seep.instance(pool, Seep.ANode[x; loss;
        map(θ->Seep.adam_fast(θ, ∇A[θ] + ∇R[θ]), Seep.get_params([decoder; encoder]))])
    global const autoencoder_test = Seep.instance(pool, Seep.ANode[x; x̂; loss])
    nothing
end

In [None]:
let
    name! = Seep.name!

    d1 = name!(discriminator(@Seep.named x̂1 decoder(z1)), "d1")
    l1 = name!(sum(log1p(-d1)), "generator_loss")

    d2 = name!(discriminator(@Seep.named x̂2 decoder(z2)), "d2")
    l2 = name!(sum(-log1p(-d2)), "fake_loss")

    d3 = name!(discriminator(x), "d3")
    l3 = name!(sum(-log(d3)), "real_loss")

    ∇G = Seep.gradients(l1)
    ∇D = Seep.gradients(l2 + l3)

    updates1 = map(θ->Seep.adam_fast(θ, ∇D[θ] + ∇R[θ], 1e-4, 0.5), Seep.get_params(discriminator))
    updates2 = map(θ->Seep.adam_fast(θ, ∇G[θ] + ∇R[θ], 1e-4, 0.5), Seep.get_params(decoder))
    
    global const gan = Seep.instance(pool, Seep.ANode[x; z1; z2; l1; l2; l3; updates1; updates2])
    global const gan_test = Seep.instance(pool, Seep.ANode[x; z1; z2; x̂1; x̂2; d1; d2; d3])
    nothing
end

In [None]:
function do_minibatch(g, test=false)
    if isdefined(g, :x)
        if test
            for i in 1:minibatch
                g.x[:,i] = testfeatures(rand(1:10_000))/255
            end
        else
            for i in 1:minibatch
                g.x[:,i] = trainfeatures(rand(1:60_000))/255
            end
        end
    end
    
    if isdefined(g, :z1)
        copy!(g.z1, randn(size(g.z1)))
    end
    
    if isdefined(g, :z2)
        copy!(g.z2, randn(size(g.z2)))
    end

    g()
end

In [None]:
const ae_loss = Float64[];

In [None]:
@showprogress for j in 1:50
    loss = 0.

    for i in 1:1000÷minibatch
        do_minibatch(autoencoder)
        loss += autoencoder.loss[1]
    end

    if isnan(loss)
        error("NaN")
    end

    push!(ae_loss, loss/1000)
end

In [None]:
semilogy(ae_loss)

In [None]:
do_minibatch(autoencoder_test, true)

for j in 0:1
    for i in 1:5
        subplot(4,5,i+10j)
        imshow(reshape(autoencoder_test.x[:,i+5j], 28, 28), vmin=0, vmax=1, cmap="bone")

        subplot(4,5,i+5+10j)
        imshow(reshape(autoencoder_test.x̂[:,i+5j], 28, 28), vmin=0, vmax=1, cmap="bone")
    end
end

In [None]:
const gan_loss = Float64[];

In [None]:
@showprogress for j in 1:200
    g_loss = r_loss = f_loss = 0.

    for i in 1:1000÷minibatch
        do_minibatch(gan)
        g_loss += gan.generator_loss[1]
        r_loss += gan.real_loss[1]
        f_loss += gan.fake_loss[1]
    end

    if isnan(g_loss) || isnan(r_loss) || isnan(f_loss)
        error("NaN")
    end

    append!(gan_loss, [g_loss, r_loss, f_loss]/1000)
    #Seep.save_snapshot("gan.jld", Seep.get_params([decoder; discriminator]))
end

In [None]:
semilogy(-gan_loss[1:3:end]; label="generator")
semilogy(gan_loss[2:3:end]; label="real")
semilogy(gan_loss[3:3:end]; label="fake")
hlines(-log(0.5), 0, length(gan_loss)÷3, linestyle="--", label="50/50")
legend()
;

In [None]:
do_minibatch(gan_test, true)

for i in 1:5
    for (row,xi,di) in [(0,:x,:d3),(1,:x̂1,:d1),(2,:x̂2,:d2)]

        subplot(6,5,i+10*row)
        imshow(reshape(gan_test.(xi)[:,i], 28, 28), vmin=0, cmap="bone")

        subplot(6,5,i+5+10*row)
        pie([gan_test.(di)[i], 1-gan_test.(di)[i]])

    end
end
#colorbar()