In [84]:
using Nemo, ExactWrightFisher, KahanSummation, Random, Distributions



In [85]:
RR = RealField(256)

Real Field with 256 bits of precision and error bounds

In [86]:
RR(3)

3.00000000000000000000000000000000000000000000000000000000000000000000000000000

In [87]:
import Base.sign

function sign(x::arb)
    if x < 0
        return -1
    else
        return 1
    end
end

sign (generic function with 10 methods)

In [88]:
function signed_logsumexp_arb(lx, signs)
    res = sum(exp.(RR.(lx)) .* signs)
    return sign(res), log(abs(res))
    
end
function signed_logsumexp_naive(lx, signs)
# There might potentially be a problem when the larger terms cancel each other
  m = maximum(lx)
  scaled_sum = sum(signs .* exp.(lx .- m))
  if abs(scaled_sum) <= 10*eps(Float64)
    return [1., -Inf]
  elseif scaled_sum < 0
    sgn = -1
    scaled_sum = -1*scaled_sum
  else
    sgn = 1
  end
  return [sgn, m + log(scaled_sum) ]#Will give an error if the sum is negative
end

function logsumexp_kahan(X::AbstractArray{T}) where {T<:Real}
    isempty(X) && return log(zero(T))
    u = maximum(X)
    isfinite(u) || return float(u)
    let u=u # avoid https://github.com/JuliaLang/julia/issues/15276
        u + log(sum_kbn(exp.(X .- u)))
    end
end

function signed_logsumexp_kahan_naive(lx, signs)
# There might potentially be a problem when the larger terms cancel each other
  m = maximum(lx)
  scaled_sum = sum_kbn(signs .* exp.(lx .- m))
  if abs(scaled_sum) <= eps(Float64)
    return [1., -Inf]
  elseif scaled_sum < 0
    sgn = -1
    scaled_sum = -1*scaled_sum
  else
    sgn = 1
  end
  return [sgn, m + log(scaled_sum) ]#Will give an error if the sum is negative
end

function signed_logsumexp_kahan(lx, signs)
  # summing the positive terms together and the negative terms together should decrease the probability of cancellation of large terms
  # @assert length(lx) == length(signs)
  # pos = signs .== 1
  if all(signs .> 0)
    return [1.0, logsumexp_kahan(lx)]
  elseif all(signs .< 0)
    return [-1.0, logsumexp_kahan(lx)]
  else
    n = length(lx)
    @inbounds logsumexp_positive_terms = logsumexp_kahan([lx[i] for i in 1:n if signs[i] > 0])
    @inbounds logsumexp_negative_terms = logsumexp_kahan([lx[i] for i in 1:n if signs[i] < 0])
    if logsumexp_positive_terms > logsumexp_negative_terms
      sgn = 1
      res = log(1-exp(logsumexp_negative_terms - logsumexp_positive_terms)) + logsumexp_positive_terms
    elseif logsumexp_positive_terms < logsumexp_negative_terms
      sgn = -1
      res = log(1-exp(logsumexp_positive_terms - logsumexp_negative_terms)) + logsumexp_negative_terms
    else
      sgn = 1
      res = -Inf
    end
    return [sgn, res]
  end
end

signed_logsumexp_kahan (generic function with 1 method)

In [89]:
ExactWrightFisher.signed_logsumexp([0.2, -3., 1], [1, 1, -1])[2] - signed_logsumexp_arb([0.2, -3., 1], [1, 1, -1])[2]

[3.5229451118748631580263821001641342027628089453259875577554e-17 +/- 3.98e-76]

In [90]:
signed_logsumexp_naive([0.2, -3., 1], [1, 1, -1])[2] - signed_logsumexp_arb([0.2, -3., 1], [1, 1, -1])[2]

[3.5229451118748631580263821001641342027628089453259875577554e-17 +/- 3.98e-76]

In [91]:
signed_logsumexp_kahan_naive([0.2, -3., 1], [1, 1, -1])[2] - signed_logsumexp_arb([0.2, -3., 1], [1, 1, -1])[2]

[3.5229451118748631580263821001641342027628089453259875577554e-17 +/- 3.98e-76]

In [92]:
signed_logsumexp_kahan([0.2, -3., 1], [1, 1, -1])[2] - signed_logsumexp_arb([0.2, -3., 1], [1, 1, -1])[2]

[3.5229451118748631580263821001641342027628089453259875577554e-17 +/- 3.98e-76]

In [93]:
ExactWrightFisher.S_kvec_M_both_logsumexp([1,3,50, 10], 0.0002, 0.075)
function S_kvec_M_plus_logsum_test(kvec::Array{T, 1}, t::Real, θ::Real, logsumexpfun) where
T<:Integer
  M = length(kvec)

  two_kvec_plus_1 = sum(2*kvec .+ 1)

  U = typeof(t)

  logterms = Array{U}(undef, two_kvec_plus_1)
  signs = Array{Float64}(undef, two_kvec_plus_1)
  cnt = 1
  for m in 0:(M-1)
    for i in 0:(2*kvec[m+1])
    # for(int i = 0; i <= 2*kvec[m]; ++i) {
      logterms[cnt] = ExactWrightFisher.log_bk_t_θ_t(m+i, t, θ, m);
      signs[cnt] = ExactWrightFisher.minus_1_power_i(i);
      cnt += 1;
    end
  end
  return logsumexpfun(logterms, signs);
end
S_kvec_M_plus_logsum_arb(kvec::Array{T, 1}, t::Real, θ::Real) where
T<:Integer = S_kvec_M_plus_logsum_test(kvec, t, θ, signed_logsumexp_arb)
S_kvec_M_plus_logsum_naive(kvec::Array{T, 1}, t::Real, θ::Real) where
T<:Integer = S_kvec_M_plus_logsum_test(kvec, t, θ, signed_logsumexp_naive)
S_kvec_M_plus_logsum_kahan(kvec::Array{T, 1}, t::Real, θ::Real) where
T<:Integer = S_kvec_M_plus_logsum_test(kvec, t, θ, signed_logsumexp_kahan)
S_kvec_M_plus_logsum_kahan_naive(kvec::Array{T, 1}, t::Real, θ::Real) where
T<:Integer = S_kvec_M_plus_logsum_test(kvec, t, θ, signed_logsumexp_kahan_naive)

S_kvec_M_plus_logsum_kahan_naive (generic function with 1 method)

In [94]:
ExactWrightFisher.S_kvec_M_plus_logsum([1,3,50, 10], 0.0002, 0.075)  .- S_kvec_M_plus_logsum_arb([1,3,50, 10], 0.0002, 0.075)

2-element Array{Any,1}:
 0.0                                                                                
  [-1.181392746419046249418274994134365143118042102038545823784209e-14 +/- 2.10e-75]

In [95]:
S_kvec_M_plus_logsum_naive([1,3,50, 10], 0.0002, 0.075) .- S_kvec_M_plus_logsum_arb([1,3,50, 10], 0.0002, 0.075)

2-element Array{Any,1}:
 0.0                                                                             
  [6.2057041161129075856192474127353606881957897961454176215791e-16 +/- 2.10e-75]

In [96]:
S_kvec_M_plus_logsum_kahan([1,3,50, 10], 0.0002, 0.075) .- S_kvec_M_plus_logsum_arb([1,3,50, 10], 0.0002, 0.075)

2-element Array{Any,1}:
 0.0                                                                                
  [-1.181392746419046249418274994134365143118042102038545823784209e-14 +/- 2.10e-75]

In [97]:
S_kvec_M_plus_logsum_kahan_naive([1,3,50, 10], 0.0002, 0.075) .- S_kvec_M_plus_logsum_arb([1,3,50, 10], 0.0002, 0.075)

2-element Array{Any,1}:
 0.0                                                                             
  [6.2057041161129075856192474127353606881957897961454176215791e-16 +/- 2.10e-75]

In [98]:
ExactWrightFisher.S_kvec_M_plus_logsum([1,3,50, 10], 0.0002 |> BigFloat, 0.075 |> BigFloat) .- S_kvec_M_plus_logsum_arb([1,3,50, 10], 0.0002, 0.075)

2-element Array{Any,1}:
 0.0                                                                               
  [4.450133000774413606158020874454439050366619305802873916556582e-14 +/- 4.04e-75]

In [99]:
S_kvec_M_plus_logsum_naive([1,3,50, 10], 0.0002 |> BigFloat, 0.075 |> BigFloat) .- S_kvec_M_plus_logsum_arb([1,3,50, 10], 0.0002, 0.075)

2-element Array{Any,1}:
 0.0                                                                                
  [4.4501330007744136061580208744544390503666193058028739165565816e-14 +/- 7.04e-76]

In [105]:
function S_kvec_M_both_logsumexp_kahan(kvec::Array{T, 1}, t::Real, θ::Real) where
T<:Integer
  logS_kvec_M_plus_res = S_kvec_M_plus_logsum_kahan(kvec, t, θ)
  sgn_logS_kvec_M_plus_res = logS_kvec_M_plus_res[1]
  sum_logS_kvec_M_plus_res = logS_kvec_M_plus_res[2]

  log_newterms = ExactWrightFisher.S_kvec_M_minus_log_newterms(kvec, t, θ)
  logsum_newterms =  signed_logsumexp_kahan(log_newterms, repeat([1.], length(log_newterms)))[2]

  return ExactWrightFisher.S_kvec_M_both_logsumexp_inner(kvec, t, θ, logS_kvec_M_plus_res, log_newterms, logsum_newterms)
end

S_kvec_M_both_logsumexp_kahan (generic function with 1 method)

In [106]:
 ExactWrightFisher.S_kvec_M_both_logsumexp([1,3,50, 10], 0.0002, 0.075)

2-element Array{Float64,1}:
 -999287.7254125979
  855577.1585386458

In [107]:
S_kvec_M_both_logsumexp_kahan([1,3,50, 10], 0.0002, 0.075)

2-element Array{Float64,1}:
 -999287.7254125979
  855577.1585386458

In [108]:

function Compute_A∞_given_U_debug(θ, t, U, m, kvec; start_print_from = 32, S_kvec_M_both_logsumexp_fun = ExactWrightFisher.S_kvec_M_both_logsumexp)
  ### 0 indexing to stick with the article's notation
    n=0
  while true
#     n = n+1
#     print(n)
#     if (n > 100){
#       return(m)
# #       break()
#     }
        if m >= start_print_from
            println(m)
        end
        kvec[m+1] = ceil(ExactWrightFisher.C_m_t_θ(m, t, θ)/2)
        #     print(km)
        S_kvec_M_BOTH = S_kvec_M_both_logsumexp_fun(kvec, t, θ)
    while (S_kvec_M_BOTH[1] < U) && (S_kvec_M_BOTH[2] > U)
        kvec = kvec .+ 1
            #       println(kvec)
        S_kvec_M_BOTH = S_kvec_M_both_logsumexp_fun(kvec, t, θ)
        if m >= start_print_from
            println(S_kvec_M_BOTH)
        end
    end
    if S_kvec_M_BOTH[1] > U
      println("A∞ = $m")
      return m
      # break()
    elseif (S_kvec_M_BOTH[2] < U)
      push!(kvec,0)
      m = m + 1
    end
  end
end

function Compute_A∞_given_U_arb_debug(θ, t, U, m, kvec; start_print_from = 32)
  ### 0 indexing to stick with the article's notation
    n=0
  while true
#     n = n+1
#     print(n)
#     if (n > 100){
#       return(m)
# #       break()
#     }
        if m >= start_print_from
            println(m)
        end
        kvec[m+1] = ceil(ExactWrightFisher.C_m_t_θ(m, t, θ)/2)
        #     print(km)
        S_kvec_M_BOTH = ExactWrightFisher.S_kvec_M_both_logsumexp_arb(kvec, t, θ)
    while (S_kvec_M_BOTH[1] < U) && (S_kvec_M_BOTH[2] > U)
        kvec = kvec .+ 1
        #       println(kvec)
        S_kvec_M_BOTH = ExactWrightFisher.S_kvec_M_both_logsumexp_arb(kvec, t, θ)
        if m >= start_print_from
            println(S_kvec_M_BOTH)
        end
    end
    if S_kvec_M_BOTH[1] > U
      println("A∞ = $m")
      return m
      # break()
    elseif (S_kvec_M_BOTH[2] < U)
      push!(kvec,0)
      m = m + 1
    end
  end
end

Compute_A∞_given_U_arb_debug (generic function with 1 method)

In [109]:
Random.seed!(0);
U = rand(Uniform())

0.8236475079774124

In [110]:
Compute_A∞_given_U_debug(sum(1:4), 0.05, U, 0, [0]; start_print_from = 35)

35
[-1.98681e8, 3.17618e8]
[-5.37652e7, 1.09657e8]
[-9.26533e6, 2.35536e7]
[-1.05549e6, 3.28686e6]
[-81788.0, 3.07845e5]
[-4408.53, 19847.2]
[-167.782, 899.304]
[-4.0237, 29.6581]
[0.461086, 1.23361]
[0.552392, 0.565406]
36
[-7.85502e7, 1.29293e8]
[-2.00415e7, 4.21e7]
[-3.25421e6, 8.52315e6]
[-3.49069e5, 1.12031e6]
[-25452.1, 98768.8]
[-1289.61, 5990.59]
[-45.5994, 255.693]
[-0.50923, 8.43057]
[0.681288, 0.8739]
[0.665231, 0.668277]
37
[-2.93485e7, 4.97617e7]
[-7.05363e6, 1.52677e7]
[-1.07822e6, 2.91076e6]
[-108814.0, 3.60074e5]
[-7459.51, 29857.9]
[-354.53, 1702.84]
[-11.2023, 68.8355]
[0.493801, 2.72463]
[0.828819, 0.873938]
A∞ = 37


37

In [111]:
Compute_A∞_given_U_debug(sum(1:4), 0.05, U, 0, [0]; start_print_from = 35, S_kvec_M_both_logsumexp_fun = S_kvec_M_both_logsumexp_kahan)

35
[-1.98681e8, 3.17618e8]
[-5.37652e7, 1.09657e8]
[-9.26533e6, 2.35536e7]
[-1.05549e6, 3.28686e6]
[-81788.0, 3.07845e5]
[-4408.48, 19847.3]
[-167.782, 899.304]
[-4.0237, 29.6581]
[0.512486, 1.28501]
[0.603792, 0.616807]
36
[-7.85502e7, 1.29293e8]
[-2.00415e7, 4.21e7]
[-3.25421e6, 8.52315e6]
[-3.49069e5, 1.12031e6]
[-25452.1, 98768.8]
[-1289.61, 5990.59]
[-45.5994, 255.693]
[-0.50923, 8.43057]
[0.681288, 0.8739]
[0.716637, 0.719683]
37
[-2.93485e7, 4.97617e7]
[-7.05363e6, 1.52677e7]
[-1.07822e6, 2.91076e6]
[-108814.0, 3.60074e5]
[-7459.46, 29857.9]
[-354.582, 1702.79]
[-11.2023, 68.8355]
[0.493801, 2.72463]
[0.828819, 0.873938]
A∞ = 37


37

In [103]:
Compute_A∞_given_U_arb_debug(sum(1:4), 0.05, U, 0, [0]; start_print_from = 35)

35
arb[[-198680896.946251151628310750515899244503643218755671309025196230351513739 +/- 5.81e-64], [317617718.315217425111758365367994390006956644429596731472376278002505047 +/- 6.90e-64]]
arb[[-53765238.389839813348973146760428530244230749519917916080625630897185594 +/- 7.31e-64], [109657170.241881014131926304954049855445558491051173587818534404514663919 +/- 7.33e-64]]
arb[[-9265327.593252635594179936881559471297163999286528789098657169825795648 +/- 8.65e-64], [23553631.984817601659307853300472156942607565416978962689977944243190421 +/- 6.97e-64]]
arb[[-1055493.57000013286511777054911118195393692335786838622612477609145224 +/- 3.79e-63], [3286857.52489748342376955742005826767943393913743621634651547227731861 +/- 6.77e-63]]
arb[[-81787.986003912952717137876933426191867193459482170635620347272075391 +/- 7.49e-64], [307845.203319719592303181662672180440158947672676179421873390789762318 +/- 4.38e-64]]
arb[[-4408.50734439478223734160645589069774276166454205737596774867177173 +/- 4.82e-63], 

38

In [None]:
ExactWrightFisher.Compute_A∞_given_U_arb(sum(1:4), 0.001, U, 0, [0])