### 座標とポテンシャルエネルギーを訓練データとして、エネルギー関数のパラメータ($\epsilon$と$\sigma$)を学習する

### 必要なパッケージを読み込む

In [None]:
using Flux, Plots
using ChainRulesCore
using Statistics
using Random
using Distributions
using MDToolbox

### ポテンシャルエネルギー関数の定義

In [None]:
function compute_energy(coords1, coords2, epsilon, sigma)
    r = sqrt(sum((coords1 .- coords2).^2))
    U = epsilon * ((sigma/r)^12 - 2 * (sigma/r)^6)
    return U
end

In [None]:
function ChainRulesCore.rrule(::typeof(compute_energy), coords1, coords2, epsilon, sigma)
    dx = coords1 .- coords2
    r = sqrt(sum(dx.^2))
    U = epsilon * ((sigma / r)^12 - 2 * (sigma / r)^6) 

    function compute_energy_pullback(dU)
        dcoords1 = -12 * epsilon * ((sigma^12 / r^14) - (sigma^6 / r^8)) .* dx .* dU
        dcoords2 = - dcoords1
        depsilon = ((sigma / r)^12 - 2 * (sigma / r)^6) * dU
        dsigma = 12 * epsilon * (sigma^11 / r^12 - sigma^5 / r^6) * dU 
        return NoTangent(), dcoords1, dcoords2, depsilon, dsigma
    end
    return U, compute_energy_pullback
end

In [None]:
#二原子のトラジェクトリからFを計算
function compute_free_energy(coords1_trj, coords2_trj, beta, epsilon, sigma)
    N = size(coords1_trj, 1)
    U_array = [compute_energy(coords1, coords2, epsilon, sigma) for (coords1, coords2) in zip(coords1_trj, coords2_trj)]
    #F = (-1 / beta) * log((1 / N) * sum(exp.(-beta .* U_array)))
    F = (-1 / beta) * (log(1 / N) + MDToolbox.logsumexp_1d(-beta .* U_array))
    return F
end

In [None]:
# exponential averaging
function compute_deltaF(coords_trj , beta, epsilon, sigma, U_array)
    N = size(coords_trj, 1)
    U_target_array = [compute_energy(coords1, coords2, epsilon, sigma) for (coords1, coords2) in coords_trj]
    deltaU = U_target_array .- U_array
    #@show deltaU
    #deltaF = (-1 / beta) * log((1 / N) * sum(exp.(-beta .* deltaU)))
    deltaF = (-1 / beta) * (log(1 / N) + MDToolbox.logsumexp_1d(-beta .* deltaU))
    return deltaF
end

In [None]:
function compute_distance(coords1, coords2)
    r = sqrt(sum((coords1 .- coords2).^2))
    return r
end

### MCMCの定義

In [None]:
#次のステップ候補点
function next_coords(coords1, coords2, delta_x = 1e-3 * 5)
    next_coords1 = coords1 .+ (rand(3) .- 0.5) .* delta_x
    next_coords2 = coords2 .+ (rand(3) .- 0.5) .* delta_x
    return next_coords1, next_coords2
end

In [None]:
#MCMC,メトロポリス法
function mcmc(coords1, coords2, epsilon, sigma, nstep = 100, beta = 1.0)
    coords1_trj = []
    coords2_trj = []
    alpha_trj = []
    push!(coords1_trj, coords1)
    push!(coords2_trj, coords2)
    
    for i in 1:nstep
        next_coords1, next_coords2 = next_coords(coords1, coords2) #候補点
        next_U = compute_energy(next_coords1, next_coords2, epsilon, sigma)
        current_U = compute_energy(coords1, coords2, epsilon, sigma)
        delta_U = next_U - current_U
        alpha = exp(-beta * delta_U) #相対確率
        
        #alpha>=1またはalpha<1のとき確率alphaで候補を受理
        if alpha > rand()
            coords1 .= next_coords1
            coords2 .= next_coords2
        end     
        push!(coords1_trj, copy(coords1))
        push!(coords2_trj, copy(coords2))
        push!(alpha_trj, alpha)
    end
    return coords1_trj, coords2_trj, alpha_trj
end

### ΔFを求める

In [None]:
kBT = 0.1
beta = 1.0/(kBT)
nstep = 100000

Random.seed!(11)
K = 3

epsilon = Vector{Float64}(undef, K)
sigma = Vector{Float64}(undef, K)


epsilon[1] = 1.0
sigma[1] = 4.0

epsilon[2] = 1.2
sigma[2] = 4.2

epsilon[3] = 1.1
sigma[3] = 4.1

#=
for i in 1:K
    epsilon[i] = 1.0 + 0.02 * (i - 1)
    sigma[i] = 4.0 + 0.02 * (i - 1)
end
=#

trj1 = Vector{Any}(undef, K)
trj2 = Vector{Any}(undef, K)
for k = 1:K
    trj1[k], trj2[k], alpha = mcmc(rand(3), rand(3) .+ 2, epsilon[k], sigma[k], nstep, beta)
end

In [None]:
u_kl = Array{Any}(undef, (K, K))

for k = 1:K
    for l = 1:K
        u_kl[k, l] = map(x -> beta * compute_energy(x[1], x[2], epsilon[l], sigma[l]), zip(trj1[k], trj2[k]))
    end
end

In [None]:
f_k = mbar(u_kl)
(1 ./ beta) .* f_k

In [None]:
epsilon_target = 1.1
sigma_target = 4.1

trj1_target, trj2_target, alpha = mcmc(rand(3), rand(3) .+ 2, epsilon_target, sigma_target, nstep, beta);

### 距離を計算

In [None]:
@show r1 = mean(compute_distance.(trj1[1], trj2[1])) 
@show r_target = mean(compute_distance.(trj1_target, trj2_target))

In [None]:
function _mbar_weight(u_kl, f_k, u_k=nothing)
    # K: number of umbrella windows
    K, L = size(u_kl)

    # N_k: number of data in k-th umbrella window
    N_k = zeros(Int64, K)
    for k = 1:K
        N_k[k] = length(u_kl[k, 1])
    end
    N_max = maximum(N_k)
    
    # conversion from array of array (u_kl) to array (u_kln)
    u_kln = zeros(Float64, K, K, N_max)
    for k = 1:K
        for l = 1:K
            u_kln[k, l, 1:N_k[k]] .= u_kl[k, l]
        end
    end

    # conversion from cell (u_k) to array (u_kn)
    u_kn = zeros(Float64, K, N_max)
    for k = 1:K
        if u_k === nothing
            u_kn[1, 1:N_k[k]] .= zero(Float64)
        else
            u_kn[k, 1:N_k[k]] .= u_k[k]
        end
    end

    log_w_kn = zeros(Float64, K, N_max)
    for k = 1:K
      log_w_kn[k, 1:N_k[k]] .= 1.0
    end
    idx = log_w_kn .> 0.5;

    log_w_kn = MDToolbox.mbar_log_wi_jn(N_k, f_k, u_kln, u_kn, K, N_max)
    log_w_n  = log_w_kn[idx]

    s = MDToolbox.logsumexp_1d(log_w_n)
    w_k = Vector{Vector{Float64}}(undef, K)
    for k = 1:K
      w_k[k] = exp.((log_w_kn[k, 1:N_k[k]] .- s))
    end

    return w_k
end


function ChainRulesCore.rrule(::typeof(_mbar_weight), u_kl, f_k, u_k)
    w_k = _mbar_weight(u_kl, f_k, u_k)

    function _mbar_weight_pullback(dw_k)
        # まず dw_k .* w_k を総和したスカラー T を計算
        T = 0.0
        for i in eachindex(w_k)
            for j in eachindex(w_k[i])
                T += dw_k[i][j] * w_k[i][j]
            end
        end

        # du_k の領域を用意し、各要素をまとめて計算
        du_k = similar(w_k)  # w_k と同じ「配列の配列」構造をもつ

        for i in eachindex(w_k)
            du_k[i] = similar(w_k[i])  # 内側の配列部分も同様に確保
            for j in eachindex(w_k[i])
                # du_k[i][j] = w_k[i][j] * (T - dw_k[i][j])
                du_k[i][j] = w_k[i][j] * (T - dw_k[i][j])
            end
        end

        # 戻り値は ( ∂u_kl無関係, ∂f_k無関係, ∂u_k無関係, du_k )
        return NoTangent(), ZeroTangent(), NoTangent(), du_k
    end

    return w_k, _mbar_weight_pullback
end


In [None]:
# Sanity check
u_k = Vector{Any}(undef, K)

for k = 1:K
    u_k[k] = map(x -> beta * compute_energy(x[1], x[2], epsilon_target, sigma_target), zip(trj1[k], trj2[k])) #あとでチェック
end

#f_target = mbar_f(u_kl, f_k, u_k)

weight_target = _mbar_weight(u_kl, f_k, u_k)

r = Vector{Float64}(undef, K)
for k = 1:K
    r[k] = sum(compute_distance.(trj1[k], trj2[k]) .* weight_target[k])
end
@show r_target = mean(compute_distance.(trj1_target, trj2_target))
@show sum(r)

In [None]:
k = 1
r = compute_distance.(trj1[k][1:100:end], trj2[k][1:100:end])
histogram(r,label="current", c=:blue, alpha=0.5, fill=false, seriestype=:stephist)
r = compute_distance.(trj1_target[1:100:end], trj2_target[1:100:end])
histogram!(r,label="target", c=:red, alpha=0.5, fill=false, seriestype=:stephist)

plot!(#title="MCMC-sampled pair distances", 
    xlim=(3, 5), xlabel="Pair distance r [nm]",
    ylabel="Frequency", size=(400,300), dpi=900)

In [None]:
#パラメータ推定に用いる原子の軌跡と距離の平均
X_train = []
for k = 1:K
    push!(X_train, (trj1[k], trj2[k]))
end
y_train = r_target

In [None]:
i = 1
u = compute_energy.(trj1[i], trj2[i], epsilon[1], sigma[1])

### 勾配法で訓練してK(バネ定数)とR(平衡長)を推定

In [None]:
function compute_weighted_distance(X, f_k, u_kl, beta, epsilon, sigma)
    #K = length(f_k)
    #u_k = Vector{Vector{Float64}}(undef, 0)
    #for k = 1:K
        #for istep = 1:nstep
        #    u_k[k][istep] = beta * compute_energy(X[1][k][istep], X[2][k][istep], m.P[1], m.P[2])
        #end
        #push!(u_k, map(x -> beta * compute_energy(x[1], x[2], epsilon, sigma), zip(X[1][k], X[2][k])))
    #end
    #u_k = map(XX -> map(x -> beta * compute_energy(x[1], x[2], epsilon, sigma), zip(XX[1], XX[2])), X)
    u_k = [compute_energy.(trj1[k], trj2[k], epsilon, sigma) for k in 1:K]
    weight_target = _mbar_weight(u_kl, f_k, u_k)
    #r = Vector{Float64}(undef, K)
    #for k = 1:K
        #r[k] = sum(compute_distance.(trj1[k], trj2[k]) .* weight_target[k])
    #end
    weighted_distance = sum([sum(compute_distance.(trj1[k], trj2[k]) .* weight_target[k]) for k in 1:K])
    return weighted_distance
end

In [None]:
struct MBAR
    P::AbstractArray #P[1] = epsilon, P[2] = sigma
end

Flux.@functor MBAR (P,)

(m::MBAR)(X) = compute_weighted_distance(X, f_k, u_kl, beta, m.P[1], m.P[2])

In [None]:
# Sanity check
m = MBAR([epsilon_target, sigma_target])
loss(X, y) = Flux.Losses.mse(m(X), y)
loss(X_train, y_train)

In [None]:
m = MBAR([epsilon[1], sigma[1]])
loss(X, y) = Flux.Losses.mse(m(X), y)
loss(X_train, y_train)

In [None]:
# DataやOptimizerの定義
#train_loader = Flux.Data.DataLoader(X_train, batchsize=10, shuffle=true)
ps = Flux.params(m)
opt = ADAM(1e-2)

In [None]:
@time gs = gradient(() -> loss(X_train, y_train), ps)

In [None]:
# コールバック関数の定義
function print_callback(epoch, loss, ps)
    println("Epoch: $epoch, loss: $loss")
    println("param: $(ps[1])")
end

In [None]:
loss_train = []

for epoch in 1:10
    gs = gradient(() -> loss(X_train, y_train), ps)
    Flux.Optimise.update!(opt, ps, gs)
    
    L = loss(X_train, y_train)
    push!(loss_train, L)
    print_callback(epoch, L, ps)
end

In [None]:
# Plot
plot(loss_train,
     xlabel = "Epoch",              # x軸ラベル
     ylabel = "Loss",               # y軸ラベル
     #title  = "Training Loss per Epoch",  # タイトル
     label  = "Training Loss",      # 凡例に表示される名前
     lw     = 2,                    # 線の太さ
     framestyle=:box,
     legend = :topright,
     size=(400, 300),
     dpi=900)            # 凡例の位置（右上に表示）
#savefig("./figure/mcmc_loss")

In [None]:
println("Initial epsilon:   ", epsilon[1])
println("Estimated epsilon: ", m.P[1])
println("Target epsilon:    ", epsilon_target)
println()
println("Initial sigma:   ", sigma[1])
println("Estimated sigma: ", m.P[2])
println("Target sigma:    ", sigma_target)

In [None]:
epsilon_estimated = deepcopy(m.P[1])
sigma_estimated = deepcopy(m.P[2])
Random.seed!(10)
trj1_estimated, trj2_estimated, alpha = mcmc(rand(3), rand(3) .+ 2, epsilon_estimated, sigma_estimated, nstep, beta);

In [None]:
using Statistics, StatsBase, Plots

# KLダイバージェンスの計算
function kl_divergence(p::Vector{Float64}, q::Vector{Float64})
    mask = (p .> 0) .& (q .> 0)  # 0 にならないようにマスク
    return sum(p[mask] .* log.(p[mask] ./ q[mask]))
end

# JSD の計算
function js_divergence(p::Vector{Float64}, q::Vector{Float64})
    p = p / sum(p)  # 正規化（確率分布にする）
    q = q / sum(q)
    m = (p + q) / 2
    return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)
end

# 2つのデータセットのヒストグラムを計算し、JSD を求める関数
function histogram_js_divergence(data1::Vector{Float64}, data2::Vector{Float64}; bins=20)
    # ヒストグラムの計算
    hist1 = fit(Histogram, data1, range(minimum(vcat(data1, data2)), maximum(vcat(data1, data2)), length=bins+1))
    hist2 = fit(Histogram, data2, range(minimum(vcat(data1, data2)), maximum(vcat(data1, data2)), length=bins+1))

    # 各ビンの確率分布を計算（正規化）
    p = hist1.weights ./ sum(hist1.weights)
    q = hist2.weights ./ sum(hist2.weights)

    # JSD の計算
    return js_divergence(p, q)
end

# 2つのランダムデータセットを作成
using Random
Random.seed!(123)

data1 = randn(1000)          # 標準正規分布
data2 = randn(1000) .+ 1.0   # 平均を 1 ずらした正規分布

# JSD の計算
jsd_value = histogram_js_divergence(data1, data2, bins=30)

println("Jensen-Shannon ダイバージェンス: ", jsd_value)

# ヒストグラムの描画
histogram(data1, bins=30, alpha=0.5, label="Data 1", normalize=true, color=:blue)
histogram!(data2, bins=30, alpha=0.5, label="Data 2", normalize=true, color=:red)
title!("Histogram Comparison of Two Distributions")
xlabel!("Value")
ylabel!("Probability Density")


In [None]:
k = 1
r = compute_distance.(trj1[k][1:100:end], trj2[k][1:100:end])
histogram(r,label="current", c=:blue, alpha=0.5, fill=false, seriestype=:stephist, bins=30)
r = compute_distance.(trj1_target[1:100:end], trj2_target[1:100:end])
histogram!(r,label="target", c=:red, alpha=0.5, fill=false, seriestype=:stephist, bins=30)
plot!(
    #title="Pair distances : before training", 
    xlim=(3, 5), xlabel="Pair distance r [nm]", ylabel="Frequency", size=(400,300), dpi=900, framestyle=:box)

#savefig("./figure/mcmc_before_training")

In [None]:
r_current = compute_distance.(trj1[1], trj2[1])
r_target = compute_distance.(trj1_target, trj2_target)
jsd = histogram_js_divergence(r_current, r_target, bins=30)

In [None]:
r = compute_distance.(trj1_estimated[1:100:end], trj2_estimated[1:100:end])
histogram(r,label="estimated", c=:blue, alpha=0.5, fill=false, seriestype=:stephist, bins=30)
r = compute_distance.(trj1_target[1:100:end], trj2_target[1:100:end])
histogram!(r,label="target", c=:red, alpha=0.5, fill=false, seriestype=:stephist, bins=30)
plot!(#title="Pair distances : after training", 
    xlim=(3, 5), xlabel="Pair distance r [nm]", ylabel="Frequency", size=(400, 300), dpi=900, framestyle=:box)
#savefig("./figure/mcmc_after_training")

In [None]:
r_estimated = compute_distance.(trj1_estimated, trj2_estimated)
r_target = compute_distance.(trj1_target, trj2_target)
jsd = histogram_js_divergence(r_estimated, r_target, bins=30)