In [None]:
using Revise
using ChangesOfVariables, InverseFunctions, ArraysOfArrays, Statistics
using StatsBase
using Optimisers
using Plots
using QuadGK
using EuclidianNormalizingFlows
using ArraysOfArrays
using FunctionChains
using Distributions
using ValueShapes
using PyPlot
using BAT
using HypothesisTests
using SpecialFunctions
using IntervalArithmetic
using Base.Threads: @threads
using CUDA
using CUDAKernels
using KernelAbstractions
using Random
using Flux
using Intervals
using LinearAlgebra
using BenchmarkTools
import Base.in

In [None]:
Threads.nthreads()

In [None]:
in(x::AbstractVector{<:Real}, s::AbstractVector{<:Intervals.Interval}) = prod(x[i] in s[i] for i in eachindex(x)) 

function rand_trunc(dist, space::Union{Intervals.Interval, AbstractVector{<:Intervals.Interval}})
    x = rand(dist)
    flag = !(x in space)
    while flag
        x = rand(dist)
        flag = !(x in space)
    end
    return x
end


function rand_trunc(dist, space::Union{Intervals.Interval, AbstractVector{<:Intervals.Interval}}, n::Integer)
    x = nothing
    if length(dist) == 1
        x = [rand_trunc(dist, space) for i in Base.OneTo(n)]
    else
        x = zeros(length(dist), n)
        @threads for i in Base.OneTo(n)
            x[:, i] .= rand_trunc(dist, space)
        end
    end
    return x
end


function rand_select(dist, space::Union{Intervals.Interval, AbstractVector{<:Intervals.Interval}}, n::Integer)
    x = rand(dist, n)
    idxs = [x[:, i] in space for i in Base.OneTo(n)]
    return x[:, idxs]
end


In [None]:
function uniformity_measure(samples::AbstractVector{<:Real}) ## samples are assumed to be uniform, since you compare against a Uniform distribution
    sort!(samples)
    n = length(samples)
    gcdf = ecdf(samples)
    f(x::Real) = (gcdf(x) - x )^2
    res = quadgk(f, 0, samples[1])[1]
    for i in 2:n
        res += quadgk(f, samples[i-1], samples[i])[1]
    end
    res += quadgk(f, samples[end], 1)[1]
    return res
end


In [None]:
function get_rand_dist(_dist, _device; c::Real = 2)
    n_dims = length(_dist)
    _m = -mean(_dist)
    _v = 1 ./ (c * sqrt.([cov(_dist)[i, i] for i in 1:n_dims]))
    
    rand_dist = let dist = _dist, m = _m, v = _v, device = _device
        (n::Integer) -> begin
#             samples = (ValueShapes.flatview(unshaped.(bat_sample(dist, BAT.IIDSampling(nsamples = n)).result.v)) .- m) ./ v
            samples = (rand(dist, n) .+ m) .* v
            if device isa GPU
                samples = gpu(samples)
            end
            return samples
        end
    end
    return rand_dist
end


function get_rand_dist(_dist, _device, _m::Union{Real, AbstractVector{<:Real}}, _v::Union{Real, AbstractVector{<:Real}})
    n_dims = length(_dist)
    
    rand_dist = let dist = _dist, m = _m, v = _v, device = _device
        (n::Integer) -> begin
            samples = (rand(dist, n) .+ m) .* v
            if device isa GPU
                samples = gpu(samples)
            end
            return samples
        end
    end
    return rand_dist
end

            
function get_rand_dist(_dist, _device, _space::Union{Intervals.Interval, AbstractVector{<:Intervals.Interval}})
    n_dims = length(_dist)
    _low = Intervals.first.(_space)
    _high = Intervals.last.(_space)
    _m = -(_high .+ _low) ./ 2
    _v = 5 ./ (_high .- _low)
    
    rand_dist = let dist = _dist, space = _space, m = _m, v = _v, device = _device
        (n::Integer) -> begin
            samples = (rand_trunc(dist, space, n) .+ m) .* v
            if device isa GPU
                samples = gpu(samples)
            end
            return samples
        end
    end
    return rand_dist
end
         

function get_rand_dist_select(_dist, _device, _space::Union{Intervals.Interval, AbstractVector{<:Intervals.Interval}})
    n_dims = length(_dist)
    _low = Intervals.first.(_space)
    _high = Intervals.last.(_space)
    _m = -(_high .+ _low) ./ 2
    _v = 5 ./ (_high .- _low)
    
    rand_dist = let dist = _dist, space = _space, m = _m, v = _v, device = _device
        (n::Integer) -> begin
            samples = (rand_select(dist, space, n) .+ m) .* v
            if device isa GPU
                samples = gpu(samples)
            end
            return samples
        end
    end
    return rand_dist
end
            

In [None]:
dist = BAT.FunnelDistribution(n=2)

In [None]:
n_dims = length(dist)

In [None]:
samples = rand(dist, 10^7)
@show low = minimum(samples; dims=2)
@show high = maximum(samples; dims=2)

@show low = (high .+ low) ./ 2 .- 3 .* (high .- low) ./ 2
@show high = (high .+ low) ./ 2 .+ 3 .* (high .- low) ./ 2;

In [None]:
space = [
    Intervals.Interval(-50.0, 50.0),
    Intervals.Interval(-2000.0, 2000.0),
]

In [None]:
wanna_use_GPU = true

_device = wanna_use_GPU ? KernelAbstractions.get_device(CUDA.rand(10)) : KernelAbstractions.get_device(rand(10))

In [None]:
# rand_dist = get_rand_dist(dist, _device; c=3)
# rand_dist = get_rand_dist(dist, _device, space)
rand_dist = get_rand_dist_select(dist, _device, space)

In [None]:
K = 10

blocks = get_flow(n_dims, _device, K)

In [None]:
function train_a_nn(
    initial_trafo::Function,
    rand_dist;
    stepsize::AbstractVector{<:Real} = [1f-3, 5f-4, 2f-4, 1f-5, 5f-5, 2f-5, 1f-5, 5f-6, 2f-6, 1f-6],
    batchsize::AbstractVector{<:Integer} = [10^3, 2*10^3, 5*10^3, 10^4, 2*10^4, 5*10^4, 10^5, 2*10^5, 5*10^5, 10^6],
    nbatches::Integer = 100,
    max_nepochs::Integer = 100,
    stationary_p_val::Real = 1e-10,
)

    trained_trafo = deepcopy(initial_trafo)
    negll_history = Vector{Float64}()
    optstate = nothing
    
    for i in eachindex(stepsize)

        optimizer = Optimisers.Adam(stepsize[i])
        if i == 1
            optstate = Optimisers.setup(optimizer, deepcopy(trained_trafo))
        end

        r = EuclidianNormalizingFlows.optimize_whitening_stationary(
            rand_dist,
            trained_trafo, 
            optimizer,
            nbatches = nbatches,
            batchsize = batchsize[i],
            max_nepochs = max_nepochs,
            stationary_p_val = stationary_p_val,
            optstate = optstate,
            wanna_use_GPU = wanna_use_GPU
        );

        trained_trafo = deepcopy(r.result)
        negll_history = vcat(negll_history, r.negll_history)
        optstate = deepcopy(r.optimizer_state)

        println("+++ DONE stepsize $(stepsize[i]), batchsize $(batchsize[i]), nbatches $(nbatches)")
    end
    
    return trained_trafo, negll_history
end
 

In [None]:
trained_trafo = Function[]
negll_history = fill(Vector{Float64}(), 0)

time_start = time()

for i in eachindex(blocks.fs)
    
    tmp_trafo, tmp_negll_history = train_a_nn(
        blocks.fs[i],
        i==1 ? rand_dist : (n::Integer)->fchain(trained_trafo)(rand_dist(n));
#         stepsize = [1f-3, 5f-4, 2f-4, 1f-5, 5f-5, 2f-5, 1f-5],
#         batchsize = [10^3, 2*10^3, 5*10^3, 10^4, 2*10^4, 5*10^4, 10^5],
#         nbatches = 100,
    )
    
    push!(trained_trafo, tmp_trafo)
    push!(negll_history, tmp_negll_history)
    
    println("*** TRAINED transformation $(i)")
end

final_trafo = fchain(trained_trafo)
final_negll_history = Vector{Float64}()

# final_trafo, final_negll_history = train_a_nn(
#     fchain(trained_trafo), 
#     rand_dist;
#     stepsize = [5f-6, 1f-6],
#     batchsize = [2*10^4, 10^5],
#     nbatches = 500,
# )

final_trafo = cpu(final_trafo)

println("!!! TRAINED compelete model in $(time() - time_start)s")


In [None]:
tmp = negll_history[2]
fig2, ax2 = plt.subplots(1, figsize=(8,4))
ax2.plot(1:length(tmp), tmp)
println()
ax2.set_ylabel("Cost")
ax2.set_xlabel("Iteration")

In [None]:
new_samples = cpu(rand_dist(10^3))
samples_transformed = final_trafo(new_samples)
ref_samples = randn(n_dims, 10^3)

plt_1 = Plots.plot(samples_transformed[1, :], samples_transformed[2, :], seriestype=:scatter, label="", size=(400, 400), xlims=(-5, 5), ylims=(-5, 5))
plt_2 = Plots.plot(ref_samples[1, :], ref_samples[2, :], seriestype=:scatter, label="", size=(400, 400), xlims=(-5, 5), ylims=(-5, 5))

Plots.plot(plt_1, plt_2, layout=(1, 2), size=(800, 400))


In [None]:
new_samples = cpu(rand_dist(10^6))
samples_transformed = final_trafo(new_samples)

_x = range(-5, 5; length=10^3)
_y = pdf.(Normal(), _x)

plt_1 = Plots.plot(_x, _y, linewidth=2, linealpha=0.7, label="")
plt_1 = Plots.plot!(samples_transformed[1, :], seriestype=:stephist, nbins=500, normalize=true, linewidth=2, label="")

plt_2 = Plots.plot(_x, _y, linewidth=2, linealpha=0.7, label="")
plt_2 = Plots.plot!(samples_transformed[2, :], seriestype=:stephist, nbins=500, normalize=true, linewidth=2, label="")

Plots.plot(plt_1, plt_2, layout=(1, 2), size=(800, 400))
