# Import Libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random
import ray
import tensorflow as tf
from gym import spaces
from ray.rllib.algorithms.ppo import PPO
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from environments.cartpole import MaskedCartPoleEnv

# Custom LSTM Model

In [None]:
class LSTMPPOModel(TFModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(
            obs_space=obs_space,
            action_space=action_space,
            num_outputs=num_outputs,
            model_config=model_config,
            name=name
        )

        orig_space = getattr(obs_space, "original_space", obs_space)

        assert (
                isinstance(orig_space, spaces.Dict)
                and "action_mask" in orig_space.spaces
                and "observations" in orig_space.spaces
        )

        self._cell_size = 64
        self.num_outputs = 2

        inputs = tf.keras.layers.Input(shape=orig_space['observations']['cartpole_obs'].shape, name='cartpole_obs')
        x = tf.keras.layers.Dense(units=256, activation='tanh', name='hidden_1')(inputs)
        flat_output = tf.keras.layers.Dense(units=256, activation='tanh', name='hidden_2')(x)
        self.base_model = tf.keras.Model(inputs=inputs, outputs=flat_output, name='base_model')

        lstm_input = tf.keras.layers.Input(shape=(None, 256), name='inputs')
        state_in_h = tf.keras.layers.Input(shape=(self._cell_size,), name='h')
        state_in_c = tf.keras.layers.Input(shape=(self._cell_size, ), name='c')
        seq_in = tf.keras.layers.Input(shape=(), name='seq_in', dtype=tf.int32)

        lstm_out, state_h, state_c = tf.keras.layers.LSTM(self._cell_size, return_sequences=True, return_state=True, name="lstm")(
            inputs=lstm_input,
            mask=tf.sequence_mask(seq_in),
            initial_state=[state_in_h, state_in_c]
        )
        policy_out = tf.keras.layers.Dense(units=self.num_outputs, activation=None, name='policy_out')(lstm_out)
        value_out = tf.keras.layers.Dense(units=1, activation=None, name='value_out')(lstm_out)
        self.rnn_model = tf.keras.Model(
            inputs=[lstm_input, seq_in, state_in_h, state_in_c],
            outputs=[policy_out, value_out, state_h, state_c],
            name='lstm_model'
        )

    def forward(self, input_dict, state, seq_lens):
        model_out, _ = self.forward_rnn(inputs=input_dict['obs']['observations'], state=state, seq_lens=seq_lens)

        action_mask = input_dict['obs']['action_mask']
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
        return tf.reshape(model_out, [-1, self.num_outputs]) + inf_mask, state

    def forward_rnn(self, inputs, state, seq_lens):
        x = self.base_model(inputs)
        x = add_time_dimension(padded_inputs=x, seq_lens=seq_lens, framework='tf')

        model_out, self._value_out, h, c = self.rnn_model([x, seq_lens] + state)
        return model_out, [h, c]

    def get_initial_state(self) -> list:
        return [
            np.zeros(self._cell_size, np.float32),
            np.zeros(self._cell_size, np.float32)
        ]

    def value_function(self):
        return tf.reshape(self._value_out, [-1])

# Train LSTM PPO

In [None]:
def evaluate(agent, eval_env, eval_iterations, render):
    total_returns = 0.0

    for _ in range(eval_iterations):
        done = False
        state = eval_env.reset()
        policy_state = agent.get_policy().get_initial_state()

        if render:
            eval_env.render()

        while not done:
            action, policy_state, _ = agent.compute_single_action(observation=state, state=policy_state)
            state, reward, done, _ = eval_env.step(action)

            total_returns += reward

            if render:
                eval_env.render()
    return total_returns / eval_iterations


def train(
        agent,
        eval_env,
        train_iterations=50,
        iterations_per_eval=1,
        eval_iterations=5,
        plot_training=True,
        algo_name='Agent'
):
    average_returns = []

    for i in range(train_iterations):
        agent.train()

        if i % iterations_per_eval == 0:
            average_return = evaluate(agent, eval_env, eval_iterations=eval_iterations, render=False)
            average_returns.append(average_return)

            print(f'Iteration: {i}, Average Returns: {average_return}')

    if plot_training:
        plt.plot(average_returns)
        plt.title(f'{algo_name} Training Progress on CartPole')
        plt.xlabel('Iterations')
        plt.ylabel('Average Return')
        plt.show()
    return average_returns

# Build Agent

In [None]:
ray.shutdown()
ray.init()

tf.random.set_seed(seed=0)
np.random.seed(0)
random.seed(0)

ModelCatalog.register_custom_model('lstm_model', LSTMPPOModel)

ppo_agent = PPO(env=MaskedCartPoleEnv, config={
    'model': {
        'vf_share_layers': True,
        'custom_model': 'ppo_model',
        'custom_model_config': {},
        'max_seq_len': 20
    },
    'render_env': True,
    'num_workers': 1,
    'rollout_fragment_length': 256,
    'num_envs_per_worker': 1,
    'batch_mode': 'complete_episodes',
    'use_critic': True,
    'use_gae': True,
    'lambda': 0.95,
    'clip_param': 0.3,
    'kl_coeff': 0.2,
    'entropy_coeff': 0.01,
    'kl_target': 0.01,
    'vf_loss_coeff': 0.5,
    'shuffle_sequences': True,
    'num_sgd_iter': 40,
    'sgd_minibatch_size': 32,
    'train_batch_size': 512,
    'seed': 0,
    'gamma': 0.99,
    'lr': 0.0005,
    'num_gpus': 1
})

ppo_agent.get_policy().model.base_model.summary(expand_nested=True)
ppo_agent.get_policy().model.rnn_model.summary(expand_nested=True)

# Train

In [None]:
train(agent=ppo_agent, eval_env=MaskedCartPoleEnv(env_config={}))