In [6]:
import tensorflow as tf
import gym
# import numpy as np
# from src.data_collection import collect_data
import logging
import tqdm
from multiprocessing import Pool


In [7]:
class VPG():

    def __init__(self, env):
        self._env = env
        if type(env.action_space) == gym.spaces.box.Box:
            self._build_computational_graph_continuous_actions()
        else:
            self._build_computational_graph_categorical_actions()   


    def train(self, policy_learning_rate=0.0003, value_function_learning_rate = 0.001, n_value_function_updates = 10,\
              n_epochs = 10):
        logging.basicConfig(filename='training.log',level=logging.DEBUG)
        logging.debug('Current Epoch, mean return, std return, min return, max return')

        episode_returns = []
        epoch_state_value_loss = []
        epoch_entropy = []

        
        for i in tqdm.tqdm(range(n_epochs)):
            batch_obs, batch_acts, batch_rews, batch_rets, batch_len = collect_data(self._env, self._graph, 4000)
            episode_returns.extend(batch_rets)
            logging.debug('%i, %f, %f, %f, %f',i, np.mean(batch_rets), np.std(batch_rets), np.min(batch_rets), np.max(batch_rets))

            sess.run(self._graph['train_policy'],feed_dict={
                                            self._graph['obs_ph']: np.array(batch_obs),
                                            self._graph['act_ph']: np.array(batch_acts),
                                            self._graph['rew_ph']: np.array(batch_rews)
                                        })
            for _ in range(n_value_function_updates):
                sess.run(self._graph['train_state_value'],feed_dict={
                                        self._graph['obs_ph']: np.array(batch_obs),
                                        self._graph['act_ph']: np.array(batch_acts),
                                        self._graph['rew_ph']: np.array(batch_rews)
                                    })
            v, e = sess.run([self._graph['state_value_loss'], self._graph['entropy']], feed_dict={
                                            self._graph['obs_ph']: np.array(batch_obs),
                                            self._graph['act_ph']: np.array(batch_acts),
                                            self._graph['rew_ph']: np.array(batch_rews)
                                        })
            epoch_state_value_loss.append(v)
            epoch_entropy.append(e)
        return episode_returns, epoch_state_value_loss, epoch_entropy


    def _build_computational_graph_categorical_actions(self):
        env = self._env
        obs_dim = env.observation_space.shape[0]
        n_acts = env.action_space.n

        # placeholder
        obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
        act_ph = tf.placeholder(shape=(None,), dtype=tf.int32)
        weights_ph = tf.placeholder(shape=(None,), dtype=tf.float32)

        # make core of policy network
        mlp = tf.keras.models.Sequential()
        mlp.add(tf.keras.layers.Dense(30, activation='tanh'))
        mlp.add(tf.keras.layers.Dense(30, activation='tanh'))
        mlp.add(tf.keras.layers.Dense(n_acts))
        logits = mlp(obs_ph)

        # value function network
        state_value_mlp = tf.keras.models.Sequential()
        state_value_mlp.add(tf.keras.layers.Dense(50, activation='relu'))
        state_value_mlp.add(tf.keras.layers.Dense(50, activation='relu'))
        state_value_mlp.add(tf.keras.layers.Dense(1))
        state_values = state_value_mlp(obs_ph)

        # make action selection op (outputs int actions, sampled from policy)
        actions = tf.squeeze(tf.multinomial(logits=logits,num_samples=1), axis=1)
        action_probs = tf.nn.softmax(logits)
        entropy = -tf.reduce_mean(tf.reduce_sum(action_probs * tf.math.log(action_probs), axis=1))

        # make loss function whose gradient, for the right data, is policy gradient
        action_masks = tf.one_hot(act_ph, n_acts)
        log_probs = tf.reduce_sum(action_masks * tf.nn.log_softmax(logits), axis=1)
        policy_loss = -tf.reduce_mean((weights_ph - state_values) * log_probs)

        state_value_loss = tf.reduce_mean((weights_ph - state_values)**2)

        optimizer_policy = tf.train.AdamOptimizer(policy_learning_rate)
        train_policy = optimizer_policy.minimize(policy_loss)
        optimizer_state_value = tf.train.AdamOptimizer(value_function_learning_rate)
        train_state_value = optimizer_state_value.minimize(state_value_loss)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        
        graph = {
                    'obs_ph': obs_ph,
                    'act_ph': act_ph,
                    'rew_ph': rew_ph,
                    'actions': actions,
                    'state_values': state_values,
                    'entropy': entropy,
                    'policy_loss': policy_loss,
                    'state_value_loss': state_value_loss,
                    'train_policy': train_policy,
                    'train_state_value': train_state_value,
                    'sess': sess
                }
        self._graph = graph

In [4]:
env = gym.make('CartPole-v0')
vpg = VPG(env)

In [5]:
def collect_data(env, computational_graph, batch_size, gamma = 0.99):
    pass


def collect_data_actor(env, computational_graph, batch_size, gamma = 0.99):
    import numpy as np
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths
    # reset episode-specific variables
    obs = env.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep
    # collect experience by acting in the environment with current policy
    while True:
        batch_obs.append(obs.copy())

        # act in the environment
        act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
        obs, rew, done, _ = env.step(act)
        
        # save action, reward
        batch_acts.append(act)
        ep_rews.append(rew)

        if done:
            if render:
                env.close()
                render = False
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            
            if ep_len == env._max_episode_steps:
                bootstrap_value = sess.run(state_values, {obs_ph:obs.reshape(1,-1)})[0][0]
            else:
                bootstrap_value = 0
            batch_weights += compute_rewards_to_go(ep_rews, gamma, bootstrap_value)
            
            # reset episode-specific variables
            obs, done, ep_rews = env.reset(), False, []

            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
    return batch_obs, batch_acts, batch_weights, batch_rets, batch_lens

KeyboardInterrupt: 