In [39]:
using Plots
using Pkg
using NPZ
using LinearAlgebra
using BenchmarkTools

In [40]:
outfile = npzread("data_2.npz")

X0 = outfile["X0"]
X0_min = outfile["X0_min"]
X0_max = outfile["X0_max"]
X1 = outfile["X1"]
X1_min = outfile["X1_min"]
X1_max = outfile["X1_max"]
x_n = outfile["x_n"]
T = outfile["T"]
:OK

:OK

In [61]:
function fit_line_D(x, t)
    # Xの転置行列
    a = x' * x
    # 逆行列
    b = inv(a)
    # ↑* x'
    c = b * x'
    # ↑* t
    w = c * t
    
    return w
end

fit_line_D (generic function with 1 method)

In [111]:
function fit_line_D2(X, T)
    return (X'X)\X'T
end

function fit_line_D3(X, t)
    return X\t
end

fit_line_D3 (generic function with 1 method)

In [131]:
fit_line_D2(Xn, T)

3-element Vector{Float64}:
  0.4560653182304476
  1.0864048983734247
 89.04744652297302

In [132]:
fit_line_D3(Xn, T)

3-element Vector{Float64}:
  0.4560653182304782
  1.0864048983733943
 89.04744652297434

In [83]:
f = [one, identity]
W = fit_line_D2(Xn, T)
norm2(x) = (x, x)
T_ = Xn * W
norm2(T - T_)
X0\T

9.424962344087536

In [71]:
X2 =  ones(Int, size(X0)) 
Xn = [X0 X1 X2]
@time W = fit_line_D(Xn, T)
println("fit_line_D :",W)
@time W = fit_line_D2(Xn, T)
println("fit_line_D2:",W)

  0.000059 seconds (7 allocations: 2.703 KiB)
fit_line_D :[0.456065318230489, 1.086404898373413, 89.04744652297406]
  0.000019 seconds (5 allocations: 656 bytes)
fit_line_D2:[0.4560653182304476, 1.0864048983734247, 89.04744652297302]


numpyの場合  
Wall time: 140 µs （0.000140）  
解析解 ベクトル式： [ 0.45606532  1.0864049  89.04744652]  

In [8]:
function dmse_line2(x0, x1, t, w)
    # 𝑦=w0𝑥0 + w1𝑥1 + w2
    #y = w[0] * x0 + w[1] * x1 + w[2] 
    y = w[1] .* x0 .+ w[2] .* x1 .+ w[3] 
    
    #d_w0 = 2 * np.mean((y - t) * x0)  # w0で偏微分
    #d_w1 = 2 * np.mean((y - t) * x1)  # w1で偏微分
    #d_w2 = 2 * np.mean(y - t)         # w2で偏微分
    d_w0 = 2 .* mean((y .- t) .* x0)  # w0で偏微分
    d_w1 = 2 .* mean((y .- t) .* x1)  # w1で偏微分
    d_w2 = 2 .* mean(y .- t)         # w2で偏微分
    
    return d_w0, d_w1, d_w2
end


dmse_line2 (generic function with 1 method)

In [27]:
function fit_line_num2(x0, x1, t)
    w_init = [1.5, 1, 90]  # 初期値(適当)
    
    # 学習率  0.001 と 0.00001 にも変えてみて動きを確かめてみよう
    alpha = 0.0001  
    
    tau_max = 100000    # ループ最大数
    eps = 0.1    # ループ停止勾配閾値
    
    #w_hist = np.zeros([tau_max, 3])
    w_hist = zeros(tau_max, 3)

    #w_hist[0,:] = w_init
    w_hist[1,:] = w_init

    #for tau in range(1, tau_max):
    tau2 = 1    # for文抜けた後もtauを使う　pythonとスコープが違う
    dmse = 0    # 同上
    for tau in range(1, tau_max)
        tau2 = tau    # tau2更新
        # 勾配を求める
        #dmse = dmse_line2(x0, x1, T, w_hist[tau - 1])
        dmse = dmse_line2(x0, x1, t, w_hist[tau, :])
        # w0, w1, w2を少しずらす
        #w_hist[tau, 0] = w_hist[tau - 1, 0] - alpha * dmse[0]
        #w_hist[tau, 1] = w_hist[tau - 1, 1] - alpha * dmse[1]
        #w_hist[tau, 2] = w_hist[tau - 1, 2] - alpha * dmse[2]
        w_hist[tau + 1, 1] = w_hist[tau, 1] - alpha * dmse[1]
        w_hist[tau + 1, 2] = w_hist[tau, 2] - alpha * dmse[2]
        w_hist[tau + 1, 3] = w_hist[tau, 3] - alpha * dmse[3]
        #print("##### w_hist[tau]: ", w_hist[tau])
        # 終了判定
        #if max(np.absolute(dmse)) < eps:
        if maximum(abs.(dmse)) < eps
            #print("np.absolute(dmse) ", np.absolute(dmse))
            break
        end
    end
    #w0 = w_hist[tau, 0]
    #w1 = w_hist[tau, 1]
    #w2 = w_hist[tau, 2]
    #w_hist = w_hist[:tau, :]
    w0 = w_hist[tau2 + 1, 1]
    w1 = w_hist[tau2 + 1, 2]
    w2 = w_hist[tau2 + 1, 3]
    #println("tau2:", w_hist[tau2, :])
    w_hist = w_hist[tau2, :] 
    return [ w0, w1, w2 ] , dmse, w_hist
end
        

fit_line_num2 (generic function with 1 method)

In [38]:
@time W, dMSE, W_history = fit_line_num2(X0, X1, T)
W

  0.002178 seconds (6.55 k allocations: 3.466 MiB)


3-element Vector{Float64}:
  0.4726079398178704
  1.0658282776968346
 90.01235305096417

numpy  
Wall time: 151 ms  (0.151s)  
勾配法： [0.47260793981787014, 1.0658282776968346, 90.01235305096417]  

In [170]:
@time (Xn'Xn)\Xn'T

  0.000042 seconds (7 allocations: 688 bytes)


3-element Vector{Float64}:
  0.4560653182304476
  1.0864048983734247
 89.04744652297302

In [151]:
AAA = round.( (Xn'Xn)\Xn' , digits=6)

3×16 Matrix{Float64}:
 -0.019035   0.027616   0.004922  …   0.025287  -0.009459   0.018347
  0.020278  -0.01435   -0.018998     -0.00613   -0.003791  -0.007079
 -0.833498   0.492496   1.08815       0.051596   0.417761   0.20649

In [152]:
BBB = round.( pinv(Xn) , digits=6) 

3×16 Matrix{Float64}:
 -0.019035   0.027616   0.004922  …   0.025287  -0.009459   0.018347
  0.020278  -0.01435   -0.018998     -0.00613   -0.003791  -0.007079
 -0.833498   0.492496   1.08815       0.051596   0.417761   0.20649

In [185]:
function aaaaa(x, w)
    y = w[1] .* x[1] .+ w[2] .* x[2] .+ w[3]
end
aaaaa([20, 65, 1], T,  Xn\T)

168.78507128185453

In [182]:
Xn

16×3 Matrix{Float64}:
 15.4256   70.4323  1.0
 23.0081   58.1548  1.0
  5.00286  37.2192  1.0
 12.5583   56.5145  1.0
  8.6689   57.3172  1.0
  7.30846  40.8392  1.0
  9.65651  57.7921  1.0
 13.639    56.9384  1.0
 14.9192   63.0313  1.0
 18.4704   65.6942  1.0
 15.4799   62.3298  1.0
 22.1305   64.9451  1.0
 10.1113   57.7298  1.0
 26.9529   66.8939  1.0
  5.68469  46.6784  1.0
 21.7617   61.0832  1.0