Skip to content
9 changes: 6 additions & 3 deletions GAN/MMD_GAN/mmd_gan_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ end
epochs::Int = 1000
num_gen::Int = 1
num_enc_dec::Int = 1
lr_enc::Float64 = 1.0e-10
lr_dec::Float64 = 1.0e-10
lr_gen::Float64 = 1.0e-10
lr_enc::Float64 = 1.0e-3
lr_dec::Float64 = 1.0e-3
lr_gen::Float64 = 1.0e-3

lambda_AE::Float64 = 8.0

Expand Down Expand Up @@ -73,6 +73,8 @@ function train_mmd_gan_1d(enc, dec, gen, hparams::HyperParamsMMD1D)
@showprogress for epoch in 1:(hparams.epochs)
for _ in 1:(hparams.num_enc_dec)
loss, grads = Flux.withgradient(enc, dec) do enc, dec
Flux.reset!(enc)
Flux.reset!(dec)
target = Float32.(rand(hparams.target_model, hparams.batch_size))
noise = Float32.(rand(hparams.noise_model, hparams.batch_size))
encoded_target = enc(target')
Expand All @@ -93,6 +95,7 @@ function train_mmd_gan_1d(enc, dec, gen, hparams::HyperParamsMMD1D)
end
for _ in 1:(hparams.num_gen)
loss, grads = Flux.withgradient(gen) do gen
Flux.reset!(gen)
target = Float32.(rand(hparams.target_model, hparams.batch_size))
noise = Float32.(rand(hparams.noise_model, hparams.batch_size))
encoded_target = enc(target')
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/benchmark_multimodal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ include("benchmark_utils.jl")
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
target_model = MixtureModel([Normal(4.0f0, 2.0f0), Normal(-2.0f0, 1.0f0)])
target_model =
hparams = HyperParamsVanillaGan(;
data_size=100,
batch_size=1,
Expand All @@ -29,7 +29,7 @@ include("benchmark_utils.jl")
train_vanilla_gan(dscr, gen, hparams)

hparams = HyperParams(;
samples=1000, K=100, epochs=100, η=1e-2, transform=noise_model
samples=1000, K=10, epochs=1000, η=1e-2, transform=noise_model
)
#hparams = AutoAdaptativeHyperParams(;
# max_k=20, samples=1200, epochs=10000, η=1e-3, transform=noise_model
Expand All @@ -39,6 +39,7 @@ include("benchmark_utils.jl")

#save_gan_model(gen, dscr, hparams)


adaptative_block_learning_1(gen, loader, hparams)

ksd = KSD(noise_model, target_model, n_samples, 18:0.1:28)
Expand All @@ -53,13 +54,13 @@ include("benchmark_utils.jl")

#save_gan_model(gen, dscr, hparams)
plot_global(
x -> -quantile.(target_model, cdf(noise_model, x)),
x -> quantile.(target_model, cdf(noise_model, x)),
noise_model,
target_model,
gen,
n_samples,
(-3:0.1:3),
(5:0.2:15),
(:0.2:10),
)

#@test js_divergence(hist1.weights, hist2.weights)/hparams.samples < 0.03
Expand Down
54 changes: 39 additions & 15 deletions benchmarks/benchmark_unimodal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ include("benchmark_utils.jl")
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
target_model = Cauchy(23.0f0, 1.0f0)
target_model = MixtureModel([Normal(-10.0, 1.0), Uniform(-5.0,5.0), Pareto(3.0, 10.0)])
hparams = HyperParamsVanillaGan(;
data_size=100,
batch_size=1,
Expand All @@ -108,7 +108,7 @@ include("benchmark_utils.jl")
train_vanilla_gan(dscr, gen, hparams)

hparams = AutoAdaptativeHyperParams(;
max_k=10, samples=1000, epochs=400, η=1e-2, transform=noise_model
max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model
)
train_set = Float32.(rand(target_model, hparams.samples))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)
Expand All @@ -130,7 +130,7 @@ include("benchmark_utils.jl")
gen,
n_samples,
(-3:0.1:3),
(18:0.1:55),
(-20:0.2:30),
)
end

Expand Down Expand Up @@ -313,29 +313,27 @@ end
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
target_model = MixtureModel([
Normal(5.0f0, 2.0f0), Normal(-1.0f0, 1.0f0), Normal(-7.0f0, 0.4f0)
])
target_model = Pareto(1.0f0, 2.0f0)

hparams = HyperParamsWGAN(;
noise_model=noise_model,
target_model=target_model,
data_size=100,
batch_size=1,
epochs=1e3,
n_critic=4,
n_critic=2,
lr_dscr=1e-2,
#lr_gen = 1.4e-2,
lr_gen=1e-2,
)

loss = train_wgan(dscr, gen, hparams)

hparams = HyperParams(; samples=100, K=10, epochs=2000, η=1e-3, noise_model)
hparams = HyperParams(; samples=100, K=10, epochs=1000, η=1e-3, noise_model)
train_set = rand(target_model, hparams.samples)
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)

adaptative_block_learning(gen, loader, hparams)
auto_adaptative_block_learning(gen, loader, hparams)

ksd = KSD(noise_model, target_model, n_samples, 20:0.1:25)
mae = min(
Expand All @@ -347,6 +345,9 @@ end
MSE(noise_model, x -> .-x .+ 23, n_sample),
)

save_gan_model(gen, dscr, hparams)


#@test js_divergence(hist1.weights, hist2.weights)/hparams.samples < 0.03

end
Expand Down Expand Up @@ -567,6 +568,8 @@ end
mse = MSE(
noise_model, x -> quantile.(target_model, cdf(noise_model, x)), n_sample
)

save_adaptative_model(gen, hparams)
end

@test_experiments "Uniform(-1,1) to Pareto(1,23)" begin
Expand Down Expand Up @@ -618,23 +621,34 @@ end
dec = Chain(Dense(29, 11), elu, Dense(11, 1))
gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1))

target_model = Normal(23.0f0, 1.0f0)
target_model = Normal(4.0f0, 2.0f0)

hparams = HyperParamsMMD1D(;
noise_model=noise_model,
target_model=target_model,
data_size=1,
data_size=100,
batch_size=1,
num_gen=1,
num_enc_dec=5,
epochs=1e5,
lr_dec=1.0e-2,
lr_enc=1.0e-2,
lr_gen=1.0e-2,
epochs=1000000,
lr_dec=1.0e-3,
lr_enc=1.0e-3,
lr_gen=1.0e-3,
)

train_mmd_gan_1d(enc, dec, gen, hparams)

plot_global(
x -> quantile.(target_model, cdf(noise_model, x)),
noise_model,
target_model,
gen,
n_samples,
(-3:0.1:3),
(-5:0.2:10),
)


hparams = HyperParams(; samples=100, K=10, epochs=2000, η=1e-3, noise_model)
train_set = rand(target_model, hparams.samples)
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)
Expand All @@ -651,6 +665,16 @@ end
MSE(noise_model, x -> .-x .+ 23, n_sample),
)

plot_global(
x -> quantile.(target_model, cdf(noise_model, x)),
noise_model,
target_model,
gen,
n_samples,
(-3:0.1:3),
(-5:0.2:10),
)

#@test js_divergence(hist1.weights, hist2.weights)/hparams.samples < 0.03

end
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function plot_transformation(real_transform, gen, range)
linecolor=:redsblues,
)
y = gen(range')
return plot!(range, vec(y); legend=:bottomright, label="neural network", linecolor=get(ColorSchemes.rainbow, 0.2), ylims=(-20,20))
return plot!(range, vec(y); legend=:bottomright, label="neural network", linecolor=get(ColorSchemes.rainbow, 0.2), ylims=(-10,10))
end

function plot_global(
Expand Down Expand Up @@ -149,7 +149,7 @@ function save_gan_model(gen, dscr, hparams)
function getName(hparams)
gan = gans[typeof(hparams)]
lr_gen = hparams.lr_gen
dscr_steps = hparams.dscr_steps
dscr_steps = hparams.n_critic
noise_model = replace(strip(string(hparams.noise_model)), "\n" => "", r"(K = .*)" => "", r"components\[.*\] " => "", r"prior = " => "", "μ=" => "", "σ=" => "", r"\{Float.*\}" => "")
target_model = replace(strip(string(hparams.target_model)), "\n" => "", r"(K = .*)" => "", r"components\[.*\] " => "", r"prior = " => "", "μ=" => "", "σ=" => "", r"\{Float.*\}" => "")
basename = "$gan-$noise_model-$target_model-lr_gen=$lr_gen-dscr_steps=$dscr_steps"
Expand Down
Loading