In [1]:
import tensorflow as tf
import numpy as np
from utils.options import Options
from utils.utils import *

from networks import RecurrentActorNetwork, RecurrentCriticNetwork
from utils.replay_buffer_trace import ReplayBufferTrace

In [2]:
opt = Options()
opt.agent_params.trace_length = 5
opt.agent_params.opt_length = 5
opt.agent_params.state_dim = 5

RANDOM_SEED = 1256
# Size of replay buffer
BUFFER_SIZE = opt.agent_params.rm_size
MINIBATCH_SIZE = opt.agent_params.batch_size

In [3]:
actor = RecurrentActorNetwork(opt)

In [4]:
critic = RecurrentCriticNetwork(actor.get_num_trainable_vars(), opt)

In [5]:
replay_buffer = ReplayBufferTrace(BUFFER_SIZE, opt.agent_params.trace_length, opt.save_dir, RANDOM_SEED)

In [6]:
goal = [0.4, 0.6]
state = [0.1,0.1,50,100,0]

init_actor_hidden_c = state_initialiser(shape=(1,actor.rnn_size),mode='g')
init_actor_hidden_m = state_initialiser(shape=(1,actor.rnn_size),mode='g')
actor_init_hidden_cm = (init_actor_hidden_c, init_actor_hidden_m)

input_s = np.reshape(state, (1, 1, actor.s_dim))
input_g = np.reshape(goal, (1, 1, actor.goal_dim))

## Trying action selection

In [8]:
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    actor.set_session(sess)
    critic.set_session(sess)
    # Initialize target network weights
    actor.update_target_network()
    critic.update_target_network()
    action, actor_last_hidden_cm = actor.predict(input_s, input_g, actor_init_hidden_cm)

## Trying training

In [7]:
action = [-1, 2, 5]
reward = 10
terminal = False

for e in range(20):
    episode = []
    for i in range(5):
        transition = [np.reshape(state, (actor.s_dim,)), np.reshape(goal, (actor.goal_dim,)),
                          np.reshape(action, (actor.a_dim,)), reward, terminal, np.reshape(state, (actor.s_dim,))]
        episode.append(transition)
        
    replay_buffer.add(episode)

In [8]:
MINIBATCH_SIZE = 10
minibatch = replay_buffer.sample_batch(MINIBATCH_SIZE)

In [9]:
state_trace_batch = np.stack(minibatch[:,:,0].ravel()).reshape(MINIBATCH_SIZE, opt.agent_params.trace_length, actor.s_dim)
print('state_batch ', state_trace_batch.shape)

goal_trace_batch = np.stack(minibatch[:,:,1].ravel()).reshape(MINIBATCH_SIZE, opt.agent_params.trace_length, actor.goal_dim)
print('goal_batch ', goal_trace_batch.shape)

action_trace_batch = np.stack(minibatch[:,:,2].ravel()).reshape(MINIBATCH_SIZE,opt.agent_params.trace_length, actor.a_dim)
print('action_batch ', action_trace_batch.shape)

reward_trace_batch = np.stack(minibatch[:,:,3].ravel()).reshape(MINIBATCH_SIZE, opt.agent_params.trace_length, 1)
print('reward_batch ', reward_trace_batch.shape)

done_trace_batch = np.stack(minibatch[:,:,4].ravel()).reshape(MINIBATCH_SIZE, opt.agent_params.trace_length, 1)
print('done_batch ', done_trace_batch.shape)

next_state_batch = np.stack(minibatch[:,-1,5].ravel()).reshape(MINIBATCH_SIZE, 1, actor.s_dim)
next_state_trace_batch = np.concatenate([state_trace_batch, next_state_batch],axis=1)

print('next_batch ', next_state_trace_batch.shape)

extra_goal_batch = np.stack(minibatch[:,-1,1].ravel()).reshape(MINIBATCH_SIZE, 1, actor.goal_dim)
extra_goal_trace_batch = np.concatenate([goal_trace_batch, extra_goal_batch],axis=1)

print('extra_goal_batch ', extra_goal_trace_batch.shape)

state_batch  (10, 5, 5)
goal_batch  (10, 5, 2)
action_batch  (10, 5, 3)
reward_batch  (10, 5, 1)
done_batch  (10, 5, 1)
next_batch  (10, 6, 5)
extra_goal_batch  (10, 6, 2)


In [10]:
init_actor_hidden_batch = state_initialiser(shape=(MINIBATCH_SIZE, actor.rnn_size), mode='z')
actor_init_h_batch = (init_actor_hidden_batch, init_actor_hidden_batch)

init_critic_hidden_batch = state_initialiser(shape=(MINIBATCH_SIZE, actor.rnn_size), mode='z')
critic_init_h_batch = (init_critic_hidden_batch, init_critic_hidden_batch)

### First, trace_length < opt_length

In [11]:
target_actor_init_h_batch = actor_init_h_batch
target_critic_init_h_batch = critic_init_h_batch
update_length = opt.agent_params.trace_length

In [12]:
init_op = tf.global_variables_initializer()
sess = tf.InteractiveSession()

sess.run(init_op)
actor.set_session(sess)
critic.set_session(sess)
# Initialize target network weights
actor.update_target_network()
critic.update_target_network()

In [13]:
print(next_state_trace_batch.shape, goal_trace_batch.shape)

(10, 6, 5) (10, 5, 2)


In [14]:
next_action_batch = actor.predict_target(next_state_trace_batch, extra_goal_trace_batch, target_actor_init_h_batch)
print(next_action_batch.shape)

(10, 3)


In [15]:
next_action_trace_batch = np.concatenate([action_trace_batch, np.expand_dims(next_action_batch, axis=1)], axis=1)
print(next_action_trace_batch.shape)

(10, 6, 3)


In [16]:
target_q_batch = critic.predict_target(next_state_trace_batch, extra_goal_trace_batch, next_action_trace_batch, target_critic_init_h_batch)
print(target_q_batch.shape)

(10, 1)


In [17]:
target_lastQ_batch_masked = target_q_batch * (1.- done_trace_batch[:,-1])
print(target_lastQ_batch_masked.shape)

(10, 1)


In [18]:
rQ = np.concatenate([np.squeeze(reward_trace_batch[:,-update_length:],axis=-1), target_lastQ_batch_masked],axis=1)
print(rQ.shape)

(10, 6)


In [19]:
discounting_mat_dict = {}
try:
    # If already defined
    discounting_mat = discounting_mat_dict[update_length]
except KeyError:
    discounting_mat = np.zeros(shape=(update_length,update_length+1),dtype=np.float)
    for i in range(update_length):
        discounting_mat[i,:i] = 0.
        discounting_mat[i,i:] = 2 ** np.arange(0.,-i+update_length+1)
    discounting_mat = np.transpose(discounting_mat)
    discounting_mat_dict[update_length] = discounting_mat

In [20]:
discounting_mat_dict

{5: array([[  1.,   0.,   0.,   0.,   0.],
        [  2.,   1.,   0.,   0.,   0.],
        [  4.,   2.,   1.,   0.,   0.],
        [  8.,   4.,   2.,   1.,   0.],
        [ 16.,   8.,   4.,   2.,   1.],
        [ 32.,  16.,   8.,   4.,   2.]])}

In [21]:
y_trace_batch = np.expand_dims(np.matmul(rQ, discounting_mat), axis=-1)
print(y_trace_batch.shape)

(10, 5, 1)


In [22]:
predicted_q_value, _, v_loss = critic.train(state_trace_batch,
                                            goal_trace_batch,
                                            action_trace_batch,
                                            y_trace_batch,
                                            update_length, critic_init_h_batch)
print(predicted_q_value.shape)

(10, 5, 1)


In [23]:
for i in range(update_length):
    actor_init_h_batch_trace = (np.expand_dims(actor_init_h_batch[0],axis=1), np.expand_dims(actor_init_h_batch[1],axis=1))
    critic_init_h_batch_trace = (np.expand_dims(critic_init_h_batch[0],axis=1), np.expand_dims(critic_init_h_batch[1],axis=1))
    if i == 0:
        actor_init_h_batch_stack = actor_init_h_batch_trace
        critic_init_h_batch_stack = critic_init_h_batch_trace
    else:
        actor_init_h_batch_stack = (np.concatenate((actor_init_h_batch_stack[0],actor_init_h_batch_trace[0]),axis=1),np.concatenate((actor_init_h_batch_stack[1],actor_init_h_batch_trace[1]),axis=1))
        critic_init_h_batch_stack = (np.concatenate((critic_init_h_batch_stack[0],critic_init_h_batch_trace[0]),axis=1),np.concatenate((critic_init_h_batch_stack[1],critic_init_h_batch_trace[1]),axis=1))

    #print(actor_init_h_batch.shape)
    action_trace_batch_for_gradients, actor_init_h_batch = actor.action_trace(np.expand_dims(state_trace_batch[:,i],1),
                                                                              np.expand_dims(goal_trace_batch[:,i],1),
                                                                              actor_init_h_batch)
    critic_init_h_batch = critic.predict(np.expand_dims(state_trace_batch[:,i],1),
                                                np.expand_dims(goal_trace_batch[:,i],1),
                                                np.expand_dims(action_trace_batch[:,i],1), critic_init_h_batch, mode = 1)
    if i == 0:
        action_trace_batch_for_gradients_stack = action_trace_batch_for_gradients
    else:
        action_trace_batch_for_gradients_stack = np.concatenate((action_trace_batch_for_gradients_stack,action_trace_batch_for_gradients),axis=1)

In [24]:
state_trace_batch_stack = np.reshape(state_trace_batch,(MINIBATCH_SIZE*update_length, 1, actor.s_dim))
print('state ', state_trace_batch_stack.shape)
action_trace_batch_stack = np.reshape(action_trace_batch,(MINIBATCH_SIZE*update_length, 1, actor.a_dim))
print('action ', action_trace_batch_stack.shape)
goal_trace_batch_stack = np.reshape(goal_trace_batch,(MINIBATCH_SIZE*update_length, 1, actor.goal_dim))
print('goal ', goal_trace_batch_stack.shape)
action_trace_batch_for_gradients_stack = np.reshape(action_trace_batch_for_gradients_stack,(MINIBATCH_SIZE*update_length, 1, actor.a_dim))
print('action_trace ', action_trace_batch_for_gradients_stack.shape)
actor_init_h_batch_stack = (np.reshape(actor_init_h_batch_stack[0],(MINIBATCH_SIZE*update_length, actor.rnn_size)), np.reshape(actor_init_h_batch_stack[1],(MINIBATCH_SIZE*update_length, actor.rnn_size)))
print(actor_init_h_batch_stack[0].shape, actor_init_h_batch_stack[1].shape)
critic_init_h_batch_stack = (np.reshape(critic_init_h_batch_stack[0],(MINIBATCH_SIZE*update_length, critic.rnn_size)), np.reshape(critic_init_h_batch_stack[1],(MINIBATCH_SIZE*update_length, critic.rnn_size)))
print(critic_init_h_batch_stack[0].shape, critic_init_h_batch_stack[1].shape)

state  (50, 1, 5)
action  (50, 1, 3)
goal  (50, 1, 2)
action_trace  (50, 1, 3)
(50, 200) (50, 200)
(50, 200) (50, 200)


In [25]:
q_gradient_trace_batch = critic.action_gradients(state_trace_batch_stack,
                                                 goal_trace_batch_stack,
                                                 action_trace_batch_for_gradients_stack,
                                                 1,
                                                 critic_init_h_batch_stack)
print(q_gradient_trace_batch.shape)

(50, 1, 3)


In [29]:
actor.train(state_trace_batch_stack,
             goal_trace_batch_stack,
             q_gradient_trace_batch,
             1,
             actor_init_h_batch_stack)


In [30]:
actor.update_target_network()
critic.update_target_network()

In [31]:
replay_buffer.save_pickle()

Successfuly saved: RDPG Buffer
Buffer length saved: 20


In [33]:
a = False
if not a:
    print(a)

False


In [35]:
a = [3,5,7,8]
np.mean(a)

5.75