In [1]:
# Custom Model for LSTM
import numpy as np

import ray
import ray.rllib.algorithms.ppo as ppo
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.framework import try_import_torch

torch, _ = try_import_torch()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# The custom model that will be wrapped by an LSTM
from turtle import forward


class MyCustomModel(TorchModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
        self.num_outputs = int(np.product(self.obs_space.shape))
        self._last_batch_size = None

        # Implement your own frward logic, whose output will then be sent through an LSTM
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict['obs_flat'] 
        # Store last batch size for value_function ouput
        self._last_batch_size = obs.shape[0]
        # Return 2x the obs (and empty states)
        # This will further ve sent through an automaticallu provided LSTM head (b/c we are settting use_lstm=True below)
        return obs * 2.0, []

    def value_function(self):
        return torch.from_numpy(np.zeros(shape=(self._last_batch_size,)))

In [3]:
ray.init()

2022-09-15 10:26:49,327	INFO worker.py:1518 -- Started a local Ray instance.
[2m[33m(raylet)[0m   aiogrpc.init_grpc_aio()
[2m[33m(raylet)[0m   loop = asyncio.get_event_loop()


0,1
Python version:,3.10.4
Ray version:,2.0.0


In [4]:
# Register the above custom model
ModelCatalog.register_custom_model("my_torch_model", MyCustomModel)

# Create the Trainer
algo = ppo.PPO(
    env='CartPole-v0',
    config={
        "framework": "torch",
        "model": {
            # Auto-wrap the custom(!) model with an LSTM
            "use_lstm": True,
            # To further customize the LSTM autp-wrapper
            "lstm_cell_size": 64,
            # Specify out custom model from above
            "custom_model": "my_torch_model",
            # Extra kwargs to be passed to yout model's c'tor
            "custom_model_config": {}
        }
    }
)

2022-09-15 10:31:29,609	INFO ppo.py:378 -- In multi-agent mode, policies will be optimized sequentially by the multi-GPU optimizer. Consider setting simple_optimizer=True if this doesn't work for you.
2022-09-15 10:31:29,614	INFO algorithm.py:351 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


In [5]:
algo.train()

{'custom_metrics': {},
 'episode_media': {},
 'num_recreated_workers': 0,
 'info': {'learner': {'default_policy': {'learner_stats': {'allreduce_latency': 0.0,
     'grad_gnorm': 0.9096849176832424,
     'cur_kl_coeff': 0.20000000000000004,
     'cur_lr': 5.0000000000000016e-05,
     'total_loss': 8.564374659651069,
     'policy_loss': -0.02134779627605151,
     'vf_loss': 8.583539336727512,
     'vf_explained_var': -0.09675047948796262,
     'kl': 0.010915500630125587,
     'entropy': 0.6798596371245641,
     'entropy_coeff': 0.0},
    'model': {},
    'custom_metrics': {},
    'num_agent_steps_trained': 127.41935483870968}},
  'num_env_steps_sampled': 4000,
  'num_env_steps_trained': 4000,
  'num_agent_steps_sampled': 4000,
  'num_agent_steps_trained': 4000},
 'sampler_results': {'episode_reward_max': 70.0,
  'episode_reward_min': 8.0,
  'episode_reward_mean': 21.232432432432432,
  'episode_len_mean': 21.232432432432432,
  'episode_media': {},
  'episodes_this_iter': 185,
  'policy_re