# 概要
分布を扱う tensorflow のモジュールが有ったのでテストをする。
以下のことができる

- 分布を定義（正規分布とか）
- 分布からのデータのサンプリング
- KL Divergence 等の算出
- 上記からのバックプロパゲーション

これらを使うことで自前でサンプリングや KL のネットワークを組まなくても済む

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
mean_op = tf.Variable([1.0, 2.0], dtype=tf.float32)
var_op = tf.Variable([0.1, 0.2], dtype=tf.float32)
target_op = tf.placeholder(tf.float32, shape=[None,])
std_op = tf.sqrt(tf.nn.softplus(var_op))
normal_dist_op = tf.distributions.Normal(loc=mean_op, scale=std_op, validate_args=True)
sample_op = normal_dist_op.sample()

loss_op = tf.reduce_mean(tf.square(target_op - sample_op))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
opt_op = optimizer.minimize(loss=loss_op)

# 平均と分散を学習する実験

In [3]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(100000):
        _, loss, sample, mean, var  = sess.run([opt_op, loss_op, sample_op, mean_op, var_op], feed_dict={
            target_op: [3.0, 10.0],
        })
        if i % 1000 == 0:
            print('{} - loss: {:.2e} \tsample: {} \tmean: {} \tvar:{}'.format(i, loss, sample, mean, var))

0 - loss: 3.68e+01 	sample: [ 1.07896209  1.63976455] 	mean: [ 1.00100005  2.00099993] 	var:[ 0.101  0.199]
1000 - loss: 2.28e+01 	sample: [ 1.36812723  3.45438623] 	mean: [ 1.80272901  2.96568656] 	var:[-0.26834628  0.06882958]
2000 - loss: 2.60e+01 	sample: [ 3.96813345  2.85415649] 	mean: [ 2.38556147  3.88340735] 	var:[-0.64259261 -0.03675965]
3000 - loss: 1.41e+01 	sample: [ 2.21759558  4.74383354] 	mean: [ 2.74415708  4.7609787 ] 	var:[-1.00791502 -0.13599555]
4000 - loss: 7.32e+00 	sample: [ 3.36986876  6.19222021] 	mean: [ 2.94839597  5.61151409] 	var:[-1.39750075 -0.2886343 ]
5000 - loss: 5.55e+00 	sample: [ 2.61850953  6.69075918] 	mean: [ 2.98395824  6.42000151] 	var:[-1.78333902 -0.40338746]
6000 - loss: 1.57e+00 	sample: [ 3.29949331  8.25481129] 	mean: [ 2.97924566  7.19836807] 	var:[-2.14142156 -0.5622772 ]
7000 - loss: 2.55e+00 	sample: [ 2.88946176  7.74394274] 	mean: [ 2.99523592  7.93425274] 	var:[-2.50681186 -0.78692597]
8000 - loss: 1.69e+00 	sample: [ 2.44166636  

66000 - loss: 8.88e-09 	sample: [ 2.99993682  9.9998827 ] 	mean: [ 3.00009155  9.99994469] 	var:[-19.71409416 -18.8612709 ]
67000 - loss: 1.27e-09 	sample: [  3.00004864  10.00001335] 	mean: [  2.99996018  10.00017262] 	var:[-19.81103134 -19.03563499]
68000 - loss: 5.68e-08 	sample: [  2.99966812  10.00005817] 	mean: [  3.00010729  10.00001907] 	var:[-19.90131378 -19.18191338]
69000 - loss: 7.96e-09 	sample: [  3.00009584  10.00008202] 	mean: [  2.99996042  10.00014782] 	var:[-19.98782349 -19.3136425 ]
70000 - loss: 2.66e-09 	sample: [  3.0000248   10.00006866] 	mean: [ 2.99988961  9.99982452] 	var:[-20.05804253 -19.44574356]
71000 - loss: 2.21e-08 	sample: [ 2.99988437  9.99982452] 	mean: [  3.00008631  10.00002289] 	var:[-20.13790512 -19.55576324]
72000 - loss: 1.84e-08 	sample: [ 2.99985361  9.99987602] 	mean: [ 3.00003958  9.99983025] 	var:[-20.20730972 -19.6564064 ]
73000 - loss: 9.55e-09 	sample: [  2.99993515  10.00012207] 	mean: [  2.99994302  10.00004959] 	var:[-20.26903343 -1

# KL Divergence を学習する実験

In [23]:
mean2_op = tf.Variable([5.0, -1.0], dtype=tf.float32)
var2_op = tf.Variable([0.5, 0.0], dtype=tf.float32)
std2_op = tf.sqrt(tf.nn.softplus(var2_op))
normal_dist2_op = tf.distributions.Normal(loc=mean2_op, scale=std2_op, validate_args=True)

kl_loss_op = tf.reduce_mean(tf.distributions.kl_divergence(
    normal_dist_op,
    normal_dist2_op,
    allow_nan_stats=False,
))

total_loss_op = loss_op + kl_loss_op
total_opt_op = optimizer.minimize(loss=total_loss_op)

In [24]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(100000):
        _, total_loss, loss, kl_loss,mean, var, mean2, var2  = sess.run([total_opt_op, total_loss_op, loss_op, kl_loss_op, mean_op, var_op, mean2_op, var2_op], feed_dict={
            target_op: [3.0, 10.0],
        })
        if i % 1000 == 0:
            print('{} - total:{:.2e} \tloss: {:.2e} \tkl:{:.2e} \nmean: {} \tvar:{}\nmean2: {} \tvar2: {}\n'.format(i, total_loss, loss, kl_loss, mean, var, mean2, var2))

0 - total:3.70e+01 	loss: 2.97e+01 	kl:7.36e+00 
mean: [ 1.00100005  2.00099993] 	var:[ 0.101       0.20100001]
mean2: [ 4.99900007 -0.99900001] 	var2: [ 0.50099999  0.001     ]

1000 - total:3.69e+01 	loss: 3.39e+01 	kl:3.08e+00 
mean: [ 1.82319582  2.98637271] 	var:[-0.13083063  0.0790823 ]
mean2: [ 4.19095755 -0.11104828] 	var2: [ 1.15651107  0.83818573]

2000 - total:1.71e+01 	loss: 1.52e+01 	kl:1.94e+00 
mean: [ 2.41752148  3.91831946] 	var:[-0.31228971  0.06985957]
mean2: [ 3.62600327  0.67551881] 	var2: [ 1.36767101  1.50698698]

3000 - total:1.11e+01 	loss: 9.55e+00 	kl:1.55e+00 
mean: [ 2.79412317  4.8194313 ] 	var:[-0.46688846 -0.01598025]
mean2: [ 3.24204159  1.44327784] 	var2: [ 1.29367638  2.11252928]

4000 - total:1.16e+01 	loss: 1.02e+01 	kl:1.36e+00 
mean: [ 2.96852303  5.67079687] 	var:[-0.60265046 -0.07903674]
mean2: [ 3.04751372  2.21845603] 	var2: [ 0.99480426  2.68903303]

5000 - total:3.06e+00 	loss: 1.86e+00 	kl:1.21e+00 
mean: [ 2.98289037  6.47717953] 	var:[-0.

45000 - total:2.39e-04 	loss: 5.87e-05 	kl:1.80e-04 
mean: [ 2.99984455  9.99953938] 	var:[-9.83926582 -9.41777515]
mean2: [ 3.00003672  9.99952507] 	var2: [-9.83815098 -9.41668224]

46000 - total:1.47e-04 	loss: 1.24e-04 	kl:2.35e-05 
mean: [  2.99974775  10.00005627] 	var:[-9.86430168 -9.48896122]
mean2: [  2.99977922  10.0001297 ] 	var2: [-9.86493015 -9.48805046]

47000 - total:3.95e-05 	loss: 3.44e-05 	kl:5.13e-06 
mean: [ 2.99987864  9.99997044] 	var:[-9.88432217 -9.55307293]
mean2: [  2.99988127  10.00000763] 	var2: [-9.88426113 -9.55132484]

48000 - total:1.73e-05 	loss: 4.67e-06 	kl:1.27e-05 
mean: [  2.99993706  10.00013256] 	var:[-9.90282154 -9.60405159]
mean2: [  2.99993682  10.00018311] 	var2: [-9.90321732 -9.60287094]

49000 - total:2.76e-04 	loss: 2.31e-04 	kl:4.55e-05 
mean: [  3.00006223  10.00022316] 	var:[-9.92465115 -9.6542778 ]
mean2: [  3.00015569  10.00023937] 	var2: [-9.92306137 -9.65484047]

50000 - total:5.49e-05 	loss: 3.51e-06 	kl:5.14e-05 
mean: [  2.9999854

89000 - total:4.31e-05 	loss: 1.06e-06 	kl:4.21e-05 
mean: [  2.99988532  10.00001144] 	var:[-10.15978909 -10.14786339]
mean2: [ 2.99994469  9.99994564] 	var2: [-10.15980148 -10.14873505]

90000 - total:4.81e-05 	loss: 3.97e-05 	kl:8.38e-06 
mean: [  2.99983573  10.00004101] 	var:[-10.16109657 -10.15068817]
mean2: [  2.99982572  10.00004387] 	var2: [-10.16118431 -10.15064716]

91000 - total:7.05e-05 	loss: 9.49e-06 	kl:6.10e-05 
mean: [  2.99977326  10.00022697] 	var:[-10.15915489 -10.15288067]
mean2: [  2.99987483  10.0002346 ] 	var2: [-10.1578455 -10.1538105]

92000 - total:8.19e-05 	loss: 1.03e-06 	kl:8.09e-05 
mean: [  2.99995589  10.00023842] 	var:[-10.15421391 -10.1563797 ]
mean2: [  3.00001121  10.00017834] 	var2: [-10.15601158 -10.15653419]

93000 - total:1.69e-04 	loss: 5.65e-05 	kl:1.13e-04 
mean: [  3.00000882  10.00022793] 	var:[-10.15791607 -10.15495491]
mean2: [  2.99999809  10.00015068] 	var2: [-10.15701962 -10.15538311]

94000 - total:2.14e-04 	loss: 1.81e-04 	kl:3.38e-