In [None]:
using Flux, Plots
using ChainRulesCore
using Statistics
using Random
using MDToolbox
using BenchmarkTools
using EzXML
using LinearAlgebra
using SparseArrays
using Dates
using Distributed
using FiniteDifferences

ENV["COLUMNS"] = 130

In [None]:
pdb_filepath = "./alanine-dipeptide-nowater.pdb"
traj_dir = "./sim_dihedral/"
ff_dir = "./sim_dihedral/"
njobs = 4

target_traj_path = "./target/sim_target.dcd"
target_ff_filepath = "./data/amber14/protein.ff14SB.xml" 

In [None]:
slice = 1000 # データのスライス
#atom_list = [2 11] #距離を計算する原子のペア
atom_list = [5 17]

atom_pairs = []
for i in 1:length(atom_list)
    for j in i+1:length(atom_list)
        push!(atom_pairs, [atom_list[i] atom_list[j]])
    end
end

In [None]:
#pythonのmdtrajのインデックスに+1してる
phi_indices = [5, 7, 9, 15]   #二面角ϕ
psi_indices = [7, 9, 15, 17]  #二面角ψ

#xmlファイルでのatom type
phi_atom_type = ["C", "N", "CT", "C"]
psi_atom_type = ["N", "CT", "C", "N"]

In [None]:
function safe_acos(x::Float64)
    # xの値を[-1, 1]の範囲にクランプする
    return acos(clamp(x, -1.0, 1.0))
end

function _compute_dihedral(ta1::TrjArray{T, U}, ta2::TrjArray{T, U}, ta3::TrjArray{T, U}, ta4::TrjArray{T, U})::Vector{T} where {T, U}
    nframe = ta1.nframe
    com1 = centerofmass(ta1, isweight=true)
    com2 = centerofmass(ta2, isweight=true)
    com3 = centerofmass(ta3, isweight=true)
    com4 = centerofmass(ta4, isweight=true)
    a = zeros(T, nframe)
    # Threads.@threads for iframe in 1:nframe
    for iframe in 1:nframe
        d1 = [com1.xyz[iframe, 1] - com2.xyz[iframe, 1]; com1.xyz[iframe, 2] - com2.xyz[iframe, 2]; com1.xyz[iframe, 3] - com2.xyz[iframe, 3]]
        d2 = [com3.xyz[iframe, 1] - com2.xyz[iframe, 1]; com3.xyz[iframe, 2] - com2.xyz[iframe, 2]; com3.xyz[iframe, 3] - com2.xyz[iframe, 3]]
        d3 = [com3.xyz[iframe, 1] - com4.xyz[iframe, 1]; com3.xyz[iframe, 2] - com4.xyz[iframe, 2]; com3.xyz[iframe, 3] - com4.xyz[iframe, 3]]
        m1 = cross(d1, d2)
        m2 = cross(d2, d3)
        a[iframe] = safe_acos(dot(m1, m2)/(norm(m1)*norm(m2)))
        rotdirection = dot(d2,cross(m1,m2))
        if rotdirection < zero(T)
            a[iframe] = -a[iframe]
        end
    end
    a .= (a ./ pi) .* T(180)
end

In [None]:
top = readpdb(pdb_filepath)

ta = mdload(target_traj_path)

phi = _compute_dihedral(ta[:,phi_indices[1]], ta[:,phi_indices[2]], ta[:,phi_indices[3]], ta[:,phi_indices[4]]) * π / 180
psi = _compute_dihedral(ta[:,psi_indices[1]], ta[:,psi_indices[2]], ta[:,psi_indices[3]], ta[:,psi_indices[4]]) * π / 180
dihedral_target = hcat(phi, psi)
    
distance_target = []
for i in 1:length(atom_pairs)
    d = compute_distance(ta, atom_pairs[i]) ./ 10
    push!(distance_target, d)
end
distance_target = hcat(distance_target...)

In [None]:
function kde_estimate(data::AbstractVector; weight::AbstractVector = ones(length(data)), bandwidth=nothing, num_points::Int=1000)
    # If bandwidth is not specified, estimate it using Silverman's rule
    if isnothing(bandwidth)
        n = length(data)
        s = std(data)
        IQR = quantile(data, 0.75) - quantile(data, 0.25)
        bandwidth = 0.9 * min(s, IQR / 1.34) / n^(1/5)
    end
    # Compute kernel density estimate
    density_estimate = zeros(num_points)
    x_grid_dense = range(minimum(data), maximum(data), length=num_points)
    
    for i in 1:num_points
        x = x_grid_dense[i]
        kernel_sum = 0.0
        for (val, w) in zip(data, weight)
            kernel_sum += w * exp(-((x - val) / bandwidth)^2 / 2) / (bandwidth * sqrt(2 * π))
        end
        density_estimate[i] = kernel_sum / sum(weight)
    end
    
    return x_grid_dense, density_estimate
end

In [None]:
function kde_estimate(data::AbstractVector, kernel; 
        weight::AbstractVector = ones(length(data)), bandwidth=nothing, num_points::Int=100, x_grid = Float64[])
    # If bandwidth is not specified, estimate it using Silverman's rule
    if isnothing(bandwidth)
        n = length(data)
        s = std(data)
        IQR = quantile(data, 0.75) - quantile(data, 0.25)
        @show bandwidth = 0.9 * min(s, IQR / 1.34) / n^(1/5)
    end
    # Compute kernel density estimate
    density_estimate = zeros(num_points)
    # Compute kernel density estimate
    if isempty(x_grid)
        x_grid = range(minimum(data), maximum(data), length=num_points) |> Vector
    end
    
    for i in 1:num_points
        x = x_grid[i]
        kernel_sum = 0.0
        for (val, w) in zip(data, weight)
            kernel_sum += w * kernel((x - val) / bandwidth)
        end
        density_estimate[i] = kernel_sum / sum(weight) / bandwidth
    end
    
    return x_grid, density_estimate
end
function gaussian_kernel(x)
    return exp(-0.5 * x^2) / sqrt(2 * π)
end

function kde_estimate_cpu(data_k, weight_k, kernel::Function; bandwidth=0.0, num_points=1000, x_grid=Float64[])
    N_k = Array{Int}(undef, K)
    for k in 1:K
        N_k[k] = length(data_k[k])
    end
    
    data = vcat(data_k...)
    weight = vcat(weight_k...)

    # If bandwidth is not specified, estimate it using Silverman's rule
    if bandwidth == 0.0
        n = length(data)
        s = std(data)
        IQR = quantile(data, 0.75) - quantile(data, 0.25)
        @show bandwidth = 0.9 * min(s, IQR / 1.34) / n^(1/5)
    end
    
    if isempty(x_grid)
        x_min = minimum(data)
        x_max = maximum(data)
        x_grid = Array(range(x_min, x_max, length=num_points))
    end

    density_estimate = similar(x_grid)
    
    for i in 1:length(x_grid)
        x = x_grid[i]
        kernel_sum = 0.0
        #=
        for j = 1:length(data)
            kernel_sum += weight[j] * kernel((x - data[j]) / bandwidth)
        end
        =#
        kernel_sum = reduce(+, weight .* kernel.((x .- data) ./ bandwidth))
        density_estimate[i] = kernel_sum / sum(weight) / bandwidth
    end
    
    return density_estimate
end


function ChainRulesCore.rrule(::typeof(kde_estimate_cpu), data_k, weight_k, kernel::Function;
        bandwidth=0.0, num_points=1000, x_grid=Float64[])
    density_estimate = kde_estimate_cpu(data_k, weight_k, kernel, bandwidth=bandwidth, num_points=num_points, x_grid=x_grid)
    function kde_estimate_cpu_pullback(dU) 
        K = length(data_k)
        N_k = Array{Int}(undef, K)
        for k in 1:K
            N_k[k] = length(data_k[k])
        end

        data = vcat(data_k...)
        weight = vcat(weight_k...)

        dweight_k = similar(weight_k)
        dweight = similar(weight)

        dweight .= 0.0
        sum_weight = sum(weight)
        for i in 1:length(data)
            for j in 1:length(x_grid)
                dweight[i] += kernel((x_grid[j] - data[i]) / bandwidth) / sum_weight / bandwidth * dU[j]
            end
        end
        index_start = 1
        for k in 1:K
            index_end = index_start + N_k[k] - 1
            dweight_k[k] = dweight[index_start:index_end]
            index_start = index_end + 1
        end

        return NoTangent(), NoTangent(), dweight_k, NoTangent(), NoTangent(), NoTangent(), NoTangent()
    end

    return density_estimate, kde_estimate_cpu_pullback
end
num_points = 100
@time x_grid, density_estimate = kde_estimate(distance_target[1:100:end], gaussian_kernel, num_points=num_points)
plot(x_grid, density_estimate, label="Kernel Density Estimate", xlabel="x", ylabel="Density", linewidth=2, title="r distribution")

In [None]:
bandwidth = 0.0024

In [None]:
"""
    function calc_histogram(data::AbstractVector;
                            rng=nothing,
                            bin_width=0.005, # nm
                            nbin=nothing,
                            density::Bool=false,
                            weight::AbstractArray=ones(length(data)))

Calculate a histogram of the input data `data`.

# Arguments
- `data::AbstractVector`: Input data vector.
- `rng::Tuple{Real, Real}`: Range of values to consider for the histogram. If not provided, the minimum and maximum values of `data` will be used.
- `bin_width::Real=0.005`: Width of each histogram bin.
- `nbin::Integer`: Number of bins for the histogram. If not provided, it will be automatically calculated based on `rng` and `bin_width`.
- `density::Bool=false`: If `true`, normalize the histogram to form a probability density.
- `weight::AbstractArray=[]`: Optional weights associated with each data point.

# Returns
- `hist::Array{Float64,1}`: Counts of data points in each bin.
- `bin_edge::Array{Float64,1}`: Edges of the bins.

# Examples
```julia-repl
julia> data = randn(1000)  # Generate random data
julia> hist, bin_edge = calc_histogram(data, rng=(-3, 3), bin_width=0.1, density=true)
```
"""
function calc_histogram(data::AbstractArray;
                        rng=nothing,
                        bin_width=0.005, # nm
                        nbin=nothing,
                        density::Bool=false,
                        weight::AbstractArray=ones(length(data)))
    
    # If range is not specified, use the range of the data
    if rng == nothing
        rng = (minimum(data), maximum(data))
    end
    # If data falls outside the specified range, ignore it
    data = filter(x -> rng[1] <= x && x <= rng[2], data)
    
    # If nbin is not specified, calculate it based on the bin width
    if nbin == nothing
        nbin = ceil(Int, (rng[2] - rng[1]) / bin_width)
    else
        # Recalculate bin width based on nbin
        bin_width = (rng[2] - rng[1]) / nbin
    end
    
    # Initialize histogram bins
    hist = zeros(Float64, nbin)
    
    # Calculate bin edges
    bin_edge = range(rng[1], rng[2], nbin+1) |> Vector
    
    # Calculate bin centers
    bin_center = (bin_edge[1:end-1] + bin_edge[2:end]) / 2

    min_value = minimum(data)
    # Fill histogram bins
    for (val, w) in zip(data, weight)
        #bin_index = argmin(abs.(bin_center .- val))
        bin_index = min(floor(Int, (val - min_value) / bin_width) + 1, nbin)    
        hist[bin_index] += w
    end
    
    # Normalize by total weight if density is true
    if density
        total_weight = sum(weight)
        hist ./= total_weight
    end
    
    return hist, bin_edge, rng
end

In [None]:
function calc_histogram(data_k::Array{<:AbstractArray},
                        weight_k::Array{<:AbstractArray};
                        rng=nothing,
                        bin_width=0.005, # nm
                        nbin=nothing,
                        density::Bool=false)

    data = vcat(data_k...)
    weight = vcat(weight_k...)
    hist, bin_edge, _ = calc_histogram(data, rng=rng, bin_width=bin_width, nbin=nbin, density=density, weight=weight)
    
    return hist
end

function ChainRulesCore.rrule(::typeof(calc_histogram), 
                        data_k::Array{<:AbstractArray},
                        weight_k::Array{<:AbstractArray};
                        rng=nothing,
                        bin_width=0.005, # nm
                        nbin=nothing,
                        density::Bool=false)

    data = vcat(data_k...)
    weight = vcat(weight_k...)
    hist, bin_edge, rng = calc_histogram(data, rng=rng, bin_width=bin_width, nbin=nbin, density=density, weight=weight)
    K = length(data_k)
    N_k = Array{Int}(undef, K)
    for k in 1:K
        N_k[k] = size(data_k[k], 1)
    end
    function calc_histogram_pullback(dP)
        dweight = similar(weight)
        dweight .= 0.0

        bin_center = (bin_edge[1:end-1] + bin_edge[2:end]) / 2
        #println(bin_center)
        #println(length(dweight))
        min_value = minimum(data)
        for i in 1:length(dweight)
            #bin_index = argmin(abs.(bin_center .- data[i]))
            bin_index = min(floor(Int, (data[i] - min_value) / bin_width) + 1, nbin)
            dweight[i] += 1.0 * dP[bin_index]
        end

        #println(dweight)
        if density
            dweight = dweight ./ sum(weight)
        end
        
        dweight_k = similar(weight_k)

        istart = 1
        for i in 1:length(data_k)
            iend = istart + N_k[i] - 1
            dweight_k[i] = dweight[istart:iend]
            istart = iend + 1
        end
        
        return NoTangent(), NoTangent(), dweight_k, NoTangent(), NoTangent(), NoTangent(), NoTangent()
    end

    return hist, calc_histogram_pullback
end

In [None]:
nbin = 30
data = distance_target
@time hist, bin_edge, rng = calc_histogram(data, nbin=nbin, density=true)
bin_center = (bin_edge[1:end-1] .+ bin_edge[2:end]) ./ 2
bar(bin_center, hist, width=1, alpha=0.5 ,title="r distribution", ylim=(0, 0.1))

In [None]:
#input traj

dihedral_k = Array{Array{Float64}}(undef, njobs) # 二面角ϕ,ψの値
distance_k = Array{Array{Float64}}(undef, njobs) # 二つの原子間の距離

for i in 1:njobs
    traj_filepath = joinpath(traj_dir, "sim_$(i)/traj_$(i).dcd")
    ta = mdload(traj_filepath, top=top)
    ta = ta[1:slice:end]

    phi = compute_dihedral(ta[:,phi_indices[1]], ta[:,phi_indices[2]], ta[:,phi_indices[3]], ta[:,phi_indices[4]]) * π / 180
    psi = compute_dihedral(ta[:,psi_indices[1]], ta[:,psi_indices[2]], ta[:,psi_indices[3]], ta[:,psi_indices[4]]) * π / 180
    dihedral_k[i] = hcat(phi, psi)

    distance = []
    for i in 1:length(atom_pairs)
        d = compute_distance(ta, atom_pairs[i]) ./ 10
        push!(distance, d)
    end
    distance_k[i] = hcat(distance...)
end

In [None]:
using Plots
using Statistics  # 標準偏差やその他統計量のために使用

# 2Dカーネル密度推定とPMFを計算する関数
function calculate_pmf(x_data::Vector{<:Real}, y_data::Vector{<:Real}, grid_size::Int=100)
    # グリッドの生成
    x_grid = range(-π, stop=π, length=grid_size)
    y_grid = range(-π, stop=π, length=grid_size)

    # カーネル密度推定のバンド幅（スコットのルールを使用）
    #hx = 1.06 * std(x_data) * length(x_data)^(-1/5)  # xのバンド幅
    #hy = 1.06 * std(y_data) * length(y_data)^(-1/5)  # yのバンド幅

    hx = 0.5 * std(x_data) * length(x_data)^(-1/5)  # xのバンド幅
    hy = 0.5 * std(y_data) * length(y_data)^(-1/5)  # yのバンド幅
    # カーネル関数（ガウスカーネル）
    function gaussian_kernel(u)
        return exp(-0.5 * u^2) / sqrt(2 * π)
    end

    # カーネル密度推定の計算
    density = zeros(grid_size, grid_size)
    for i in 1:grid_size
        for j in 1:grid_size
            xi, yj = x_grid[i], y_grid[j]
            sum_k = 0.0
            for k in 1:length(x_data)
                sum_k += gaussian_kernel((xi - x_data[k]) / hx) * gaussian_kernel((yj - y_data[k]) / hy)
            end
            density[i, j] = sum_k / (length(x_data) * hx * hy)
        end
    end

    return x_grid, y_grid, density
end

# PMFの等高線プロットを作成
function plot_pmf_contour(x_grid, y_grid, pmf; title="PMF")
    # カラーマップの設定：閾値以上を白に設定
    
    # viridisの色を手動で指定
    viridis_colors = [
        RGB(0.267, 0.004, 0.329),
        RGB(0.283, 0.141, 0.458),
        RGB(0.254, 0.265, 0.530),
        RGB(0.207, 0.372, 0.553),
        RGB(0.164, 0.471, 0.558),
        RGB(0.128, 0.567, 0.551),
        RGB(0.136, 0.659, 0.517),
        RGB(0.267, 0.749, 0.441),
        RGB(0.478, 0.821, 0.318),
        RGB(0.741, 0.873, 0.150),
        RGB(0.993, 0.906, 0.144)
    ]
    
    # 最後に白色を追加
    extended_colors = vcat(viridis_colors, [RGB(1.0, 1.0, 1.0)])
    
    # カスタムカラーマップを作成
    extended_cmap = cgrad(extended_colors)
    
    # このカラーマップを使ってプロット
    # 例: contourf(x_grid, y_grid, data, color=extended_cmap)
    
    x_grid = x_grid .* 180 ./ π
    y_grid = y_grid .* 180 ./ π
    
    # カラーマップの範囲を設定
    threshold = 1e-8
    log_pmf = -log10.(pmf.+threshold)
    clims = (minimum(log_pmf), threshold)
    contourf(x_grid, y_grid, log_pmf, color=extended_cmap, colorbar=true, levels=10, dpi=900)
    xlabel!("Phi")
    ylabel!("Psi")
    title!(title)
end

In [None]:
x_data = dihedral_target[1:1000:end, 2]
y_data = dihedral_target[1:1000:end, 1]

x_grid, y_grid, pmf = calculate_pmf(x_data, y_data)
plot_pmf_contour(x_grid, y_grid, pmf, title="target")

In [None]:
p = []
for k in 1:length(dihedral_k)
    # データの準備
    x_data = dihedral_k[k][1:1:end, 2]
    y_data = dihedral_k[k][1:1:end, 1]
    
    # PMFの計算
    x_grid, y_grid, pmf = calculate_pmf(x_data, y_data)
    tmp = plot_pmf_contour(x_grid, y_grid, pmf, title="k=$(k)")
    push!(p, tmp)
end
plot(p..., layout=(2, 2))

In [None]:
nbin = 30
data = distance_target[1:100:end]
@time hist, bin_edge, rng = calc_histogram(data, nbin=nbin, density=true)
bin_center = (bin_edge[1:end-1] .+ bin_edge[2:end]) ./ 2
bar(bin_center, hist, width=1, alpha=0.5 ,title="r distribution")

In [None]:
p = []
for i in 1:length(distance_k)
    nbin = 30
    data = vcat(distance_k[i])
    @time hist, bin_edge, rng = calc_histogram(data, nbin=nbin, density=true)
    bin_center = (bin_edge[1:end-1] .+ bin_edge[2:end]) ./ 2
    tmp = bar(bin_center, hist, width=1, alpha=0.5 ,title="r distribution", ylim=(0, 0.1))
    push!(p, tmp)
end

In [None]:
plot(p...)

In [None]:
#二面角ϕ,ψに関連するパラメータはn,θ_zero,kの３つ
#最適化を目指すのはk

abstract type AbstractParam end

struct ParamCPU{T<:AbstractFloat}<:AbstractParam
    n::Vector{T}
    theta_zero::Vector{T}
    k::Vector{T}
end

struct Param{T<:AbstractFloat}<:AbstractParam
    n::CuArray{T}
    theta_zero::CuArray{T}
    k::CuArray{T}
end

function init_Param(ff_array::AbstractVector, gpu::Bool=false, T::DataType=Float64)
    n = T[]
    theta_zero = T[]
    k = T[]
    
    for line in ff_array
        #println(line)      
        if(occursin(r"type.*", line[1]))
            continue
        elseif(occursin(r"periodicity.*", line[1]))
            push!(n, parse(T, line[2]))
        elseif(occursin(r"phase.*", line[1]))
            push!(theta_zero, parse(T, line[2]))
        elseif(occursin(r"k.*", line[1]))
            push!(k, parse(T, line[2]))
        else
            println("ERROR")
        end
    end
    
    if(gpu) 
        n = adapt(CuArray, n)
        theta_zero = adapt(CuArray, theta_zero)
        k = adapt(CuArray, k)
        ff_param = Param(n, theta_zero, k)
        return ff_param
    end
    
    ff_param = ParamCPU(n, theta_zero, k)
    return ff_param
end

In [None]:
#力場ファイルから対象のパラメータを返す関数
function input_ff(ff_filepath; gpu=false::Bool)
    xml = readxml(ff_filepath)
    xmlroot = root(xml)
    
    children = elements(xmlroot)
    children_name = nodename.(children)
    torsion_index = children_name .== "PeriodicTorsionForce"
    torsion = children[torsion_index][1]
    
    phi_atom_type = ["protein-C", "protein-N", "protein-CX", "protein-C"]
    psi_atom_type = ["protein-N", "protein-CX", "protein-C", "protein-N"]
    
    ff_phi = []
    ff_psi = []
    for ff_params in eachelement(torsion)
        params_name = [nodename(i) for i in eachattribute(ff_params)]
        params_content = [nodecontent(i) for i in eachattribute(ff_params)]
        atom_type = [params_content[params_name .== "type$(i)"][1] for i in 1:4]
        #println(atom_type)
        if(atom_type == phi_atom_type)
            ff_phi = [(params_name[i], params_content[i]) for i in 1:length(params_name)]
        end
        if(atom_type == psi_atom_type)
            ff_psi = [(params_name[i], params_content[i]) for i in 1:length(params_name)]
        end
    end
    return init_Param(ff_phi, gpu), init_Param(ff_psi, gpu)
end

In [None]:
ff_phi_k = Array{ParamCPU}(undef, njobs)
ff_psi_k = Array{ParamCPU}(undef, njobs)

for i in 1:njobs
    ff_filepath = joinpath(ff_dir, "sim_$(i)/sim_$(i).xml")
    ff_phi_k[i], ff_psi_k[i] = input_ff(ff_filepath)
end

#input target ff
ff_phi_target, ff_psi_target = input_ff(target_ff_filepath)

In [None]:
#二面角のポテンシャルの計算
@inline function compute_dihedral_energy(theta::AbstractFloat, n::AbstractFloat, 
        theta_zero::AbstractFloat, k::AbstractFloat)
    return k * (1 + cos(n * theta - theta_zero))
end

function sum_compute_dihedral_energy(theta::AbstractFloat, ff_param::AbstractParam)
    return sum(compute_dihedral_energy.(theta, ff_param.n, ff_param.theta_zero, ff_param.k))
end

In [None]:
#compute u_kl
K = njobs
N_k = Array{Int}(undef, K)
for k in 1:K
    N_k[k] = size(dihedral_k[k], 1)
end
KBT = KB_kcalpermol * 300
beta = Float64(1.0/(KBT))

u_kl = Array{Array{Float64}}(undef, (K, K))
for k in 1:K
    for l in 1:K
        u_kl[k, l] = map(i -> beta * sum_compute_dihedral_energy(dihedral_k[k][i, 1], ff_phi_k[l]) 
            + beta * sum_compute_dihedral_energy(dihedral_k[k][i, 2], ff_psi_k[l]), 1:N_k[k])
    end
end

In [None]:
f_k = Float64.(MDToolbox.mbar(u_kl))

In [None]:
function compute_u_k_cpu(beta::T, dihedral_k, n_phi, theta_zero_phi,
        k_phi, n_psi, theta_zero_psi, k_psi) where {T}
    
    K = length(dihedral_k)
    N_k = Array{Int}(undef, K)
    for k in 1:K
        N_k[k] = size(dihedral_k[k], 1)
    end

    u_k = Vector{Vector{T}}(undef, K)
    for k in 1:K
        u_k[k] = zeros(T, N_k[k])

        for n in 1:N_k[k]
            for j in 1:length(k_phi)
                u_k[k][n] += beta * compute_dihedral_energy(dihedral_k[k][n, 1], n_phi[j], theta_zero_phi[j], k_phi[j])
            end
            for j in 1:length(k_psi)
                u_k[k][n] += beta * compute_dihedral_energy(dihedral_k[k][n, 2], n_psi[j], theta_zero_psi[j], k_psi[j])
            end
        end
    end
    return u_k
end

function ChainRulesCore.rrule(::typeof(compute_u_k_cpu), beta, dihedral_k, n_phi, theta_zero_phi,
        k_phi, n_psi, theta_zero_psi, k_psi)
    K = length(dihedral_k)
    N_k = Array{Int}(undef, K)
    for k in 1:K
        N_k[k] = size(dihedral_k[k], 1)
    end

    u_k = compute_u_k_cpu(beta, dihedral_k, n_phi, theta_zero_phi,
        k_phi, n_psi, theta_zero_psi, k_psi)

    function compute_u_k_pullback(dU)

        dk_phi = similar(k_phi)
        dk_psi = similar(k_psi)
        dk_phi .= 0.0
        dk_psi .= 0.0
        
        #println(k_phi)
        for i in 1:length(k_phi)
            if(i == 1 || i == 4)
                continue
            end
            for k in 1:K
                for n in 1:N_k[k]
                    dk_phi[i] += beta * (1.0 + cos(n_phi[i] * dihedral_k[k][n, 1] - theta_zero_phi[i])) * dU[k][n]
                end
            end
        end

        for i in 1:length(k_psi)
            if(i == 1)
                continue
            end
            for k in 1:K
                for n in 1:N_k[k]
                    dk_psi[i] += beta * (1.0 + cos(n_psi[i] * dihedral_k[k][n, 2] - theta_zero_psi[i])) * dU[k][n]
                end
            end
        end
        return NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dk_phi, NoTangent(), NoTangent(), dk_psi
    end

    return u_k, compute_u_k_pullback
end

function compute_u_k_cpu(beta, dihedral_k, ff_phi::AbstractParam, ff_psi::AbstractParam)
    return compute_u_k_cpu(beta, dihedral_k, ff_phi.n, ff_phi.theta_zero, ff_phi.k,
        ff_psi.n, ff_psi.theta_zero, ff_psi.k)
end

In [None]:
function _mbar_weight(u_kl, f_k, u_k=nothing)
    # K: number of umbrella windows
    K, L = size(u_kl)

    # N_k: number of data in k-th umbrella window
    N_k = zeros(Int64, K)
    for k = 1:K
        N_k[k] = length(u_kl[k, 1])
    end
    N_max = maximum(N_k)
    
    # conversion from array of array (u_kl) to array (u_kln)
    u_kln = zeros(Float64, K, K, N_max)
    for k = 1:K
        for l = 1:K
            u_kln[k, l, 1:N_k[k]] .= u_kl[k, l]
        end
    end

    # conversion from cell (u_k) to array (u_kn)
    u_kn = zeros(Float64, K, N_max)
    for k = 1:K
        if u_k === nothing
            u_kn[1, 1:N_k[k]] .= zero(Float64)
        else
            u_kn[k, 1:N_k[k]] .= u_k[k]
        end
    end

    log_w_kn = zeros(Float64, K, N_max)
    for k = 1:K
      log_w_kn[k, 1:N_k[k]] .= 1.0
    end
    idx = log_w_kn .> 0.5;

    log_w_kn = MDToolbox.mbar_log_wi_jn(N_k, f_k, u_kln, u_kn, K, N_max)
    log_w_n  = log_w_kn[idx]

    s = MDToolbox.logsumexp_1d(log_w_n)
    w_k = Vector{Vector{Float64}}(undef, K)
    for k = 1:K
      w_k[k] = exp.((log_w_kn[k, 1:N_k[k]] .- s))
    end

    return w_k
end

function ChainRulesCore.rrule(::typeof(_mbar_weight), u_kl, f_k, u_k)
    w_k = _mbar_weight(u_kl, f_k, u_k)

    function _mbar_weight_pullback(dw_k)
        # まず dw_k .* w_k を総和したスカラー T を計算
        T = 0.0
        for i in eachindex(w_k)
            for j in eachindex(w_k[i])
                T += dw_k[i][j] * w_k[i][j]
            end
        end

        # du_k の領域を用意し、各要素をまとめて計算
        du_k = similar(w_k)  # w_k と同じ「配列の配列」構造をもつ

        for i in eachindex(w_k)
            du_k[i] = similar(w_k[i])  # 内側の配列部分も同様に確保
            for j in eachindex(w_k[i])
                # du_k[i][j] = w_k[i][j] * (T - dw_k[i][j])
                du_k[i][j] = w_k[i][j] * (T - dw_k[i][j])
            end
        end

        # 戻り値は ( ∂u_kl無関係, ∂f_k無関係, ∂u_k無関係, du_k )
        return NoTangent(), ZeroTangent(), NoTangent(), du_k
    end

    return w_k, _mbar_weight_pullback
end

In [None]:
# 不確かさの計算
function asymptotic_covariance_matrix(w_k)
    W = hcat(w_k...)

    N = zeros(Float64, (size(W, 2), size(W, 2)))
    for i in 1:size(W, 2)
        N[i, i] = size(W, 1)
    end
    
    # 単位行列 I の作成
    _I = Matrix(I, size(W, 1), size(W, 1))
    
    # 中間計算: (I - W * N * W^T)
    M = _I - W * N * W'
    
    # 擬似逆行列の計算
    M_pseudo_inv = pinv(M)
    
    # 漸近的な共分散行列の計算
    Σ = W' * M_pseudo_inv * W
    return Σ
end

function compute_uncertainty(Σ)
    ans = deepcopy(Σ)
    for i in 1:size(Σ, 1)
        for j in 1:size(Σ, 2)
            ans[i, j] = Σ[i, i] - 2 * Σ[i, j] + Σ[j, j]
        end
    end
    return ans
end

In [None]:
# 計算時間が長いため、スライスの値を変えてカーネルと動かすと良い
#=
u_k = compute_u_k_cpu(beta, dihedral_k, ff_phi_target, ff_psi_target)
w_k = _mbar_weight(u_kl, f_k, u_k)
@time Σ = asymptotic_covariance_matrix(w_k)
uncertainty = compute_uncertainty(Σ)
=#

In [None]:
# ヒストグラムのクロスエントロピーLoss関数
function calculate_histogram_A(A_k, dihedral_k, f_k, u_kl, beta, 
        n_phi, theta_zero_phi, k_phi, n_psi, theta_zero_psi, k_psi)
    K = size(A_k, 1)
    u_k = compute_u_k_cpu(beta, dihedral_k, n_phi, theta_zero_phi,
        k_phi, n_psi, theta_zero_psi, k_psi)
    w_k = _mbar_weight(u_kl, f_k, u_k)

    hist = calc_histogram(A_k, w_k, rng=rng, nbin=nbin, density=true)
    return return hist
end

X_train = distance_k #距離の軌跡
nbin = 50
y_train, bin_edge, rng = calc_histogram(distance_target[1:10:end], nbin=nbin, density=true)

struct Energy{T<:AbstractArray}
    P::T #P[1] = k_phi, P[2] = k_psi
end

Flux.@functor Energy (P,)

(m::Energy)(X::AbstractArray) = calculate_histogram_A(X, dihedral_k, f_k,
    u_kl, beta, ff_phi_target.n, ff_phi_target.theta_zero, m.P[1],
    ff_psi_target.n, ff_psi_target.theta_zero, m.P[2])

loss(x, y) = Flux.Losses.crossentropy(x, y)

In [None]:
# 平均値のMSEのLoss関数
#=
function compute_average_property(A_k, dihedral_k, f_k, u_kl, beta, 
        n_phi, theta_zero_phi, k_phi, n_psi, theta_zero_psi, k_psi)
    K = size(A_k, 1)
    u_k = compute_u_k_cpu(beta, dihedral_k, n_phi, theta_zero_phi,
        k_phi, n_psi, theta_zero_psi, k_psi)
    w_k = _mbar_weight(u_kl, f_k, u_k)

    #=
    A_target = 0.0
    for k in 1:K
        A_target += sum(w_k[k] .* A_k[k])
    end
    =#
    A_target = sum(sum.([w_k[k] .* A_k[k] for k in 1:length(A_k)]))
    return A_target
end

X_train = distance_k #距離の軌跡
y_train = mean(distance_target) #ターゲットの距離の平均

struct Energy{T<:AbstractArray}
    P::T #P[1] = k_phi, P[2] = k_psi
end

Flux.@functor Energy (P,)

(m::Energy)(X::AbstractArray) = compute_average_property(X, dihedral_k, f_k,
    u_kl, beta, ff_phi_target.n, ff_phi_target.theta_zero, m.P[1],
    ff_psi_target.n, ff_psi_target.theta_zero, m.P[2])

loss(x, y) = sum((x .- y) .^ 2)
=#

In [None]:
# ksdensityのklダイバージェンスのLoss関数
#=
num_points = 100
x_grid, density_estimate = kde_estimate(distance_target[1:100:end], gaussian_kernel, num_points=num_points)

_, y_train = kde_estimate(distance_target[1:10:end], gaussian_kernel, x_grid=x_grid) #ターゲットの距離の分布
function compute_distribution_property(A_k, dihedral_k, f_k, u_kl, beta, 
        n_phi, theta_zero_phi, k_phi, n_psi, theta_zero_psi, k_psi)
    K = size(A_k, 1)
    u_k = compute_u_k_cpu(beta, dihedral_k, n_phi, theta_zero_phi,
        k_phi, n_psi, theta_zero_psi, k_psi)
    w_k = _mbar_weight(u_kl, f_k, u_k)
    
    density_estimate = kde_estimate_cpu(A_k, w_k, gaussian_kernel, x_grid=x_grid, bandwidth=bandwidth)
    #density_estimate = kde_estimate_cpu(A_k, w_k, gaussian_kernel, x_grid=x_grid)
    return density_estimate
end
struct Energy{T<:AbstractArray}
    P::T #P[1] = k_phi, P[2] = k_psi
end

Flux.@functor Energy (P,)

(m::Energy)(X::AbstractArray) = compute_distribution_property(X, dihedral_k, f_k,
    u_kl, beta, ff_phi_target.n, ff_phi_target.theta_zero, m.P[1],
    ff_psi_target.n, ff_psi_target.theta_zero, m.P[2])

# 確率分布を規格化する関数
function normalize(p)
    return p / sum(p)
end

# KLダイバージェンスを計算する関数
function kl_divergence(p, q)
    # 分布を規格化
    p = normalize(p)
    q = normalize(q)
    
    # KLダイバージェンスを計算
    return sum(p .* log.(p ./ q))
end

function ChainRulesCore.rrule(::typeof(kl_divergence), p, q)
    # 元の関数の値
    y = kl_divergence(p, q)

    # 微分ルール
    function kl_divergence_pullback(Δ)
        p = normalize(p)
        q = normalize(q)
        ∂p = Δ * (log.(p ./ q) .+ 1 .- sum(p .* log.(p ./ q)))
        ∂q = -Δ * (p ./ q) ./ length(q)
        return NoTangent(), ∂p, ∂q
    end

    return y, kl_divergence_pullback
end
loss(x, y) = kl_divergence(x, y)
=#

In [None]:
println("ff_phi_target.k = $(ff_phi_target.k)")
println("ff_phi_k[1].k = $(ff_phi_k[1].k)")
println()
println("ff_psi_target.k = $(ff_psi_target.k)")
println("ff_psi_k[1].k = $(ff_psi_k[1].k)")

In [None]:
#sanity check
@show m = Energy([ff_phi_target.k, ff_psi_target.k])
loss(m(X_train), y_train)

In [None]:
@show m = Energy([deepcopy(ff_phi_k[1].k), deepcopy(ff_psi_k[1].k)])
loss(m(X_train), y_train)

In [None]:
@time g = Flux.gradient(m -> loss(m(X_train), y_train), m)[1]

In [None]:
function print_progress(epoch, loss, m)
    println("Epoch: $(epoch), loss : $(loss)")
end

In [None]:
# 選択したLossに応じて、適切にepoch数と学習率を定める
k_phi_estimated_array = []
k_psi_estimated_array = []
n_eff_array = []
uncertainty_array = []

loss_train = []
nepoch = 100
learning_rate = 0.01
println("Initial loss: $(loss(m(X_train), y_train))")
println("Initial param: $(m.P[1])")

m = Energy([deepcopy(ff_phi_k[1].k), deepcopy(ff_psi_k[1].k)])
t = Flux.Optimisers.setup(Adam(learning_rate), m)
loss_train = []
g1 = []
g2 = []
@time for epoch in 1:nepoch
    g = gradient(m -> loss(m(X_train), y_train), m)[1]
    Flux.Optimisers.update!(t, m, g)

    L = loss(m(X_train), y_train)
    push!(loss_train, L)

    push!(k_phi_estimated_array, deepcopy(m.P[1]))
    push!(k_psi_estimated_array, deepcopy(m.P[2]))

    push!(g1, g.P[1])
    push!(g2, g.P[2])

    #=
    u_k = compute_u_k_cpu(beta, dihedral_k, ff_phi_target.n, ff_phi_target.theta_zero, m.P[1],
        ff_psi_target.n, ff_psi_target.theta_zero, m.P[2])
    w_k = _mbar_weight(u_kl, f_k, u_k)
    =#
    #=
    Σ = asymptotic_covariance_matrix(w_k)
    uncertainty = compute_uncertainty(Σ)
    push!(uncertainty_array, uncertainty)
    =#
    
    if (epoch % (nepoch / 10) == 0)
        print_progress(epoch, L, m)
    end
end

In [None]:
plot(loss_train,
     xlabel = "Epoch",              # x軸ラベル
     ylabel = "Loss",               # y軸ラベル
     #title  = "Training Loss per Epoch",  # タイトル
     label  = "Training Loss",      # 凡例に表示される名前
     lw     = 2,                    # 線の太さ
     framestyle=:box,
     legend = :topright,
     size=(400, 300),
     dpi=900)
#savefig("./figure/histogram_2_11_loss.png")

In [None]:
phi_dif = map(i -> (k_phi_estimated_array[i] - ff_phi_target.k) .^ 2, 1:nepoch) |> cpu
psi_dif = map(i -> (k_psi_estimated_array[i] - ff_psi_target.k) .^ 2, 1:nepoch) |> cpu

#二乗誤差
phi_squared_error = map(x -> sum(x), phi_dif)
psi_squared_error = map(x -> sum(x), psi_dif)

phi_dif = [[phi_dif[j][i] for j in 1:nepoch] for i in 1:size(ff_phi_target.k, 1)]
psi_dif = [[psi_dif[j][i] for j in 1:nepoch] for i in 1:size(ff_psi_target.k, 1)]

plot(phi_squared_error, xlabel="Epoch", title="Phi Squared Error")

In [None]:
plot(psi_squared_error, xlabel="Epoch", title="Psi Squared Error")

In [None]:
phi_estimated = deepcopy(m.P[1])
psi_estimated = deepcopy(m.P[2])

## 推定されたパラメータからxmlファイルを作成

In [None]:
phi_atom_type = ["protein-C", "protein-N", "protein-CX", "protein-C"]
psi_atom_type = ["protein-N", "protein-CX", "protein-C", "protein-N"]
function create_ff(ff_input_filepath, ff_output_filepath)
    input_xml = readxml(ff_input_filepath)
    output_xml = deepcopy(input_xml)
    xmlroot = root(output_xml)
    children = elements(xmlroot)
    children_name = nodename.(children)
    torsion_index = children_name .== "PeriodicTorsionForce"
    torsion = children[torsion_index][1] #これでtorsion内の要素にアクセスできる

    for ff_params in eachelement(torsion)
        params_name = [nodename(i) for i in eachattribute(ff_params)]
        params_content = [nodecontent(i) for i in eachattribute(ff_params)]
        atom_type = [params_content[params_name .== "type$(i)"][1] for i in 1:4]
        if atom_type == phi_atom_type
            for a in eachattribute(ff_params)
                if occursin(r"k.*", a.name)
                    #ここにkを変える処理を入力
                    index = parse(Int, match(r"\d+", a.name).match)
                    k_updated = phi_estimated[index]
                    a.content = string(k_updated)
                end
            end
            params_name = [nodename(i) for i in eachattribute(ff_params)]
            params_content = [nodecontent(i) for i in eachattribute(ff_params)]
        end
        #ここから処理を書き忘れていたため追加した部分
        if atom_type == psi_atom_type
            for a in eachattribute(ff_params)
                if occursin(r"k.*", a.name)
                    #ここにkを変える処理を入力
                    index = parse(Int, match(r"\d+", a.name).match)
                    k_updated = psi_estimated[index]
                    a.content = string(k_updated)
                end
            end
            params_name = [nodename(i) for i in eachattribute(ff_params)]
            params_content = [nodecontent(i) for i in eachattribute(ff_params)]
        end
        #ここまで
    end
    write(ff_output_filepath, output_xml)
end

In [None]:
estimated_dir = "estimated_dihedral"
isdir(estimated_dir) || mkdir(estimated_dir)
ff_output_filepath = joinpath(estimated_dir, "estimated.xml")
isfile(ff_output_filepath) || create_ff(target_ff_filepath, ff_output_filepath)

## 推定されたパラメータからシミュレーションを流す

In [None]:
nsteps = 2_500_000_000 
gpu_id = "n4"

function run_job(ff_filepath, traj_filepath, log_filepath, i)
    sbatch_file = "temp_$(i).sh"
    
    open(sbatch_file, "w") do of
        println(of, "#!/bin/bash")
        println(of, "#SBATCH -p all")
        println(of, "#SBATCH -J sim$(i) # job name")
        println(of, "#SBATCH -n 1  # num of total mpi processes")
        println(of, "#SBATCH -c 1  # num of threads per mpi processes")
        println(of, "#SBATCH -o $(log_filepath)")
        #println(of, "python sim.py $(pdb_filepath) $(ff_filepath) $(traj_filepath) $(nsteps)")
        println(of, "time python sim.py $(pdb_filepath) $(ff_filepath) $(traj_filepath) $(nsteps)")
    end
    sleep(5)
    run(`sbatch --gpus-per-node=1 -w $(gpu_id) $(sbatch_file)`)
    sleep(5)
    rm(sbatch_file)
end

In [None]:
#=
ff_filepath = ff_output_filepath

traj_filepath = joinpath(estimated_dir, "estimated_dihedral")
log_filepath = joinpath(sim_dir, "estimated_dihedral")
run_job(ff_filepath, traj_filepath, log_filepath, 1)
=#