# Libraries

In [1]:
import random
import ray
import tensorflow as tf
import utils
from gym import spaces
from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel
from ray.rllib.algorithms.apex_dqn import ApexDQN
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from environments.cartpole import CartPoleEnv

# Custom Model

In [2]:
class DQNCustomModel(DistributionalQTFModel):
    def __init__(
            self,
            obs_space: spaces.Space,
            action_space: spaces.Space,
            num_outputs: int,
            model_config: ModelConfigDict,
            name: str,
            q_hiddens=(256,),
            dueling=True,
            num_atoms=51,
            v_min=-1.0,
            v_max=1.0,
            use_noisy=True,
            sigma0=0.5,
            add_layer_norm=True,
            verbose=True
    ):
        super().__init__(
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            q_hiddens=q_hiddens,
            dueling=dueling,
            num_atoms=num_atoms,
            v_min=v_min,
            v_max=v_max,
            use_noisy=use_noisy,
            sigma0=sigma0,
            add_layer_norm=add_layer_norm
        )

        self.base_model = tf.keras.Sequential([
            tf.keras.layers.Input(shape=obs_space.shape, name='observations'),
            tf.keras.layers.Dense(units=256, activation='tanh', name='hidden'),
            tf.keras.layers.Dense(units=num_outputs, activation='tanh', name='outputs')
        ], name='base_model')

        if verbose and dueling:
            print('--- Value Network ---')
            self.state_value_head.summary(expand_nested=True)

            print('--- Q Network ---')
            self.q_value_head.summary(expand_nested=True)

    def forward(self, input_dict: dict[str, TensorType], state: list[TensorType], seq_lens: TensorType) -> (TensorType, list[TensorType]):
        model_out = self.base_model(input_dict['obs'])
        return model_out, state

# Default Model Train

In [3]:
ray.shutdown()
ray.init()
tf.random.set_seed(seed=0)
random.seed(0)
agent = ApexDQN(env=CartPoleEnv, config={
    'env_config': {'verbose': False},
    'num_workers': 4,
    'replay_buffer_config' : {
        "type": 'MultiAgentPrioritizedReplayBuffer',
        "capacity": 100000,
        'prioritized_replay': True,
        'prioritized_replay_alpha': 0.6,
        'prioritized_replay_beta': 0.4,
        'prioritized_replay_eps': 1e-6,
        'replay_sequence_length': 1,
    },
    'num_steps_sampled_before_learning_starts': 10000,
    'target_network_update_freq': 10000,
    'rollout_fragment_length': 4,
    'train_batch_size': 256,
    'n_step': 3,
    'double_q': True,
    'dueling': True,
    'noisy': True,
    'num_atoms': 51,
    'v_min': -500.0,
    'v_max': 500.0,
    'exploration_config': {
        'epsilon_timesteps': 2,
        'final_epsilon': 0.0
    },
    'seed': 0,
    'gamma': 0.99,
    'lr': 0.0005,
    'num_gpus': 1
})
agent.get_policy().model.base_model.summary(expand_nested=True)
utils.train(agent=agent, eval_env=CartPoleEnv(env_config={'verbose': False}))

2023-01-06 15:26:52,858	INFO worker.py:1538 -- Started a local Ray instance.
2023-01-06 15:26:54,180	INFO algorithm_config.py:2503 -- Your framework setting is 'tf', meaning you are using static-graph mode. Set framework='tf2' to enable eager execution with tf2.x. You may also then want to set eager_tracing=True in order to reach similar execution speed as with static-graph mode.
2023-01-06 15:26:54,203	INFO algorithm.py:501 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
2023-01-06 15:27:14,931	INFO trainable.py:172 -- Trainable.setup took 20.729 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 observations (InputLayer)      [(None, 4)]          0           []                               
                                                                                                  
 fc_1 (Dense)                   (None, 256)          1280        ['observations[0][0]']           
                                                                                                  
 fc_out (Dense)                 (None, 256)          65792       ['fc_1[0][0]']                   
                                                                                                  
 value_out (Dense)              (None, 1)            257         ['fc_1[0][0]']                   
                                                                                              

[2m[36m(MultiAgentPrioritizedReplayBuffer pid=19712)[0m 2023-01-06 15:27:15,269	INFO replay_buffer.py:63 -- Estimated max memory usage for replay buffer is 0.001725 GB (25000.0 batches of size 1, 69 bytes each), available system memory is 34.160103424 GB
[2m[36m(MultiAgentPrioritizedReplayBuffer pid=5168)[0m 2023-01-06 15:27:15,208	INFO replay_buffer.py:63 -- Estimated max memory usage for replay buffer is 0.001725 GB (25000.0 batches of size 1, 69 bytes each), available system memory is 34.160103424 GB
[2m[36m(MultiAgentPrioritizedReplayBuffer pid=3408)[0m 2023-01-06 15:27:15,298	INFO replay_buffer.py:63 -- Estimated max memory usage for replay buffer is 0.001725 GB (25000.0 batches of size 1, 69 bytes each), available system memory is 34.160103424 GB
[2m[36m(MultiAgentPrioritizedReplayBuffer pid=11516)[0m 2023-01-06 15:27:15,318	INFO replay_buffer.py:63 -- Estimated max memory usage for replay buffer is 0.001725 GB (25000.0 batches of size 1, 69 bytes each), available sys

Iteration: 0, Average Returns: 11.0
Iteration: 1, Average Returns: 23.4
Iteration: 2, Average Returns: 22.8
Iteration: 3, Average Returns: 102.0
Iteration: 4, Average Returns: 193.4
Iteration: 5, Average Returns: 127.8
Iteration: 6, Average Returns: 162.8
Iteration: 7, Average Returns: 184.0
Iteration: 8, Average Returns: 190.6
Iteration: 9, Average Returns: 208.0
Iteration: 10, Average Returns: 161.0
Iteration: 11, Average Returns: 141.0
Iteration: 12, Average Returns: 166.8
Iteration: 13, Average Returns: 355.4
Iteration: 14, Average Returns: 279.6
Iteration: 15, Average Returns: 243.6
Iteration: 16, Average Returns: 464.8
Iteration: 17, Average Returns: 500.0
Iteration: 18, Average Returns: 284.2
Iteration: 19, Average Returns: 500.0
Iteration: 20, Average Returns: 500.0
Iteration: 21, Average Returns: 500.0
Iteration: 22, Average Returns: 500.0
Iteration: 23, Average Returns: 500.0
Iteration: 24, Average Returns: 500.0
Iteration: 25, Average Returns: 500.0
Iteration: 26, Average Re


KeyboardInterrupt



#  Custom Model Train

In [None]:
ray.shutdown()
ray.init()
ModelCatalog.register_custom_model("dqn_model", DQNCustomModel)
tf.random.set_seed(seed=0)
random.seed(0)
agent = ApexDQN(env=CartPoleEnv, config={
    'env_config': {'verbose': False},
    'num_workers': 4,
    'model': {
        'custom_model': 'dqn_model',
        'custom_model_config': {}
    },
    'replay_buffer_config' : {
        "type": 'MultiAgentPrioritizedReplayBuffer',
        "capacity": 100000,
        'prioritized_replay': True,
        'prioritized_replay_alpha': 0.6,
        'prioritized_replay_beta': 0.4,
        'prioritized_replay_eps': 1e-6,
        'replay_sequence_length': 1,
    },
    'num_steps_sampled_before_learning_starts': 10000,
    'target_network_update_freq': 10000,
    'rollout_fragment_length': 4,
    'train_batch_size': 256,
    'n_step': 3,
    'double_q': True,
    'dueling': True,
    'noisy': True,
    'num_atoms': 51,
    'v_min': -500.0,
    'v_max': 500.0,
    'exploration_config': {
        'epsilon_timesteps': 2,
        'final_epsilon': 0.0
    },
    'seed': 0,
    'gamma': 0.99,
    'lr': 0.0005,
    'num_gpus': 1
})
agent.get_policy().model.base_model.summary(expand_nested=True)
utils.train(agent=agent, eval_env=CartPoleEnv(env_config={'verbose': False}))