In [1]:
typealias F32 Float32
typealias Vec32 Vector{F32}
typealias Mat32 Array{F32,2}
nothing

- $L$ は学習データのサンプル数
- $M$ は出力ベクトルの次元 (分類クラス数 or 目的変数の個数)
- $N$ は入力ベクトルの次元 (説明変数の個数)

とする。この時

- $w$ は $(N, M)$ 行列
- $\mathit{xs}$ は $(L, N)$ 行列
- $\mathit{ts}$ は $(L, M)$ 行列

である。

In [2]:
function train(rho, w, xs, ts)
    gs = xs * w  # gs is (L, M)-matrix
    es = gs - ts # es is (L, M)-matrix
    dw = xs' * es  # xs' is (N, L)-matrix, so (xs' * es) is (N, M)-matrix
    w_ = w - rho * dw # w_ is updated w
    return w_
end

train (generic function with 1 method)

In [3]:
function error(w, xs, ts)
    gs = xs * w  # gs is (L, M)-matrix
    es = gs - ts # es is (L, M)-matrix
    return sum(es .^ 2.0) / 2.0 # .^ means element-wise power
end

error (generic function with 1 method)

In [15]:
classes = [
    F32[1, 0],  # class-1
    F32[0, 1]   # class-2
]

xs = F32[ 1.0, 0.5, -0.2, -0.4, -1.3, -2.0 ]  # input vector
cs =    [   1,   1,    2,    1,    2,    2 ]  # corresponding classes
ts = reduce(hcat, map(i -> classes[i], cs))'  # making teacher vector

hcat(xs, ts)

6×3 Array{Float32,2}:
  1.0  1.0  0.0
  0.5  1.0  0.0
 -0.2  0.0  1.0
 -0.4  1.0  0.0
 -1.3  0.0  1.0
 -2.0  0.0  1.0

In [16]:
# weight matrix
w = F32[0.2, 0.3]'

1×2 Array{Float32,2}:
 0.2  0.3

In [17]:
last_error = error(w, xs, ts)

4.294099807739258

In [18]:
rho = 0.1
while true
    w = train(rho, w, xs, ts)
    e = error(w, xs, ts)
    println(e)
    if abs(last_error - e) < 1e-9
        break
    end
    last_error = e
end

2.2403740882873535
2.072387456893921
2.0586469173431396
2.057523012161255
2.057431221008301
2.0574235916137695
2.0574231147766113
2.0574231147766113


In [19]:
xs = F32[ 1.0, 0.5, -0.2, -0.4, -1.3, -2.0 ]  # input vector
cs =    [   1,   1,    1,    2,    2,    2 ]  # corresponding classes

ts = reduce(hcat, map(i -> classes[i], cs))'  # making teacher vector
w = F32[0.2, 0.3]' # weight matrix

last_error = error(w, xs, ts)
println(last_error)
rho = 0.1
while true
    w = train(rho, w, xs, ts)
    e = error(w, xs, ts)
    println(e)
    if abs(last_error - e) < 1e-9
        break
    end
    last_error = e
end

4.3140997886657715
2.118553876876831
1.9389673471450806
1.9242777824401855
1.9230761528015137
1.9229780435562134
1.922970175743103
1.9229692220687866
1.9229692220687866
