In [110]:
from memory import MTMemoryBuffer
from policies import PPOTrainer
from utils import gaussian_likelihood, reshape_train_var
import tensorflow as tf
from tensorflow.layers import dense, conv2d, max_pooling2d, flatten
import numpy as np
import time
import gym
from env import CartPoleEnv
import time

In [117]:
env = CartPoleEnv()

obs = tf.placeholder(tf.float32, shape=[None]+list(env.observation_space.shape))
dense1 = dense(obs, 32, activation=tf.tanh)
dense2 = dense(dense1, 32, activation=tf.tanh)
act_probs = dense(dense2, env.action_space.shape[0])

v_dense1 = dense(obs, 32, activation=tf.tanh)
v_dense2 = dense(v_dense1, 32, activation=tf.tanh)
value = dense(v_dense2, 1)

network = PPOTrainer(obs, act_probs, value, act_type='c')

In [118]:
mb = MTMemoryBuffer()

In [142]:
ec = EnvController(CartPoleEnv, mb, n_threads=8)

In [140]:
tmp = time.time()
ec.sim_thread(1, network, 1000)
print(mb.to_data().shape)
time.time() - tmp

(45540, 3)


31.576586484909058

In [143]:
tmp = time.time()
ec.sim_episodes(network, 1000)
time.time() - tmp

(46309, 3)


15.376994848251343

In [150]:
for i in range(1, 9):
    ec = EnvController(CartPoleEnv, mb, n_threads=i)
    tmp = time.time()
    ec.sim_episodes(network, 1000)
    print(i, ':', time.time() - tmp)

1 : 30.95636796951294
2 : 13.943068504333496
3 : 14.795561790466309
4 : 14.825214862823486
5 : 15.121448278427124
6 : 15.156155586242676
7 : 15.415870189666748
8 : 15.219146013259888


In [146]:
import threading

class EnvController():
    def __init__(self, make_env, memory_buffer, n_threads=1, obs_transform=None, act_transform=None):
        self.make_env = make_env
        self.mb = memory_buffer
        self.n_threads = n_threads
        if obs_transform is not None:
            self.obs_transform = obs_transform
        if act_transform is not None:
            self.act_transform = act_transform
        
    def obs_transform(self, obs):
        return obs.squeeze()
    
    def act_transform(self, act):
        return act
    
    def set_obs_transform(self, transform_func):
        self.obs_transform = transform_func
    
    def set_act_transform(transform_func):
        self.act_transform = transform_func
        
    def sim_thread(self, agent_id, network, n_episodes=1, max_steps=200, render=False):
        env = self.make_env()
        
        for episode in range(n_episodes):
            self.mb.start_rollout(agent_id)
            obs = env.reset()
            for step in range(max_steps):
                obs = self.obs_transform(obs)
                act = network.gen_act(obs)
                act = self.act_transform(act)

                obs_next, rew, d, _ = env.step(act)

                if render:
                    env.render()
                    time.sleep(0.02)

                mb.record(agent_id, obs, act, rew)
                obs = obs_next

                if d:
                    break
                    
    def sim_episodes(self, network, n_episodes=1, max_steps=200, render=False):
        threads = []
        ept = [int(n_episodes // self.n_threads) for i in range(self.n_threads)] # Episodes per thread
        ept[:(n_episodes % self.n_threads)] += np.ones((n_episodes % self.n_threads,))
        for i in range(self.n_threads):
            new_thread = threading.Thread(target=self.sim_thread, args=(i, network, int(ept[i]), max_steps,))
            threads.append(new_thread)
            new_thread.start()
            
        for thread in threads:
            thread.join()

In [1]:
def spawn_agent(agent_id, create_env, network, mb, n_episodes=4, max_steps=500, render=False):
    env = create_env()
    for episode in range(n_episodes):
        obs = env.reset()
        mb.start_rollout(agent_id)
        for step in range(max_steps):
            act = network.gen_act([obs])
            
            if render:
                env.render()
                time.sleep(0.02)
                
            obs_next, rew, d, _ = env.step(act)
            
            mb.record(agent_id, obs, act, rew)
            obs = obs_next
            
            if d:
                break

In [2]:
def gather_data(mpmb, n_threads=7, n_episodes=4, return_rewards=False, reset_mem=True):
    agent_pool = []
    
    for i in range(n_threads):
        agent_pool.append(Thread(target=spawn_agent, args=(i, make_lunar_lander_c, network.gen_act, mpmb),
                                                     kwargs={'n_episodes': math.ceil(n_episodes/n_threads)}))
        agent_pool[-1].start()

    for agent in agent_pool:
        agent.join()
        
    
    if return_rewards:
        return mpmb.get_avg_reward(), mpmb.to_data(reset=reset_mem)
    return mpmb.to_data(reset=reset_mem)

In [3]:
# mpmb = MTMemoryBuffer()
# ep_rewards = []

# for i in range(n_episodes//update_freq):
# #     ep_reward, train_data = gather_data(mpmb, n_episodes=update_freq, return_rewards=True)
# #     ep_rewards.append(ep_reward)
#     spawn_agent(0, make_cart_pole, network.gen_act, mpmb, n_episodes=update_freq, max_steps=300, render=False)
#     print(mpmb.get_avg_reward())
#     train_data = mpmb.to_data()
#     network.train(train_data)
# #     if i % print_freq == 0 and i != 0:
# #         print(f'Update #{int(i)}, Recent Reward:', np.mean(ep_rewards[print_freq:]))