In [1]:
from memory import MemoryBuffer
from env import make_cart_pole, make_cart_pole_c, make_lunar_lander_c
from policies import VAPGTrainer
import tensorflow as tf
from tensorflow.layers import dense
from utils import reshape_train_var, gaussian_likelihood
import numpy as np
import time

In [2]:
# Lunar lander continuous network

env = make_lunar_lander_c()

obs = tf.placeholder(tf.float32, shape=[None, env.observation_space.shape[0]])
dense1 = dense(obs, 32, activation=tf.tanh)
dense2 = dense(dense1, 32, activation=tf.tanh)
act_probs = dense(dense2, env.action_space.shape[0])

v_dense1 = dense(obs, 32, activation=tf.tanh)
value = dense(v_dense1, 1)

network = VAPGTrainer(obs, act_probs, value, act_type='c')

In [3]:
# # Cart pole network

# env = make_cart_pole()

# obs = tf.placeholder(tf.float32, shape=[None, 4])
# dense1 = dense(obs, 32, activation=tf.tanh)
# dense2 = dense(dense1, 32, activation=tf.tanh)
# act_probs = dense(dense2, 2)
# softmax_probs = tf.nn.softmax(act_probs)

# v_dense1 = dense(obs, 32, activation=tf.tanh)
# v_dense2 = dense(v_dense1, 32, activation=tf.tanh)
# value = dense(v_dense2, 1)

# network = VAPGTrainer(obs, softmax_probs, value, act_type='d')

In [9]:
n_episodes = 1000000
max_steps = 300
update_freq = 64
print_freq = 5

mb = MemoryBuffer()

In [None]:
all_rewards = []

for episode in range(n_episodes):
    ep_reward = 0
    
    mb.start_rollout()
    obs = env.reset()
    for step in range(max_steps):
        obs = obs.squeeze()
        act = network.gen_act(obs)
        
        obs_next, rew, d, _ = env.step(act)
        ep_reward += rew
        
        if True:
            env.render()
            time.sleep(0.02)
        
        mb.record(obs, act, rew)
        obs = obs_next
        
        if d:
            break
            
    all_rewards.append(ep_reward)
            
    if episode % update_freq == 0 and episode != 0:
        t = mb.to_data()
        #print(t[:,0].shape)
#         print(t.shape)
        network.train(t)
        
        if episode % (update_freq * print_freq) == 0:
            print(f'Update #{episode // update_freq}, Reward: {np.mean(all_rewards[-update_freq*print_freq:])}')
        