In [1]:
import tensorflow as tf

tf.reset_default_graph()

In [2]:
batch_size = 2
num_actions = 2
number_of_atoms = 3
Vmax = tf.cast(1, tf.float32)
Vmin = tf.cast(-1, tf.float32)
gamma = 0.99

In [3]:
q_init = tf.constant([
    [
        [0.3, 0.4, 0.3],
        [0.3, 0.4, 0.3]
    ],
    [
        [0.3, 0.4, 0.3],
        [0.3, 0.4, 0.3]
    ]   
], dtype=tf.float32)

target_q = tf.constant([
    [
        [0.1, 0.1, 0.8],
        [0.8, 0.1, 0.1]
    ],
    [
        [0.1, 0.1, 0.8],
        [0.8, 0.1, 0.1]
    ] 
], dtype=tf.float32)

r = tf.constant([-1., -1.], dtype=tf.float32)
a = tf.constant([0, 1], dtype=tf.int32)
done_mask = tf.constant([0., 1.], dtype=tf.float32)


In [4]:
q = tf.get_variable('q', initializer=q_init)
q = tf.nn.softmax(q)

In [5]:
delta_z = (Vmax - Vmin) / float(number_of_atoms - 1)
z = tf.tile(tf.reshape(tf.linspace(Vmin, Vmax, number_of_atoms), (1, -1)), [batch_size, 1])

# update and project support
z_update = tf.reshape(r,  (-1, 1)) + gamma * (1. - tf.reshape(done_mask, (-1, 1))) * z
z_update_clipped = tf.clip_by_value(z_update, Vmin + 1e-7, Vmax + 1e-7)
b = (z_update_clipped - Vmin - 1e-7) / delta_z
u = tf.ceil(b)
l = tf.floor(b)

# argmax_a' Q_target(s', a')
z_batch = tf.tile(tf.reshape(tf.linspace(Vmin, Vmax, number_of_atoms), (1, 1, -1)), [batch_size, num_actions, 1])
a_next = tf.reduce_mean(target_q * z_batch, axis=2)
a_next_max = tf.argmax(a_next, 1)
a_next_idx = tf.stack([tf.range(batch_size), tf.cast(a_next_max, tf.int32)], axis=1)
target_p = tf.gather_nd(target_q, a_next_idx)

# distribute probability masses
l_p = (u - b) * target_p
u_p = (b - l) * target_p

with tf.variable_scope('projection', reuse=False):
    m = tf.get_variable(
        'm', [batch_size, number_of_atoms], dtype=tf.float32,
        initializer=tf.constant_initializer(0),
        trainable=False
    )

projection_op_lst = []
# zero out m
op = tf.assign(m, tf.zeros_like(m))
projection_op_lst.append(op.op)
for batch in range(batch_size):
    op = tf.assign(m[batch], m[batch] +
                   tf.unsorted_segment_sum(l_p[batch], tf.cast(l[batch], tf.int32), number_of_atoms) +
                   tf.unsorted_segment_sum(u_p[batch], tf.cast(u[batch], tf.int32), number_of_atoms)
                   )
    projection_op_lst.append(op.op)

m_prob = tf.nn.softmax(m, axis=1)

# compute loss
a_idx = tf.stack([tf.range(batch_size), tf.cast(a, tf.int32)], axis=1)
predicted_p = tf.gather_nd(q, a_idx)

loss = tf.reduce_sum(m_prob * tf.log(predicted_p), axis=1)
loss = -tf.reduce_mean(loss)


optimizer = tf.train.AdamOptimizer(0.01)
train_op = optimizer.minimize(loss)

In [6]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [7]:
for i in range(1000):
    sess.run(projection_op_lst)
    loss_np = sess.run(loss)
    if i % 100 == 0:    
        print(loss_np)
    sess.run(train_op)

1.0976036
0.8911241
1.0081251
1.0070993
1.0079237
1.0072668
1.0075135
1.0078026
1.0081862
1.0071292


In [8]:
variables_names = [v.name for v in tf.trainable_variables()]
values = sess.run(variables_names)
for k, v in zip(variables_names, values):
    print "Variable: ", k
    print "Shape: ", v.shape

Variable:  q:0
Shape:  (2, 2, 3)


In [9]:
sess.run(q)

array([[[0.29016688, 0.46323124, 0.24660183],
        [0.32204348, 0.35591307, 0.32204348]],

       [[0.32204348, 0.35591307, 0.32204348],
        [0.58076924, 0.20952903, 0.20970175]]], dtype=float32)

In [10]:
sess.run(target_q)

array([[[0.1, 0.1, 0.8],
        [0.8, 0.1, 0.1]],

       [[0.1, 0.1, 0.8],
        [0.8, 0.1, 0.1]]], dtype=float32)