In [None]:
function proba_1_point(seq, h, J, N; q = 21, T = 1)
	prob = zeros(q,N)
    for site in 1:N
        norm = 0 
        for a in 1:q
            log_proba = h[a, site]
            for j in 1:N
                log_proba += J[a, seq[j], site, j]
            end
            prob[a,site] = exp(log_proba/T)
            norm += prob[a,site]
        end
        
        for a in 1:q
            prob[a,site] /= norm
        end
    end
    return prob
end

function single_proba_2_point(a, b, i, j, seq, h, J, N, T)
    log_proba = h[a, i] + h[b, j] + J[a, b, i, j]
    for k in 1:N
        if k !== j
            log_proba += J[a,seq[k], i, k]
        end
    end
    
    for l in 1:N
        if l !== i
            log_proba += J[seq[l], b, l, j]
        end
    end
    return exp(log_proba/T)
end


function dist_proba_2_point(i, j, seq, h, J, N, q, T) 
    res = zeros(q,q)
    norm = 0.
    for a in 1:q
        for b in 1:q
            res[a,b] = single_proba_2_point(a, b, i, j, seq, h, J, N, T)
            norm += res[a,b]
        end
    end
    return res/norm
end


function proba_2_point(seq, h, J, N; q = 21, T = 1) 
    res = zeros(q,L,q,L)
    for i in 1:L
        for j in 1:L
            res[:,i,:,j] = dist_proba_2_point(i, j, seq, h, J, N, q, T) 
        end
    end
    return res
end



## Dkl order

In [None]:
function loc_softmax(x, temp)
    max_ = maximum(x)
    out = zeros(size(x))
    if all(isfinite, max_)
        @fastmath out .= exp.(x .- max_)
    else
        _zero, _one, _inf = 0, 1, Inf
        @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
    end
    #tmp = dims isa Colon ? sum(out) : sum!(max_, out)
    out .= (out).^(1/temp)
    out ./= sum(out)#tmp
    return out
end 

function DklAdd(site, seq, h, J, L; version = "Leo")
    p = zeros(L,21)
    q = exp.(h[:,site]) ./ sum(exp.(h[:,site]))
    dkl = zeros(L)
    ord = []
    cdes = []
    for k in 1:L
        if k !== site
            p[k,:] = proba_DNA_gibbs_masked(site,[k],seq[k],h,J)
            dkl[k] = kldivergence(p[k,:], q)
        end
    end
    push!(ord, argmax(dkl))
    push!(cdes, get_entropy(p[argmax(dkl), :])[1])
    if version == "Leo"
        q .= p[argmax(dkl), :]
    end
    p .= 0
    dkl .= 0
    
    for n in 2:L-1
        for k in 1:L
            if k !== site && !(k in ord)
                p[k,:] = proba_DNA_gibbs_masked(site,vcat(ord,k),seq[vcat(ord,k)],h,J)
                dkl[k] = kldivergence(p[k,:], q)
            end
        end    
        push!(ord, argmax(dkl))
        push!(cdes, get_entropy(p[argmax(dkl), :])[1])
        if version == "Leo"
            q .= p[argmax(dkl), :]
        end
        p .= 0
        dkl .= 0
    end
    return ord, cdes
end 


function cde_decrease(site, seq, h, J, L; version = "Leo")
    p = zeros(L,21)
    cde = zeros(L)
    ord = []
    ps = zeros(L-1, 21)
    cdes = []
    for k in 1:L
        if k !== site
            p[k,:] = proba_DNA_gibbs_masked(site,[k],seq[k],h,J)
            cde[k] = get_entropy(p[k,:])[1]
        end
    end
    cde[site] = 10
    cde[ord] .= 10
    push!(ord, argmin(cde))
    push!(cdes, minimum(cde))
    ps[1,:] .= p[argmin(cde),:]
    p .= 0
    cde .= 0
    
    for n in 2:L-1
        for k in 1:L
            if k !== site && !(k in ord)
                p[k,:] = proba_DNA_gibbs_masked(site,vcat(ord,k),seq[vcat(ord,k)],h,J)
                cde[k] = get_entropy(p[k,:])[1]
            end
        end   
        cde[site] = 10
        cde[ord] .= 10
        push!(ord, argmin(cde))
        push!(cdes, minimum(cde))
        ps[n,:] .= p[argmin(cde),:]
        p .= 0
        cde .= 0
    end
    return ps, cdes
end 
    
function pathDkl(final_aa, site, seq, h, J, L; mult_loss = true, best = true, temp = 0.1, reg_E = 0.1)
    p = zeros(21)
    q = zeros(21)
    q[final_aa] = 1.
    pseudocount1!(q,q,10^-4,21)
    dkl = 10^8 .* ones(L, 21)
    dE = 10^8 .* ones(L, 21)
    wt_dEs = []
    single_dEs = []
    seqs = zeros(Int, L, L)
    aas = []
    sites = []
    for i in 1:L
        seqs[i,1] = seq[i]
    end
    cdes = [cde_1site(site, seq, h, J)[1]]
    
    seqs[:,1] .= seq
    
    for n in 2:L
        for k in 1:L
            if k !== site
                for aa in 1:21
                    new_seq = deepcopy(seqs[:,n-1])
                    new_seq[k] = aa
                    p .= proba_DNA_gibbs_without_deg(site,new_seq,h,J,L)
                    pseudocount1!(p,p,10^-4,21)
                    dkl[k,aa] = kldivergence(p, q)
                    if mult_loss == true
                        dE[k,aa] = delta_energy(h,J, new_seq, seq)
                    else
                        dE[k,aa] = delta_energy(h,J, new_seq, seqs[:,n-1])
                    end
                end
            end
        end
        
        if best == true 
            i, amino = Tuple(argmin(dkl .+ (reg_E .* dE .* dE)))
        else
            cost = (dkl .+ (reg_E .* dE .* dE))
            idx = sample(1:L*21, weights(loc_softmax(.-cost, temp)))
            i = mod(idx, L)
            if i == 0
                i = 76
            end
            amino = div(idx,L)+1
            if amino == 22
                amino = 21
            end
        end
        seqs[:,n] .= seqs[:,n-1]
        seqs[Int(i), n] = Int(amino)
        push!(cdes, cde_1site(site,seqs[:,n], h, J)[1])
        push!(aas, amino)
        push!(sites,i)
        push!(single_dEs, delta_energy(h, J, seqs[:,n], seqs[:,n-1]))
        push!(wt_dEs, delta_energy(h, J, seqs[:,n], seq))
        dkl .= 10^8
    end 
    
    return seqs, cdes, aas, sites, single_dEs, wt_dEs
end 

function pseudocount1!(dest, f1, pc::AbstractFloat, q::Int)
     dest .= ((1-pc) .* f1 ) .+ (pc / q)
end
