### 座標とポテンシャルエネルギーを訓練データとして、エネルギー関数のパラメータ($\epsilon$と$\sigma$)を学習する

### 必要なパッケージを読み込む

In [1]:
using Flux, Plots
using ChainRulesCore
using Statistics
using Random

### ポテンシャルエネルギー関数の定義

In [2]:
function compute_energy(coords1, coords2, epsilon, sigma)
    r = sqrt(sum((coords1 .- coords2).^2))
    U = epsilon * ((sigma/r)^12 - 2 * (sigma/r)^6)
    return U
end

compute_energy (generic function with 1 method)

In [3]:
function ChainRulesCore.rrule(::typeof(compute_energy), coords1, coords2, epsilon, sigma)
    dx = coords1 .- coords2
    r = sqrt(sum(dx.^2))
    U = epsilon * ((sigma / r)^12 - 2 * (sigma / r)^6) 
    
    noise = 10
    function compute_energy_pullback(dU)
        dcoords1 = -12 * epsilon * ((sigma^12 / r^14) - (sigma^6 / r^8)) .* dx .* dU
        dcoords2 = - dcoords1
        depsilon = ((sigma / r)^12 - 2 * (sigma / r)^6) * dU
        dsigma = 12 * epsilon * (sigma^11 / r^12 - sigma^5 / r^6) * dU 
        return NoTangent(), dcoords1, dcoords2, depsilon, dsigma
    end
    return U, compute_energy_pullback
end

### ポテンシャルエネルギー関数を持つFluxのカスタムレイヤの定義

In [4]:
struct Energy
    P::AbstractArray #P[1] = epsilon, P[2] = sigma
end

Flux.@functor Energy (P,)

(m::Energy)(coords) = compute_energy(coords..., m.P[1], m.P[2])

In [5]:
# カスタムレイヤでポテンシャルエネルギー値を計算できるかテスト
m = Energy([1.0, 1.0])
coords = ([0.0, 0.0, 0.0], [3.0, 0.0, 0.0])
m(coords)

-0.0027416025485425466

### 訓練データの作成

In [6]:
epsilon_true = 0.3
sigma_true = 3.0
nframe = 100

seed_value = 1234 #乱数固定
Random.seed!(seed_value)

X_train = []
y_train = []
for iframe = 1:nframe
    coords1 = randn(3) 
    coords2 = randn(3) 
    push!(X_train, (coords1, coords2))
    push!(y_train, compute_energy(coords1, coords2, epsilon_true, sigma_true))
end

In [7]:
sort(abs.(y_train), rev=true)

100-element Vector{Float64}:
     3.7393382587199235e8
     1.710397817833069e8
     2.644959624931498e7
     2.3602367950986173e7
     5.063294248546081e6
     2.428080873548686e6
     2.1500470102530373e6
 45153.00495316633
 19534.009087515184
 12516.534868073755
  9865.222442998134
  7021.695689971635
  6085.626957934721
     ⋮
     0.14732884760995343
     0.13949352355703026
     0.12880388362989648
     0.12711703086977033
     0.11866779492051184
     0.11537618042994668
     0.1094851473900713
     0.10356352985422172
     0.10041751898189936
     0.09141671781596213
     0.04478933956145604
     0.01244426500365124

### 勾配法で訓練してK(バネ定数)とR(平衡長)を推定

In [8]:
# Lossの定義
m = Energy([0.31, 3.1]) #かなり正解に近い値からスタート
loss(X, y) = Flux.Losses.mse(m.(X), y)
loss(X_train, y_train)

4.8137975983583106e14

In [9]:
# DataやOptimizerの定義
train_loader = Flux.Data.DataLoader((data=X_train, label=y_train), batchsize=10, shuffle=true)
ps = Flux.params(m)
opt = ADAM(1e-3)

Adam(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

In [10]:
# コールバック関数の定義
function print_callback(epoch, loss, ps)
    println("Epoch: $epoch, loss: $loss")
    println("param: $(ps[1])")
end

print_callback (generic function with 1 method)

In [None]:
# 訓練
loss_train = []

for epoch in 1:10001
    for (X, y) in train_loader
        gs = gradient(() -> loss(X, y), ps)
        Flux.Optimise.update!(opt, ps, gs)
    end
    push!(loss_train, loss(X_train, y_train))
    if epoch%1000 == 1
            print_callback(epoch, loss(X_train, y_train), ps)
    end
end

Epoch: 1, loss: 3.8378721133948856e14
param: [0.30468614702252633, 3.094685559913893]
Epoch: 1001, loss: 1.291585932157722e6
param: [0.2529149509030946, 3.042979105785854]
Epoch: 2001, loss: 4.864093207660576e8
param: [0.2576468230644103, 3.0381540087353547]
Epoch: 3001, loss: 13009.568431803325
param: [0.26232265627805096, 3.0337384692508658]


In [None]:
plot(loss_train, legend=nothing, framestyle=:box, linewidth=2)
ylabel!("Loss")
xlabel!("Epoch")

In [None]:
ps[1][1]

In [None]:
size(loss_train)

In [None]:
println("Ground-truth epsilon: ", epsilon_true)
println("Estimate epsilon:     ", m.P[1])

println("Ground-truth sigma: ", sigma_true)
println("Estimate sigma:     ", m.P[2])