# Libraries

In [44]:
import random
import ray
import tensorflow as tf
import utils
from typing import List
from gym import spaces
from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel
from ray.rllib.algorithms.apex_dqn import ApexDQN
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.typing import ModelConfigDict, TensorType, AlgorithmConfigDict, EnvCreator
from environments.knapsack import KnapsackEnv

from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.models.tf.tf_action_dist import get_categorical_class_with_temperature
from ray.rllib.utils.tf_utils import reduce_mean_ignore_inf

# Define Custom Model
* Multiple Inputs are concatenated
* Action Mask is applied

In [45]:
class DQNCustomModel(DistributionalQTFModel):
    def __init__(
            self,
            obs_space: spaces.Space,
            action_space: spaces.Space,
            num_outputs: int,
            model_config: ModelConfigDict,
            name: str,
            q_hiddens=(256,),
            dueling=True,
            num_atoms=51,
            v_min=-1.0,
            v_max=1.0,
            use_noisy=True,
            sigma0=0.5,
            add_layer_norm=True,
            verbose=True
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)

        assert (
                isinstance(orig_space, spaces.Dict)
                and "action_mask" in orig_space.spaces
                and "observations" in orig_space.spaces
        )

        super().__init__(
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            q_hiddens=q_hiddens,
            dueling=dueling,
            num_atoms=num_atoms,
            v_min=v_min,
            v_max=v_max,
            use_noisy=use_noisy,
            sigma0=sigma0,
            add_layer_norm=add_layer_norm
        )
        self._num_atoms = num_atoms

        weight_inputs = tf.keras.layers.Input(shape=orig_space['observations']['observation_weights'].shape, name='observation_weights')
        value_inputs = tf.keras.layers.Input(shape=orig_space['observations']['observation_values'].shape, name='observation_values')
        knapsack_inputs = tf.keras.layers.Input(shape=orig_space['observations']['observation_knapsack'].shape, name='observation_knapsack')
        inputs = tf.keras.layers.Concatenate(axis=-1)([weight_inputs, value_inputs, knapsack_inputs])
        hidden = tf.keras.layers.Dense(units=256, activation='tanh', name='hidden')(inputs)
        outputs = tf.keras.layers.Dense(units=num_outputs, activation='tanh', name='outputs')(hidden)
        self.base_model = tf.keras.Model(inputs=[weight_inputs, value_inputs, knapsack_inputs], outputs=outputs, name='base_model')

        self._action_mask = None

        if verbose and dueling:
            print('--- Value Network ---')
            self.state_value_head.summary(expand_nested=True)

            print('--- Q Network ---')
            self.q_value_head.summary(expand_nested=True)

    def forward(self, input_dict: dict[str, TensorType], state: list[TensorType], seq_lens: TensorType) -> (TensorType, list[TensorType]):
        self._action_mask = input_dict['obs']['action_mask']
        model_out = self.base_model(input_dict['obs']['observations'])
        return model_out, state

    def get_q_value_distributions(self, model_out: TensorType) -> List[TensorType]:
        q_values_out = super().get_q_value_distributions(model_out=model_out)
        inf_mask = tf.maximum(tf.math.log(self._action_mask), tf.float32.min)

        if self._num_atoms == 1:
            action_scores, logits, dist = q_values_out
            action_scores += inf_mask
            return [action_scores, logits, dist]
        else:
            action_scores, z, support_logits_per_action, logits, dist = q_values_out
            return [
                action_scores + inf_mask,
                z,
                support_logits_per_action + tf.expand_dims(inf_mask, axis=-1),
                logits,
                dist
            ]

# Train Custom Masked Model

In [46]:
ray.shutdown()
ray.init()
ModelCatalog.register_custom_model("dqn_model", DQNCustomModel)
tf.random.set_seed(seed=0)
random.seed(0)
agent = ApexDQN(env=KnapsackEnv, config={
    'env_config': {'verbose': False},
    'num_workers': 2,
    'model': {
        'custom_model': 'dqn_model',
        'custom_model_config': {}
    },
    'replay_buffer_config' : {
        "type": 'MultiAgentPrioritizedReplayBuffer',
        "capacity": 50000,
        'prioritized_replay': True,
        'prioritized_replay_alpha': 0.6,
        'prioritized_replay_beta': 0.4,
        'prioritized_replay_eps': 1e-6,
        'replay_sequence_length': 1,
    },
    'num_steps_sampled_before_learning_starts': 10000,
    'target_network_update_freq': 10000,
    'rollout_fragment_length': 4,
    'train_batch_size': 128,
    'n_step': 3,
    'double_q': True,
    'dueling': False,
    'noisy': True,
    'num_atoms': 51,
    'v_min': 1.0,
    'v_max': 30.0,
    'exploration_config': {
        'epsilon_timesteps': 2,
        'final_epsilon': 0.0
    },
    'seed': 0,
    'gamma': 0.99,
    'lr': 0.0005,
    'num_gpus': 1
})
agent.get_policy().model.base_model.summary(expand_nested=True)
utils.train(agent=agent, eval_env=KnapsackEnv(env_config={'verbose': False}))

2023-01-06 19:08:16,599	INFO worker.py:1538 -- Started a local Ray instance.
2023-01-06 19:08:31,660	INFO trainable.py:172 -- Trainable.setup took 13.849 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


Model: "base_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 observation_weights (InputLaye  [(None, 10)]        0           []                               
 r)                                                                                               
                                                                                                  
 observation_values (InputLayer  [(None, 10)]        0           []                               
 )                                                                                                
                                                                                                  
 observation_knapsack (InputLay  [(None, 1)]         0           []                               
 er)                                                                                     

[2m[36m(MultiAgentPrioritizedReplayBuffer pid=19312)[0m 2023-01-06 19:08:31,779	INFO replay_buffer.py:63 -- Estimated max memory usage for replay buffer is 0.0035625 GB (12500.0 batches of size 1, 285 bytes each), available system memory is 34.160103424 GB
[2m[36m(MultiAgentPrioritizedReplayBuffer pid=13368)[0m 2023-01-06 19:08:31,834	INFO replay_buffer.py:63 -- Estimated max memory usage for replay buffer is 0.0035625 GB (12500.0 batches of size 1, 285 bytes each), available system memory is 34.160103424 GB
[2m[36m(MultiAgentPrioritizedReplayBuffer pid=18604)[0m 2023-01-06 19:08:31,851	INFO replay_buffer.py:63 -- Estimated max memory usage for replay buffer is 0.0035625 GB (12500.0 batches of size 1, 285 bytes each), available system memory is 34.160103424 GB
[2m[36m(MultiAgentPrioritizedReplayBuffer pid=5216)[0m 2023-01-06 19:08:31,920	INFO replay_buffer.py:63 -- Estimated max memory usage for replay buffer is 0.0035625 GB (12500.0 batches of size 1, 285 bytes each), avai

Iteration: 0, Average Returns: 20.2
Iteration: 1, Average Returns: 26.0
Iteration: 2, Average Returns: 25.6
Iteration: 3, Average Returns: 26.0
Iteration: 4, Average Returns: 26.0
Iteration: 5, Average Returns: 26.0
Iteration: 6, Average Returns: 26.0
Iteration: 7, Average Returns: 26.0
Iteration: 8, Average Returns: 26.0
Iteration: 9, Average Returns: 26.0
Iteration: 10, Average Returns: 26.0
Iteration: 11, Average Returns: 26.0
Iteration: 12, Average Returns: 26.0
Iteration: 13, Average Returns: 26.0
Iteration: 14, Average Returns: 26.0
Iteration: 15, Average Returns: 25.6


KeyboardInterrupt: 