In [3]:
include("ParRep.jl"); using .ParRep



In [4]:
Base.@kwdef mutable struct GelmanRubinDiagnostic{F}
    observables::F
    means = Matrix{Float64}(undef,1,1)
    sq_means = Matrix{Float64}(undef,1,1)
    burn_in = 10
    tol = 1e-4
    gr_hist = Float64[]
end

function check_dephasing!(checker::GelmanRubinDiagnostic,replicas::Vector{X},step_n) where {X}
    
    if step_n == 1 # initialize running mean and running square_mean buffer
        checker.means = zeros(length(checker.observables),length(replicas))
        checker.sq_means = copy(checker.means)
    end
    
    @threads for i=1:length(replicas)
        r = replicas[i]
        for (j,f)=enumerate(checker.observables)
            val = f(r)
            sq_val = val^2

            checker.means[j,i] += (val-checker.means[j,i]) / step_n
            checker.sq_means[j,i] += (sq_val-checker.sq_means[j,i]) / step_n

        end
    end

    (step_n < checker.burn_in) && return false

    Obar = sum(checker.means;dims = 2) / length(replicas)

    numerator = sum(@. (checker.sq_means -2checker.means*Obar + Obar^2);dims=2)
    denominator = sum(checker.sq_means - checker.means .^ 2;dims=2)

    R = maximum(numerator ./ denominator) - 1
    push!(checker.gr_hist,R)
    return (R < checker.tol)
end


In [6]:
Base.@kwdef mutable struct SteepestDescentState{X}
    η = 0.1
    dist_tol = 1e-1
    grad_tol = 1e-5
    steps = 100
    minima = X[]
    ∇V::Function
end

function get_state!(checker::SteepestDescentState,state::X,_) where {X}
    for k=1:checker.steps
        grad = checker.∇V(state)
        state -= checker.η * checker.grad # gradient descent step
        deltas = [√sum(abs2,m-state) for m in checker.minima]

        any(delta < checker.dist_tol for delta in deltas) && return first(i for (i,m)=enumerate(checker.minima) if deltas[i]<checker.dist_tol)

        grad_norm = √sum(abs2,grad)
        (grad_norm < checker.grad_tol) && break
    end

    if grad_norm < checker.grad_tol
        push!(checker.minima,state)
        return length(checker.minima)
    else
        return nothing
    end
end

get_state! (generic function with 1 method)