このノートブックは、次のページの内容に基づいて作られています。

http://gluon.mxnet.io/chapter01_crashcourse/probability.html#Basic-probability-theory

# 確率論の基礎

必要なライブラリを読み込みます。

In [1]:
require 'mxnet'

true

サイコロの各目が出る確率の理論値を変数 `probabilities` に入れます。

In [2]:
probabilities = MXNet::NDArray.ones(6) / 6


[0.166667, 0.166667, 0.166667, 0.166667, 0.166667, 0.166667]
<MXNet::NDArray 6 @cpu(0)>

この確率にしたがって、サイコロを振ってみましょう。
MXNet では、`MXNet::NDArray.sample_multinomial` を使うと多項分布からサンプリングできます。

単一のサンプルを取り出してみましょう。確率のベクトルを渡すだけです。

In [3]:
MXNet::NDArray.sample_multinomial(probabilities)


[3]
<MXNet::NDArray 1 @cpu(0)>

`MXNet::NDArray.sample_multinomial(probabilities)` を呼び出すたびに、確率変数を得ることができます。そのため、 `Array.new(10) { ... }` のような記法を使って複数回これを呼び出した結果を集めれば、同じ分布から複数のサンプルを生成したことになります。しかし、このやり方は遅いです。

`sample_multinomial` は複数のサンプルを一度に生成することに対応しています。以下のように、`shape:` キーワードで結果の数値配列の形を指定するだけです。

In [4]:
MXNet::NDArray.sample_multinomial(probabilities, shape: [10])


[3, 4, 5, 3, 5, 3, 5, 2, 3, 3]
<MXNet::NDArray 10 @cpu(0)>

In [5]:
MXNet::NDArray.sample_multinomial(probabilities, shape: [5, 10])


[[2, 2, 1, 5, 0, 5, 1, 2, 2, 4], 
 [4, 3, 2, 3, 2, 5, 5, 0, 2, 0], 
 [3, 0, 2, 4, 5, 4, 0, 5, 5, 5], 
 [2, 4, 4, 2, 3, 4, 4, 0, 4, 3], 
 [3, 0, 3, 5, 4, 3, 0, 2, 2, 1]]
<MXNet::NDArray 5x10 @cpu(0)>

では、1000回サイコロを振った結果を生成してみましょう。

In [6]:
rolls = MXNet::NDArray.sample_multinomial(probabilities, shape: [1000])
nil # <- This nil avoids to display the last result

In [7]:
counts = MXNet::NDArray.zeros([6, 1000])
totals = MXNet::NDArray.zeros([6])
rolls.each_with_index do |roll, i|
  totals[roll.as_scalar.to_i] += 1
  counts[0..-1, i] = totals
end
nil

In [8]:
totals / 1000


[0.167, 0.168, 0.175, 0.159, 0.158, 0.173]
<MXNet::NDArray 6 @cpu(0)>

In [9]:
counts


[[0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...], 
 [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ...], 
 [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5, 5, ...], 
 [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, ...], 
 [0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, ...], 
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 3, 3, ...]]
<MXNet::NDArray 6x1000 @cpu(0)>

In [10]:
counts[0..-1, -3..-1]


[[165, 166, 167], 
 [168, 168, 168], 
 [175, 175, 175], 
 [159, 159, 159], 
 [158, 158, 158], 
 [173, 173, 173]]
<MXNet::NDArray 6x3 @cpu(0)>

In [11]:
x = MXNet::NDArray.arange(1000).reshape([1, 1000]) + 1
estimates = counts / x
IRuby.display estimates[0..-1, 0]
IRuby.display estimates[0..-1, 1]
IRuby.display estimates[0..-1, 100]
nil


[0, 1, 0, 0, 0, 0]
<MXNet::NDArray 6 @cpu(0)>


[0, 0.5, 0, 0, 0.5, 0]
<MXNet::NDArray 6 @cpu(0)>


[0.19802, 0.158416, 0.178218, 0.188119, 0.128713, 0.148515]
<MXNet::NDArray 6 @cpu(0)>

In [12]:
require 'rbplotly'

indices = [*0...estimates.shape[1]]
plot = Plotly::Plot.new(data: [
  { x: indices, y: estimates[0, 0..-1].to_a, name: 'Estimated P(die=1)' },
  { x: indices, y: estimates[1, 0..-1].to_a, name: 'Estimated P(die=2)' },
  { x: indices, y: estimates[2, 0..-1].to_a, name: 'Estimated P(die=3)' },
  { x: indices, y: estimates[3, 0..-1].to_a, name: 'Estimated P(die=4)' },
  { x: indices, y: estimates[4, 0..-1].to_a, name: 'Estimated P(die=5)' },
  { x: indices, y: estimates[5, 0..-1].to_a, name: 'Estimated P(die=6)' },
])
plot.show

#<CZTop::Socket::PUB:0x7f96f43427c0 last_endpoint="tcp://127.0.0.1:63811">

In [13]:
unless File.exist?('train-images-idx3-ubyte') &&
       File.exist?('train-labels-idx1-ubyte')
  system("wget http://data.mxnet.io/mxnet/data/mnist.zip")
  system("unzip -x mnist.zip")
end

In [14]:
ycount = MXNet::NDArray.ones([10])
xcount = MXNet::NDArray.ones([784, 10])  # 784 == 28*28

mnist_iter = MXNet::IO::MNISTIter.new(batch_size: 1)
mnist_iter.each do |batch|
  x = batch.data[0].reshape([784])
  y = batch.label[0].to_i
  ycount[y] += 1
  xcount[0..-1, y] += x
end

0.upto(9) do |i|
  xcount[0..-1, i] = xcount[0..-1, i] / ycount[i]
end

py = ycount / MXNet::NDArray.sum(ycount)


[0.0987169, 0.112365, 0.0993001, 0.102183, 0.0973671, 0.0903516, 0.0986336, ...]
<MXNet::NDArray 10 @cpu(0)>

In [17]:
require 'chunky_png'
require 'base64'

def imshow(ary)
  height, width = ary.shape
  fig = ChunkyPNG::Image.new(width, height, ChunkyPNG::Color::TRANSPARENT)
  ary = ((ary - ary.min) / ary.max) * 255
  0.upto(height - 1) do |i|
    0.upto(width - 1) do |j|
      v = ary[i, j].round
      fig[j, i] = ChunkyPNG::Color.rgba(v, v, v, 255)
    end
  end

  src = 'data:image/png;base64,' + Base64.strict_encode64(fig.to_blob)
  IRuby.display "<img src='#{src}' width='#{width*2}' height='#{height*2}' />", mime: 'text/html'
end

0.upto(9) do |i|
  imshow(xcount[0..-1, i].reshape([28, 28]).to_narray)
end

py


[0.0987169, 0.112365, 0.0993001, 0.102183, 0.0973671, 0.0903516, 0.0986336, ...]
<MXNet::NDArray 10 @cpu(0)>

In [18]:
log_xcount = MXNet::NDArray.log(xcount)
log_xcount_neg = MXNet::NDArray.log(1 - xcount)
log_py = MXNet::NDArray.log(py)

mnist_iter = MXNet::IO::MNISTIter.new(
  image: 't10k-images-idx3-ubyte',
  label: 't10k-labels-idx1-ubyte',
  batch_size: 1
)
mnist_iter.each_with_index do |batch, batch_i|
  x = batch.data[0].reshape([784])
  y = batch.label[0].to_i
  
  log_px = log_py.dup
  0.upto(9) do |i|
    log_px[i] += MXNet::NDArray.dot(log_xcount[0..-1, i], x) +
                   MXNet::NDArray.dot(log_xcount_neg[0..-1, i], 1-x)
  end
  log_px -= MXNet::NDArray.max(log_px)
  px = MXNet::NDArray.exp(log_px).to_narray
  px /= px.sum

  imshow(x.reshape([28, 28]).to_narray)
  puts px.to_a.map.with_index {|v, i| "[#{i}] %0.3f" % v }.join(', ')
  break if batch_i == 10
end

[0] 0.000, [1] 0.000, [2] 0.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 1.000, [7] 0.000, [8] 0.000, [9] 0.000


[0] 0.000, [1] 1.000, [2] 0.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 0.000, [7] 0.000, [8] 0.000, [9] 0.000


[0] 0.999, [1] 0.000, [2] 0.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 0.001, [7] 0.000, [8] 0.000, [9] 0.000


[0] 1.000, [1] 0.000, [2] 0.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 0.000, [7] 0.000, [8] 0.000, [9] 0.000


[0] 0.000, [1] 0.000, [2] 0.000, [3] 1.000, [4] 0.000, [5] 0.000, [6] 0.000, [7] 0.000, [8] 0.000, [9] 0.000


[0] 0.000, [1] 0.979, [2] 0.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 0.000, [7] 0.000, [8] 0.021, [9] 0.000


[0] 0.000, [1] 0.000, [2] 0.000, [3] 0.000, [4] 0.885, [5] 0.000, [6] 0.000, [7] 0.000, [8] 0.000, [9] 0.115


[0] 0.000, [1] 0.000, [2] 0.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 0.000, [7] 0.000, [8] 1.000, [9] 0.000


[0] 1.000, [1] 0.000, [2] 0.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 0.000, [7] 0.000, [8] 0.000, [9] 0.000


[0] 0.000, [1] 0.000, [2] 0.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 0.000, [7] 0.000, [8] 0.000, [9] 1.000


[0] 0.000, [1] 0.000, [2] 1.000, [3] 0.000, [4] 0.000, [5] 0.000, [6] 0.000, [7] 0.000, [8] 0.000, [9] 0.000
