# Tutorial - Creating a 1-D GAN

In [1]:
using Flux
using Statistics, Distributions
using Plots

┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1342


In [2]:
function generate_real_data(n,k=10)
    r  = rand([0,1],k,n)
    X = r.*(randn(k,n).+ 3.) + (1 .- r).*(randn(k,n).+ 8.)
    return X
end

function D()
    return Chain(
            Dense(2, 25,relu),
            Dense(25,1)
            )
end

function G(latent_dim::Int)
    return Chain(
            Dense(latent_dim, 15,relu),
            Dense(15,2)
            )
end

G (generic function with 1 method)

In [None]:
histogram(generate_real_data(3)[:,:],alpha=0.2,bins=20)

In [None]:
loss_D(x, y, dscr) = sum(Flux.Losses.logitbinarycrossentropy(dscr(x), y))

function trainDiscriminator!(dscr,gen,train_size)
    real = generate_real_data(train_size)
    fake = gen(rand(5,train_size))

    X    = hcat(real,fake)
    Y    = vcat(ones(train_size),zeros(train_size))
    data = Flux.Data.DataLoader(X, Y', batchsize=1,shuffle=true);
    for d in data
        gs = gradient(Flux.params(dscr)) do
            l = loss_D(d...,dscr)
        end
        Flux.update!(opt, Flux.params(dscr), gs)
    end
end

loss_G(z,gen,dscr) = sum(Flux.Losses.logitbinarycrossentropy(dscr(gen(z)),1))

function trainGenerator!(gen,dscr,train_size)
    noise = rand(5,train_size)
    data = Flux.Data.DataLoader(noise, batchsize=128,shuffle=true);
    for d in data
        gs = gradient(Flux.params(gen)) do
            l = loss_G(d,gen,dscr)
        end
        Flux.update!(opt, Flux.params(gen), gs)
    end
    fake_generated = gen(rand(5,train_size))
end

In [None]:
gen  = G(5)
dscr = D()
opt  = ADAM()
train_size = 2000

In [None]:
epochs = 1000
for e in 1:epochs
    trainDiscriminator!(dscr,gen,Int(train_size/2))
    trainGenerator!(gen,dscr,train_size)
    if e%1000 == 0
        real = generate_real_data(train_size)
        fake = gen(rand(5,train_size))
        @show mean(dscr(real)),mean(dscr(fake))
    end
end

In [None]:
real = generate_real_data(train_size)
fake = gen(rand(5,train_size))
scatter(real[1,1:100],real[2,1:100])
scatter!(fake[1,1:100],fake[2,1:100])