In [None]:
using Flux, Plots
using ChainRulesCore
using Statistics
using Random
using MDToolbox
using BenchmarkTools
using EzXML
using LinearAlgebra
using SparseArrays
using Dates
using Distributed
using FiniteDifferences
ENV["COLUMNS"] = 130

In [None]:
pdb_filepath = "./alanine-dipeptide-nowater.pdb"
top = readpdb(pdb_filepath)
traj_dir = "./sim_coulomb"
ff_dir = "./sim_coulomb"
njobs = 10

target_traj_path = "./target/sim_target.dcd"
target_ff_filepath = "./data/amber14/protein.ff14SB.xml" 

In [None]:
atom_list = [5 17] #距離を計算する原子のペア
atom_pairs = []
for i in 1:length(atom_list)
    for j in i+1:length(atom_list)
        push!(atom_pairs, [atom_list[i] atom_list[j]])
    end
end

In [None]:
slice = 1000
#pythonのmdtrajのインデックスに+1してる
phi_indices = [5, 7, 9, 15]   #二面角ϕ
psi_indices = [7, 9, 15, 17]  #二面角ψ

#xmlファイルでのatom type
phi_atom_type = ["C", "N", "CT", "C"]
psi_atom_type = ["N", "CT", "C", "N"]

In [None]:
epsilon_0 = 1.0
coulomb14scale = 0.8333333333333334
lj14scale = 0.5

In [None]:
atomname_list = top.atomname
resname_list = top.resname

In [None]:
ff_atom = ["protein-HC", "protein-CT", "protein-HC", "protein-HC", "protein-C", 
    "protein-O", "protein-N", "protein-H", "protein-CX", "protein-H1", "protein-CT", 
    "protein-HC", "protein-HC", "protein-HC", "protein-C", "protein-O", "protein-N", "protein-H", 
    "protein-CT", "protein-H1", "protein-H1", "protein-H1"]

In [None]:
bonded_1_2 = 
[(2, 5),
 (5, 6),
 (1, 2),
 (2, 3),
 (2, 4),
 (5, 7),
 (9, 15),
 (15, 16),
 (9, 11),
 (9, 10),
 (7, 9),
 (11, 12),
 (11, 13),
 (11, 14),
 (7, 8),
 (15, 17),
 (19, 20),
 (19, 21),
 (19, 22),
 (17, 19),
 (17, 18)]
bonded_1_2 = sort(bonded_1_2)

In [None]:
# グラフを構築するための関数
function build_graph(edges::Vector{Tuple{Int, Int}})
    graph = Dict{Int, Vector{Int}}()
    for (u, v) in edges
        if !haskey(graph, u)
            graph[u] = Vector{Int}()
        end
        if !haskey(graph, v)
            graph[v] = Vector{Int}()
        end
        push!(graph[u], v)
        push!(graph[v], u)
    end
    return graph
end

# 幅優先探索（BFS）を使用して各頂点からの距離を計算
function bfs_distances(graph::Dict{Int, Vector{Int}}, start::Int)
    distances = Dict{Int, Int}()
    queue = [(start, 0)]
    visited = Set{Int}()
    while !isempty(queue)
        (v, d) = popfirst!(queue)
        if v in visited
            continue
        end
        visited = push!(visited, v)
        distances[v] = d
        for neighbor in graph[v]
            if !(neighbor in visited)
                push!(queue, (neighbor, d + 1))
            end
        end
    end
    return distances
end

# 指定された距離 n の頂点ペアを見つける関数
function find_pairs_with_distance(edges::Vector{Tuple{Int, Int}}, n::Int)
    graph = build_graph(edges)
    pairs = Set{Tuple{Int, Int}}()
    for vertex in keys(graph)
        distances = bfs_distances(graph, vertex)
        for (v, d) in distances
            if d == n
                pairs = push!(pairs, (min(vertex, v), max(vertex, v)))
            end
        end
    end
    return collect(pairs)
end

bonded_1_3 = sort(find_pairs_with_distance(bonded_1_2, 2))
bonded_1_4 = sort(find_pairs_with_distance(bonded_1_2, 3))

In [None]:
function tuples_to_matrix(tuples::Vector{Tuple{Int, Int}})
    # タプルの数と各タプルの長さを取得
    num_rows = length(tuples)
    num_cols = length(first(tuples))

    # 行列を初期化
    matrix = zeros(Int, num_rows, num_cols)

    # 行列にタプルの値を埋め込む
    for i in 1:num_rows
        for j in 1:num_cols
            matrix[i, j] = tuples[i][j]
        end
    end

    return matrix
end

nonbonded = Tuple{Int64, Int64}[]
#natom = size(ta, 2)
natom = size(top, 2)
exception_parameters = vcat(bonded_1_2, bonded_1_3, bonded_1_4)
for i in 1:natom
    for j in (i+1):natom
        if !((i, j) in exception_parameters)
            push!(nonbonded, (i, j))
        end
    end
end

nonbonded_matrix = tuples_to_matrix(nonbonded)
bonded_14pair_matrix = tuples_to_matrix(bonded_1_4)

In [None]:
function safe_acos(x::Float64)
    # xの値を[-1, 1]の範囲にクランプする
    return acos(clamp(x, -1.0, 1.0))
end

function _compute_dihedral(ta1::TrjArray{T, U}, ta2::TrjArray{T, U}, ta3::TrjArray{T, U}, ta4::TrjArray{T, U})::Vector{T} where {T, U}
    nframe = ta1.nframe
    com1 = centerofmass(ta1, isweight=true)
    com2 = centerofmass(ta2, isweight=true)
    com3 = centerofmass(ta3, isweight=true)
    com4 = centerofmass(ta4, isweight=true)
    a = zeros(T, nframe)
    # Threads.@threads for iframe in 1:nframe
    for iframe in 1:nframe
        d1 = [com1.xyz[iframe, 1] - com2.xyz[iframe, 1]; com1.xyz[iframe, 2] - com2.xyz[iframe, 2]; com1.xyz[iframe, 3] - com2.xyz[iframe, 3]]
        d2 = [com3.xyz[iframe, 1] - com2.xyz[iframe, 1]; com3.xyz[iframe, 2] - com2.xyz[iframe, 2]; com3.xyz[iframe, 3] - com2.xyz[iframe, 3]]
        d3 = [com3.xyz[iframe, 1] - com4.xyz[iframe, 1]; com3.xyz[iframe, 2] - com4.xyz[iframe, 2]; com3.xyz[iframe, 3] - com4.xyz[iframe, 3]]
        m1 = cross(d1, d2)
        m2 = cross(d2, d3)
        a[iframe] = safe_acos(dot(m1, m2)/(norm(m1)*norm(m2)))
        rotdirection = dot(d2,cross(m1,m2))
        if rotdirection < zero(T)
            a[iframe] = -a[iframe]
        end
    end
    a .= (a ./ pi) .* T(180)
end

In [None]:
top = readpdb(pdb_filepath)
ta = mdload(target_traj_path)

distance_target = []
for i in 1:length(atom_pairs)
    d = compute_distance(ta, atom_pairs[i]) ./ 10
    push!(distance_target, d)
end
distance_target = hcat(distance_target...)

nonbonded_distancemap_target = compute_distance(ta, nonbonded_matrix) ./ 10
bonded_14pair_distancemap = compute_distance(ta, bonded_14pair_matrix) ./ 10  

In [None]:
"""
    function calc_histogram(data::AbstractVector;
                            rng=nothing,
                            bin_width=0.005, # nm
                            nbin=nothing,
                            density::Bool=false,
                            weight::AbstractArray=ones(length(data)))

Calculate a histogram of the input data `data`.

# Arguments
- `data::AbstractVector`: Input data vector.
- `rng::Tuple{Real, Real}`: Range of values to consider for the histogram. If not provided, the minimum and maximum values of `data` will be used.
- `bin_width::Real=0.005`: Width of each histogram bin.
- `nbin::Integer`: Number of bins for the histogram. If not provided, it will be automatically calculated based on `rng` and `bin_width`.
- `density::Bool=false`: If `true`, normalize the histogram to form a probability density.
- `weight::AbstractArray=[]`: Optional weights associated with each data point.

# Returns
- `hist::Array{Float64,1}`: Counts of data points in each bin.
- `bin_edge::Array{Float64,1}`: Edges of the bins.

# Examples
```julia-repl
julia> data = randn(1000)  # Generate random data
julia> hist, bin_edge = calc_histogram(data, rng=(-3, 3), bin_width=0.1, density=true)
```
"""
function calc_histogram(data::AbstractArray;
                        rng=nothing,
                        bin_width=0.005, # nm
                        nbin=nothing,
                        density::Bool=false,
                        weight::AbstractArray=ones(length(data)))
    
    # If range is not specified, use the range of the data
    if rng == nothing
        rng = (minimum(data), maximum(data))
    end
    # If data falls outside the specified range, ignore it
    data = filter(x -> rng[1] <= x && x <= rng[2], data)
    
    # If nbin is not specified, calculate it based on the bin width
    if nbin == nothing
        nbin = ceil(Int, (rng[2] - rng[1]) / bin_width)
    else
        # Recalculate bin width based on nbin
        bin_width = (rng[2] - rng[1]) / nbin
    end
    
    # Initialize histogram bins
    hist = zeros(Float64, nbin)
    
    # Calculate bin edges
    bin_edge = range(rng[1], rng[2], nbin+1) |> Vector
    
    # Calculate bin centers
    bin_center = (bin_edge[1:end-1] + bin_edge[2:end]) / 2

    min_value = minimum(data)
    # Fill histogram bins
    for (val, w) in zip(data, weight)
        #bin_index = argmin(abs.(bin_center .- val))
        bin_index = min(floor(Int, (val - min_value) / bin_width) + 1, nbin)    
        hist[bin_index] += w
    end
    
    # Normalize by total weight if density is true
    if density
        total_weight = sum(weight)
        hist ./= total_weight
    end
    
    return hist, bin_edge, rng
end

In [None]:
function calc_histogram(data_k::Array{<:AbstractArray},
                        weight_k::Array{<:AbstractArray};
                        rng=nothing,
                        bin_width=0.005, # nm
                        nbin=nothing,
                        density::Bool=false)

    data = vcat(data_k...)
    weight = vcat(weight_k...)
    hist, bin_edge, _ = calc_histogram(data, rng=rng, bin_width=bin_width, nbin=nbin, density=density, weight=weight)
    
    return hist
end

function ChainRulesCore.rrule(::typeof(calc_histogram), 
                        data_k::Array{<:AbstractArray},
                        weight_k::Array{<:AbstractArray};
                        rng=nothing,
                        bin_width=0.005, # nm
                        nbin=nothing,
                        density::Bool=false)

    data = vcat(data_k...)
    weight = vcat(weight_k...)
    hist, bin_edge, rng = calc_histogram(data, rng=rng, bin_width=bin_width, nbin=nbin, density=density, weight=weight)
    K = length(data_k)
    N_k = Array{Int}(undef, K)
    for k in 1:K
        N_k[k] = size(data_k[k], 1)
    end
    function calc_histogram_pullback(dP)
        dweight = similar(weight)
        dweight .= 0.0

        bin_center = (bin_edge[1:end-1] + bin_edge[2:end]) / 2
        #println(bin_center)
        #println(length(dweight))
        min_value = minimum(data)
        for i in 1:length(dweight)
            #bin_index = argmin(abs.(bin_center .- data[i]))
            bin_index = min(floor(Int, (data[i] - min_value) / bin_width) + 1, nbin)
            dweight[i] += 1.0 * dP[bin_index]
        end

        #println(dweight)
        if density
            dweight = dweight ./ sum(weight)
        end
        
        dweight_k = similar(weight_k)

        istart = 1
        for i in 1:length(data_k)
            iend = istart + N_k[i] - 1
            dweight_k[i] = dweight[istart:iend]
            istart = iend + 1
        end
        
        return NoTangent(), NoTangent(), dweight_k, NoTangent(), NoTangent(), NoTangent(), NoTangent()
    end

    return hist, calc_histogram_pullback
end

In [None]:
nbin = 30
data = distance_target[1:10:end]
@time hist, bin_edge, rng = calc_histogram(data, nbin=nbin, density=true)
bin_center = (bin_edge[1:end-1] .+ bin_edge[2:end]) ./ 2
bar(bin_center, hist, width=1, alpha=0.5 ,title="r distribution", ylim=(0, 0.1))

In [None]:
p = []
for i in 1:length(distance_k)
    nbin = 30
    data = vcat(distance_k[i])
    @time hist, bin_edge, rng = calc_histogram(data, nbin=nbin, density=true)
    bin_center = (bin_edge[1:end-1] .+ bin_edge[2:end]) ./ 2
    tmp = bar(bin_center, hist, width=1, alpha=0.5 , title = "k=$(i)", ylim=(0, 0.1))
    push!(p, tmp)
end
plot(p..., layout=(4, 3), size=(1000, 800))

In [None]:
function kde_estimate(data::AbstractVector; weight::AbstractVector = ones(length(data)), bandwidth=nothing, num_points::Int=1000)
    # If bandwidth is not specified, estimate it using Silverman's rule
    if isnothing(bandwidth)
        n = length(data)
        s = std(data)
        IQR = quantile(data, 0.75) - quantile(data, 0.25)
        bandwidth = 0.9 * min(s, IQR / 1.34) / n^(1/5)
    end
    # Compute kernel density estimate
    density_estimate = zeros(num_points)
    x_grid_dense = range(minimum(data), maximum(data), length=num_points)
    
    for i in 1:num_points
        x = x_grid_dense[i]
        kernel_sum = 0.0
        for (val, w) in zip(data, weight)
            kernel_sum += w * exp(-((x - val) / bandwidth)^2 / 2) / (bandwidth * sqrt(2 * π))
        end
        density_estimate[i] = kernel_sum / sum(weight)
    end
    
    return x_grid_dense, density_estimate
end

In [None]:
@timed x_grid_dense, density_estimate = kde_estimate(vec(distance_target))
plot(x_grid_dense, density_estimate, label="Kernel Density Estimate", xlabel="x", ylabel="Density", linewidth=2)

In [None]:
atom_pairs

In [None]:
distance_k = Array{Array{Float64}}(undef, njobs)
nonbonded_distancemap_k = Array{Matrix}(undef, njobs)
bonded_14pair_distancemap_k = Array{Matrix}(undef, njobs)

for i in 1:njobs
    traj_filepath = joinpath(traj_dir, "sim_$(i)/traj_$(i).dcd")
    ta = mdload(traj_filepath, top=top)
    ta = ta[1:slice:end]
    distance = []
    for i in 1:length(atom_pairs)
        d = compute_distance(ta, atom_pairs[i]) ./ 10
        push!(distance, d)
    end
    distance_k[i] = hcat(distance...)
    nonbonded_distancemap_k[i] = compute_distance(ta, nonbonded_matrix) ./ 10
    bonded_14pair_distancemap_k[i] = compute_distance(ta, bonded_14pair_matrix) ./ 10  
end

In [None]:
function input_ff(ff_filepath)
    charge = zeros(Float64, length(atomname_list))
    xml = readxml(ff_filepath)
    xmlroot = root(xml)
    children = elements(xmlroot)
    children_name = nodename.(children)
    residues_indes = children_name .== "Residues"
    residues = children[residues_indes][1]
    
    for residue in eachelement(residues)
        #println(residue)
        #println([nodecontent(i) for i in eachattribute(residue)])
        resname = [nodecontent(i) for i in eachattribute(residue)][1]
        if resname in resname_list
            #println(resname)
            atoms = atomname_list[resname_list .== resname]
            #println(atoms)
            for element_residue in eachelement(residue)   
                #println([nodename(i) for i in eachattribute(element_residue)])
                #println([nodecontent(i) for i in eachattribute(element_residue)])
                node_name = [nodename(i) for i in eachattribute(element_residue)]
                node_content = [nodecontent(i) for i in eachattribute(element_residue)]
    
                if length(node_content[node_name .== "name"]) == 0
                    continue
                end
                atomname = node_content[node_name .== "name"][1]
                #println(atomname)
                if atomname in atoms
                    #println(atomname)
                    #println(node_content)
                    q = node_content[node_name .== "charge"][1]
                    #println(atomname_list .== atomname .&& resname_list .== resname)
                    
                    charge[atomname_list .== atomname .&& resname_list .== resname] .= parse(Float64, q)
                end               
            end
        end
    end
    return charge
end

In [None]:
ff_charge_k = Array{Array{Float64}}(undef, njobs)

for i in 1:njobs
    ff_filepath = joinpath(ff_dir, "sim_$(i)/sim_$(i).xml")
    ff_charge_k[i] = input_ff(ff_filepath)
end

#input target ff
ff_charge_target = input_ff(target_ff_filepath)

In [None]:
for i in 1: length(atomname_list)
    println(i, " ", atomname_list[i])
end

In [None]:
mean.(distance_k)

In [None]:
#クーロン相互作用の定義
function compute_coulomb_interaction_without_cutoff(q1, q2, epsilon_0, r)
    return 1 / (4 * π * epsilon_0) * q1 * q2 / r
end

function compute_coulomb_interaction_14pair_without_cutoff(q1, q2, epsilon_0, r, coulomb14scale)
    return compute_coulomb_interaction_without_cutoff(q1, q2, epsilon_0, r) * coulomb14scale
end

In [None]:
function compute_colomb_interaction(charge_array, nonbonded_pair_distance, bonded_14pair_distance)
    U = 0.0
    for i in 1:length(nonbonded)
        atom1_index, atom2_index = nonbonded[i]
        q1 = charge_array[atom1_index]
        q2 = charge_array[atom2_index]
        r = nonbonded_pair_distance[i]
    
        U += compute_coulomb_interaction_without_cutoff(q1, q2, epsilon_0, r)
    end
    for i in 1:length(bonded_1_4)
        atom1_index, atom2_index = bonded_1_4[i]
        q1 = charge_array[atom1_index]
        q2 = charge_array[atom2_index]
        r = bonded_14pair_distance[i]
    
        U += compute_coulomb_interaction_14pair_without_cutoff(q1, q2, epsilon_0, r, coulomb14scale)
    end

    return U
end

In [None]:
#compute u_kl
K = njobs
N_k = Array{Int}(undef, K)
for k in 1:K
    N_k[k] = length(distance_k[k])
end
KBT = KB_kcalpermol * 300
beta = Float64(1.0/(KBT))

u_kl = Array{Array{Float64}}(undef, (K, K))
for k in 1:K
    for l in 1:K
        u_kl[k, l] = map(i -> beta * compute_colomb_interaction(ff_charge_k[l], nonbonded_distancemap_k[k][i, :],
                bonded_14pair_distancemap_k[k][i, :]), 1:N_k[k])
    end
end

In [None]:
f_k = Float64.(MDToolbox.mbar(u_kl))

In [None]:
nonbonded_distancemap_k[1]

In [None]:
function compute_u_k_cpu(beta::T, nonbonded_distancemap_k, bonded_14pair_distancemap_k, charge_target) where {T}
    
    K = length(nonbonded_distancemap_k)
    N_k = Array{Int}(undef, K)
    for k in 1:K
        N_k[k] = size(nonbonded_distancemap_k[k], 1)
    end

    u_k = Vector{Vector{T}}(undef, K)
    for k in 1:K
        u_k[k] = zeros(T, N_k[k])
        
        for n in 1:N_k[k]
            u_k[k][n] = beta * compute_colomb_interaction(charge_target, nonbonded_distancemap_k[k][n, :], bonded_14pair_distancemap_k[k][n, :])
        end
    end
    return u_k
end

function ChainRulesCore.rrule(::typeof(compute_u_k_cpu), beta::T, nonbonded_distancemap_k, 
        bonded_14pair_distancemap_k, charge_target) where {T}
    K = length(nonbonded_distancemap_k)
    N_k = Array{Int}(undef, K)
    for k in 1:K
        N_k[k] = size(nonbonded_distancemap_k[k], 1)
    end

    u_k = compute_u_k_cpu(beta, nonbonded_distancemap_k, bonded_14pair_distancemap_k, charge_target)
    function compute_u_k_pullback(dU)
        dq = similar(charge_target)
        dq .= 0.0
        for k in 1:K
            for n in 1:N_k[k]
                for i in 1:length(nonbonded)
                    atom1_index, atom2_index = nonbonded[i]
                    q1 = charge_target[atom1_index]
                    q2 = charge_target[atom2_index]
                    r = nonbonded_distancemap_k[k][n, :][i]
                    dq[atom1_index] += beta * compute_coulomb_interaction_without_cutoff(q1, q2, epsilon_0, r) / q1 * dU[k][n]
                    dq[atom2_index] += beta * compute_coulomb_interaction_without_cutoff(q1, q2, epsilon_0, r) / q2 * dU[k][n]
                end
                for i in 1:length(bonded_1_4)
                    atom1_index, atom2_index = bonded_1_4[i]
                    q1 = charge_target[atom1_index]
                    q2 = charge_target[atom2_index]
                    r = bonded_14pair_distancemap_k[k][n, :][i]
                    dq[atom1_index] += beta * compute_coulomb_interaction_14pair_without_cutoff(q1, q2, epsilon_0, r, coulomb14scale) / q1 * dU[k][n]
                    dq[atom2_index] += beta * compute_coulomb_interaction_14pair_without_cutoff(q1, q2, epsilon_0, r, coulomb14scale) / q2 * dU[k][n]
                end
            end
        end
        return NoTangent(), NoTangent(), NoTangent(), NoTangent(), dq
    end

    return u_k, compute_u_k_pullback
end

In [None]:
function _mbar_weight(u_kl, f_k, u_k=nothing)
    # K: number of umbrella windows
    K, L = size(u_kl)

    # N_k: number of data in k-th umbrella window
    N_k = zeros(Int64, K)
    for k = 1:K
        N_k[k] = length(u_kl[k, 1])
    end
    N_max = maximum(N_k)
    
    # conversion from array of array (u_kl) to array (u_kln)
    u_kln = zeros(Float64, K, K, N_max)
    for k = 1:K
        for l = 1:K
            u_kln[k, l, 1:N_k[k]] .= u_kl[k, l]
        end
    end

    # conversion from cell (u_k) to array (u_kn)
    u_kn = zeros(Float64, K, N_max)
    for k = 1:K
        if u_k === nothing
            u_kn[1, 1:N_k[k]] .= zero(Float64)
        else
            u_kn[k, 1:N_k[k]] .= u_k[k]
        end
    end

    log_w_kn = zeros(Float64, K, N_max)
    for k = 1:K
      log_w_kn[k, 1:N_k[k]] .= 1.0
    end
    idx = log_w_kn .> 0.5;

    log_w_kn = MDToolbox.mbar_log_wi_jn(N_k, f_k, u_kln, u_kn, K, N_max)
    log_w_n  = log_w_kn[idx]

    s = MDToolbox.logsumexp_1d(log_w_n)
    w_k = Vector{Vector{Float64}}(undef, K)
    for k = 1:K
      w_k[k] = exp.((log_w_kn[k, 1:N_k[k]] .- s))
    end

    return w_k
end

function ChainRulesCore.rrule(::typeof(_mbar_weight), u_kl, f_k, u_k)
    w_k = mbar_weight(u_kl, f_k, u_k)
    function mbar_weight_pullback(dw_k)
        du_k = deepcopy(w_k)
        for k = 1:length(w_k)
            for n = 1:length(w_k[k])
                du_k[k][n] = 0.0
                for l in 1:length(w_k)
                    for m in 1:length(w_k[l])
                        if( k == l && n == m)
                            du_k[k][n] += dw_k[l][m] * (- w_k[l][m] + w_k[l][m] ^ 2)
                        else
                            du_k[k][n] += dw_k[l][m] * (w_k[k][n] * w_k[l][m])
                        end
                    end
                end
            end
        end
        return NoTangent(), ZeroTangent(), NoTangent(), du_k
    end
    return w_k, mbar_weight_pullback
end

In [None]:
function compute_average_property(A_k, nonbonded_distancemap_k, bonded_14pair_distancemap_k, f_k, u_kl, beta, charge_target)
    K = size(A_k, 1)
    u_k = compute_u_k_cpu(beta, nonbonded_distancemap_k, bonded_14pair_distancemap_k, charge_target)
    w_k = _mbar_weight(u_kl, f_k, u_k)
    
    A_target = 0.0
    for k in 1:K
        A_target += sum(w_k[k] .* A_k[k])
    end

    return A_target
end

In [None]:
X_train = distance_k #距離の軌跡
y_train = mean(distance_target) #ターゲットの距離の平均

struct Energy{T<:AbstractArray}
    P::T #P[1] = k_phi, P[2] = k_psi
end

Flux.@functor Energy (P,)

(m::Energy)(X::AbstractArray) = compute_average_property(X, nonbonded_distancemap_k,
    bonded_14pair_distancemap_k, f_k, u_kl, beta, m.P)

loss(x, y) = Flux.Losses.mse(x, y)

In [None]:
#sanity check
m = Energy(deepcopy(ff_charge_target))
loss(m(X_train), y_train)

In [None]:
m = Energy(deepcopy(ff_charge_k[1]))
loss(m(X_train), y_train)

In [None]:
@time g = gradient(m -> loss(m(X_train), y_train), m)[1]

In [None]:
function print_progress(epoch, loss, m)
    println("Epoch: $(epoch), loss : $(loss)")
end

In [None]:
charge_estimated_array = []

In [None]:
initial_error = []
for i in 1:length(m.P)
    push!(initial_error, (m.P[i] - ff_charge_target[i]) ^ 2)
end

println("initial error")
for i in 1:length(initial_error)
    println(initial_error[i])
end

In [None]:
loss_train = []
nepoch = 5
learning_rate = 1e-1
println("Initial loss: $(loss(m(X_train), y_train))")
println("Initial param: $(m.P[1])")

t = Flux.Optimisers.setup(Adam(learning_rate), m)

@time for epoch in 1:nepoch
    g = gradient(m -> loss(m(X_train), y_train), m)[1]
    
    index = (resname_list .== "ACE") .& (occursin.(r"^HH.*", atomname_list)) 
    m.P[index] .= mean(m.P[index])    
    index = (resname_list .== "ALA") .& (occursin.(r"^HB.*", atomname_list))
    m.P[index] .= mean(m.P[index])   
    index = (resname_list .== "NME") .& (occursin.(r"^HH.*", atomname_list))
    m.P[index] .= mean(m.P[index])

    index = resname_list .== "ACE"
    m.P[index] .= m.P[index] .- mean(m.P[index])
    index = resname_list .== "ALA"
    m.P[index] .= m.P[index] .- mean(m.P[index])
    index = resname_list .== "NME"
    m.P[index] .= m.P[index] .- mean(m.P[index])


    Flux.Optimisers.update!(t, m, g)
    
    L = loss(m(X_train), y_train)
    push!(loss_train, L)

    push!(charge_estimated_array, m.P)

    if (epoch % (nepoch / 10) == 0)
        print_progress(epoch, L, m)
    end
    
end

In [None]:
plot(loss_train, legend=nothing, framestyle=:box, linewidth=2, title="Loss")
ylabel!("Loss")
xlabel!("Epoch")

In [None]:
after_train_error = []
for i in 1:length(m.P)
    push!(after_train_error, (m.P[i] - ff_charge_target[i]) ^ 2)
end

println("after train error")
for i in 1:length(after_train_error)
    println(after_train_error[i])
end

In [None]:
initial_error .- after_train_error

In [None]:
sum(initial_error .- after_train_error)

In [None]:
function asymptotic_covariance_matrix(w_k)
    W = hcat(w_k...)

    N = zeros(Float64, (size(W, 2), size(W, 2)))
    for i in 1:size(W, 2)
        N[i, i] = size(W, 1)
    end
    
    # 単位行列 I の作成
    _I = Matrix(I, size(W, 1), size(W, 1))
    
    # 中間計算: (I - W * N * W^T)
    M = _I - W * N * W'
    
    # 擬似逆行列の計算
    M_pseudo_inv = pinv(M)
    
    # 漸近的な共分散行列の計算
    Σ = W' * M_pseudo_inv * W
    return Σ
end

function compute_uncertainty(Σ)
    ans = deepcopy(Σ)
    for i in 1:size(Σ, 1)
        for j in 1:size(Σ, 2)
            ans[i, j] = Σ[i, i] - 2 * Σ[i, j] + Σ[j, j]
        end
    end
    return ans
end

In [None]:
u_k = compute_u_k_cpu(beta, nonbonded_distancemap_k, bonded_14pair_distancemap_k, ff_charge_target)
u_k = [u_k[i][1:100:end] for i in 1:length(u_k)]
w_k = _mbar_weight(u_kl, f_k, u_k)
@time Σ = asymptotic_covariance_matrix(w_k)
uncertainty = compute_uncertainty(Σ)