In [None]:
using Plots

include("distributions.jl")

include("distances/new_distance.jl")
include("distances/distance_Wasserstein.jl")


We want to get the effects of $n$ and $m$ for getting high true positive rate. In particular, we have have to choose between high $n$ low $m$ or high $m$ low $n$, which one should we choose? 

For that we work on two different laws of RPM
$$
\text{DP}(1.0, P_0) \quad \text{and} \quad \text{DP}(2.0, P_0),
$$

where $P_0 = \text{Uniform}(-1,1)$

Suppose that the budget is $1000$, i.e. $n + m = 1000.$ How should we choose $n, m$?

We can try $ n = d , m = 1000 = d $ for $d = 10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000$

In [None]:
function sample_distances_thresholds(q_1::PPM, q_2::PPM, n::Int, m::Int, s::Int, θ::Float64, n_permutations::Int)
    # This functions samples distances between hierarchical empirical measures and also gets thresholds for each of the distance using permutatoin approach
    
    # q_1 :: Law of random probabiity measure Q^1
    # q_2 :: Law of random probabiity measure Q^2
    # n :: Number of rows in hierarchical sample
    # m :: Number of columns in hierarchical sample
    # s :: Number of distances sampled
    # θ :: Probability level for threshold
    # n_permutations :: Number of permutations for permutation approach

    d_wws = Vector{Float64}(undef, s)
    d_lips = Vector{Float64}(undef, s)

    perm_thresholds = zeros(s, 2) # First dimension - generated empirical measure, second one - distance function(WoW or HIPM)

    for i in 1:s
        println("s = $i")
        emp_1, emp_2 = generate_emp(q_1, n, m), generate_emp(q_2, n, m) # Q^1_{n,m}, Q^2_{n,m}
        d_wws[i] = ww(emp_1, emp_2)
        d_lips[i] = dlip(emp_1, emp_2)
        
        # now we get threshold for each distance per each probability level θ. This is different from threshold via Rademacher complexity because here threshold
        # depends on the generated empirical measures.

        # To get threshold we first obtain approximate samples of d(Q^1_{n,m}, Q^2_{n,m}) that we call permuted_samples
        permuted_samples_ww = zeros(n_permutations)
        permuted_samples_dlip = zeros(n_permutations)
        total_rows = vcat(emp_1.atoms, emp_2.atoms)

        for k in 1:n_permutations
            random_indices = randperm(2n) # indices to distribute rows to new hierarchical meausures

            emp_1_permuted = emp_ppm(total_rows[random_indices[1:n],:], n, m, emp_1.a, emp_1.b)
            emp_2_permuted = emp_ppm(total_rows[random_indices[n+1:end],:], n, m, emp_2.a, emp_2.b)

            permuted_samples_ww[k] = ww(emp_1_permuted, emp_2_permuted)
            permuted_samples_dlip[k] = dlip(emp_1_permuted, emp_2_permuted)
        end

        # Now we compute thresholds for each distance function and probability level θ. Actually they are not exactly same thresholds as in theory because we do not rescale by √(n/2)
        perm_thresholds[i, 1] = quantile(permuted_samples_ww, 1 - θ) # thresholds for WoW 
        perm_thresholds[i, 2] = quantile(permuted_samples_dlip, 1 - θ) # thresholds for HIPM
    end

    return d_wws, d_lips, perm_thresholds
end




function rejection_rate(d_wws::Vector{Float64}, d_lips::Vector{Float64}, perm_thresholds::Array{Float64, 2}, θ::Float64)
    # Given sampled distances and thresholds we compute rejection rates for each distance function and probability level θ

    n = length(d_wws)
    rej_rates = zeros(2) # rejection rates for WoW and HIPM

    rej_rates[1] = sum(d_wws .> perm_thresholds[1]) / n # how often does d_wws exceed the threshold using wow
    rej_rates[2] = sum(d_lips .> perm_thresholds[2]) / n # how often does d_lips exceed the threshold using hipm

    return rej_rates
end



In [None]:
function rejection_rates(q_1::PPM, q_2::PPM, n::Int, m::Int, s::Int, θ::Float64, n_permutations::Int)
    # q_1 : Law of random probabiity measure Q^1
    # q_2 : Law of random probabiity measure Q^2
    # n : Number of rows in hierarchical sample
    # m : Number of columns in hierarchical sample
    # s : Number of distances sampled
    # θ : Probability level for threshold
    
    # n_permutations : Number of permutations for permutation approach

    d_wws, d_lips, perm_thresholds = sample_distances_thresholds(q_1, q_2, n, m, s, θ, n_permutations)


    # Now get the rejection rates and rejection rate plots per each probability level

    rej_rates = rejection_rate(d_wws, d_lips, perm_thresholds, θ)

    return rej_rates
end

In [None]:
# define two laws of random probability measures
α_1, α_2 = 1.0, 2.0
P_0 = ()->probability("same") # uniform measure
a,b = -1.0, 1.0
q_1 = DP(α_1, P_0, a, b)
q_2 = DP(α_2, P_0, a, b)
n, m = 5, 5 # number of rows and columns in hierarchical sample


S = 1
n_permutations = 1 # number of permutations for permutation approach

θ = 0.05
rej_rates(n, m) = rejection_rates(q_1, q_2, n, m, S, θ, n_permutations)
println("Rejection rates for WoW and HIPM are $(rej_rates(n, m))")

In [None]:
d = vcat([10],collect(100:100:1000))
powers = [rej_rates(d, 1000 - d) for d in d]

In [None]:
powers

In [None]:
plot(d, getindex.(powers, 1), label="WoW", xlabel="n", ylabel="Power", title="Power vs n (m = 1000 - n)", legend=:bottomright)
plot!(d, getindex.(powers, 2), label="HIPM", linestyle=:dash)
