# Rllib Usage flows
This notebook demonstrate various Rllib usage flows, beyond the basic ones
The goal is to be able to train an agent and then use its policy to do inference.

We want to be able to save trainer weights during training and then load it and do the inference.

## Basic flows
we've seen in the [Rllib docs](https://ray.readthedocs.io/en/latest/rllib-training.html#getting-started) that we can run rllib from command line:
```
rllib train --run DQN --env CartPole-v0  # --eager [--trace] for eager execution

rllib rollout \
    ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \
    --run DQN --env CartPole-v0 --steps 10000
```

Also, we can run the training from within python (see [Rllib Training APIs](https://ray.readthedocs.io/en/latest/rllib-training.html#rllib-training-apis)):

![rllib-api](./rllib-intro/rllib-api.svg)

Although there is a direct access to manually call the trainer's `train` method (see [this](https://ray.readthedocs.io/en/latest/rllib-training.html#basic-python-api)) , 
it is recommended to call it through Tune as in the following example:
```
import ray
from ray import tune

ray.init()
tune.run(
    "PPO",
    stop={"episode_reward_mean": 200},
    config={
        "env": "CartPole-v0",
        "num_gpus": 0,
        "num_workers": 1,
        "lr": tune.grid_search([0.01, 0.001, 0.0001]),
        "eager": False,
    },
)
``` 
here we provided the trainer as string "PPO". we were able to do so because PPO is already registered.

Alternatively, we could define a trainer and send it to tune directly - as done in [custom_tf_policy.py](https://github.com/ray-project/ray/blob/master/rllib/examples/custom_tf_policy.py):
```
# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
MyTFPolicy = build_tf_policy(
    name="MyTFPolicy",
    loss_fn=policy_gradient_loss,
    postprocess_fn=calculate_advantages,
)

# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
    name="MyCustomTrainer",
    default_policy=MyTFPolicy,
)

if __name__ == "__main__":
    args = parser.parse_args()
    ray.init(num_cpus=args.num_cpus or None)
    tune.run(
        MyTrainer,
        stop={"training_iteration": args.iters},
        config={
            "env": "CartPole-v0",
            "num_workers": 2,
            "num_gpus":args.num_gpus,       # GuyK
        })
```


or we could also give it a custom training function:


In [None]:
import ray
from ray import tune
from ray.rllib.agents.ppo import PPOTrainer

def train(config, reporter):
    trainer = PPOTrainer(config=config, env=YourEnv)
    while True:
        result = trainer.train()
        reporter(**result)
        if result["episode_reward_mean"] > 200:
            phase = 2
        elif result["episode_reward_mean"] > 100:
            phase = 1
        else:
            phase = 0
        trainer.workers.foreach_worker(
            lambda ev: ev.foreach_env(
                lambda env: env.set_phase(phase)))

ray.init()
tune.run(
    train,
    config={
        "num_gpus": 0,
        "num_workers": 2,
    },
    resources_per_trial={
        "cpu": 1,
        "gpu": lambda spec: spec.config.num_gpus,
        "extra_cpu": lambda spec: spec.config.num_workers,
    },
)

**Note** 
All RLlib trainers are compatible with the Tune API. This enables them to be easily used in experiments with Tune. 
They inherit from a base class called `Trainable` who's [API](https://ray.readthedocs.io/en/latest/tune-usage.html#trainable-api) allows advanced operations

in the following, we'll try to form a flow that will allow the following:
1. train an agent and save to file
2. load a policy model from file and do inference


# Training flow with checkpoints
in this section, we'll form a flow that uses [Tune Training API](https://ray.readthedocs.io/en/latest/tune-usage.html#tune-training-api) and saves a checkpoint of the trainer s.t. we can later load and evaluate. 

Lets assume that we have in the config file some parameters that are related to saving checkpoints.
1. we can save checkpoint synchronously - e.g. every X timesteps in the environment
2. we can save checkpoint asynchronously - e.g. when we break a record in evaluation score (be it on simulation or using OPE)


it looks like it should be via callback. 
I should read a `.yaml` file as configuration. not sure it can be consumed directly by rllib. 
in any case, I will digest it and prepare a config dict to send to tune.

The script will be implemented in <font color='blue'>**train_w_tune.py**</font>


## Some background info 
in this script we'll still use the registry for both the agent and environment.
we'll try to use the call backs to save checkpoints.  
There are several possible callbacks (see [callbacks and custom metrics](https://ray.readthedocs.io/en/latest/rllib-training.html#callbacks-and-custom-metrics)):
```
analysis = tune.run(
    "PG",
    config={
        "env": "CartPole-v0",
        "callbacks": {
            "on_episode_start": on_episode_start,
            "on_episode_step": on_episode_step,
            "on_episode_end": on_episode_end,
            "on_train_result": on_train_result,
            "on_postprocess_traj": on_postprocess_traj,
        },
    },
)
```
it looks like the keys of the `callbacks` are specific predefined locations in the agent's code that it checks whether there is a callback and if there is, it calls it from there.

There are 3 questions here:
1. what are all the possible callbacks that tune supports
2. where are they located in the code
3. what is the information that we can get in the callback ? do we have access to the trainer s.t. we can invoke `trainer.save()` ?


### what are all the possible callbacks
we build the trainer with `build_trainer` function. this function builds a `Trainer` object.  
This `Trainer` is defined in [trainer.py](https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py). 
in there, you can also find the defult configuration dict (called [`COMMON_CONFIG`](https://ray.readthedocs.io/en/latest/rllib-training.html#common-parameters)). 
in this dictionary one can find the possible callbacks:
```
    "callbacks": {
        "on_episode_start": None,     # arg: {"env": .., "episode": ...}
        "on_episode_step": None,      # arg: {"env": .., "episode": ...}
        "on_episode_end": None,       # arg: {"env": .., "episode": ...}
        "on_sample_end": None,        # arg: {"samples": .., "worker": ...}
        "on_train_result": None,      # arg: {"trainer": ..., "result": ...}
        "on_postprocess_traj": None,  # arg: {
                                      #   "agent_id": ..., "episode": ...,
                                      #   "pre_batch": (before processing),
                                      #   "post_batch": (after processing),
                                      #   "all_pre_batches": (other agent ids),
                                      # }
    },
```
in addition, it looks like there are additional 'events' given in the `build_trainer` function. for example, if we look at the [`dqn.py`](https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/dqn.py) file:
```
GenericOffPolicyTrainer = build_trainer(
    name="GenericOffPolicyAlgorithm",
    default_policy=None,
    default_config=DEFAULT_CONFIG,
    validate_config=validate_config_and_setup_param_noise,
    get_initial_state=get_initial_state,
    make_policy_optimizer=make_policy_optimizer,
    before_train_step=update_worker_exploration,
    after_optimizer_step=update_target_if_needed,
    after_train_result=after_train_result,
    execution_plan=execution_plan)
```
we see that there are `before_train_step`, `after_optimizer_step` and `after_train_result`. what are these ? 
in [`trainer_template.py`](https://github.com/ray-project/ray/blob/master/rllib/agents/trainer_template.py) we can find the documentation of these functions:
- **after_init** (func) : optional function to run at the end of trainer init that takes the trainer instance as argument  
- **before_train_step** (func): optional callback to run **before each train() call**. It takes the trainer instance as an argument.
- **after_optimizer_step** (func): optional callback to run after each step() call to the policy optimizer. It takes the trainer instance and the policy gradient fetches as arguments.  
- **after_train_result** (func): optional callback to run at the end of each train() call. It takes the trainer instance and result dict as arguments, and may mutate the result dict as needed.  
- **collect_metrics_fn** (func): override the method used to collect metrics. It takes the trainer instance as argumnt.  
- **before_evaluate_fn** (func): callback to run before evaluation. This takes the trainer instance as argument.

all these callback functions are called from within the `_train` implementation of the `trainer_cls` that is built (see note below).

> **Note** the callback function in the configuration all get `info` dict as an argument. the functions defined in the `build_trainer` take the trainer instance as an argument. we need to understand what is this `info` dict but it looks like it wont be enough to save a checkpoint. It looks like the more (or only) appropriate place to do it is in the `after_train_result`.


> Note that the `build_trainer` function takes a `Trainer` class as a base class, add some `mixins` and define a subclass `trainer_cls` which is returned to the caller of `build_trainer`.  
> Note that `trainer_cls` has implementation of both `__init__` and `_init`.  the `_init` is called from `_setup` of the parent class (`Trainer`). and the `Trainer._setup` is called from the `Trainer`'s parent class (`Trainable`) during its `__init__`. 

if we look at the documentation on [Contributing to Rllib](https://ray.readthedocs.io/en/latest/rllib-dev.html#contributing-to-rllib) we see that they describe how to create an agent:  
"*It takes just two changes to add an algorithm to contrib. A minimal example can be found here. First, subclass `Trainer` and implement the `_init` and `_train` methods*"  
and this is exactly what `build_trainer` does. so we'll use it.


### How the callbacks are related ?
so we see we have 2 ways to define callbacks. what is the relation between them ?
Note that the call to `tune.run` can be done without having access to the trainer definition. i.e. we can `pip install ray` and write a script that doesnt need to change the code of the trainer in order to define the callback.

#### `on_train_result` vs `after_train_result`
`on_train_result` is the callback that `tune.run` gets in the `config['callbacks']`. we can see that we can define this function without having access to the trainer definition code. we can use a registered trainer ("PG", "DQN", "PPO") and this will use the trainer registered with these strings. in the above example we used "PG", but we could have used "DQN" and then it would use `DQNTrainer` that is defined in the rllib `dqn.py`:
```
GenericOffPolicyTrainer = build_trainer(
    name="GenericOffPolicyAlgorithm",
    default_policy=None,
    default_config=DEFAULT_CONFIG,
    validate_config=validate_config_and_setup_param_noise,
    get_initial_state=get_initial_state,
    make_policy_optimizer=make_policy_optimizer,
    before_train_step=update_worker_exploration,
    after_optimizer_step=update_target_if_needed,
    after_train_result=after_train_result,
    execution_plan=execution_plan)

DQNTrainer = GenericOffPolicyTrainer.with_updates(
    name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)
```

Note that within this `dqn.py` we define the `DQNTrainer` with a function called `after_train_result`. which is described in the following


`after_train_result` 
if we build a custom trainer using the `build_trainer`, we could also define the `after_train_result` callback by ourselves and then do whatever we need there. 


**So when each of these routines is called ?**

Let's follow the `train()` method. 
as we described, the `build_trainer` returns a `trainer_cls` which is child of `Trainer` and implements the `_train()` method.
when we provide tune with a trainer (e.g. "DQN") , it invokes its `train()` method. 
this `train()` is implemented in `Trainer` :

```
    @override(Trainable)
    @PublicAPI
    def train(self):
        """Overrides super.train to synchronize global vars."""
        ...
        result = None
        for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
            try:
                result = Trainable.train(self)
            except RayError as e:
                ...

        if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
            self._sync_filters_if_needed(self.workers)

        ...
        
        if self.config["evaluation_interval"] == 1 or (
                self._iteration > 0 and self.config["evaluation_interval"]
                and self._iteration % self.config["evaluation_interval"] == 0):
            evaluation_metrics = self._evaluate()
            assert isinstance(evaluation_metrics, dict), \
                "_evaluate() needs to return a dict."
            result.update(evaluation_metrics)

        return result

```
so we see a call to the parent `Trainable.train()` method with the instance of the `Trainer`. let's look at this function:
```
    def train(self):
        """Runs one logical iteration of training.

        Subclasses should override ``_train()`` instead to return results.
        ...

        Returns:
            A dict that describes training progress.
        """
        start = time.time()
        result = self._train()
        assert isinstance(result, dict), "_train() needs to return a dict."

        # We do not modify internal state nor update this result if duplicate.
        if RESULT_DUPLICATE in result:
            return result

        result = result.copy()

        self._iteration += 1
        self._iterations_since_restore += 1

        if result.get(TIME_THIS_ITER_S) is not None:
            time_this_iter = result[TIME_THIS_ITER_S]
        else:
            time_this_iter = time.time() - start
        self._time_total += time_this_iter
        self._time_since_restore += time_this_iter

        result.setdefault(DONE, False)

        # self._timesteps_total should only be tracked if increments provided
        if result.get(TIMESTEPS_THIS_ITER) is not None:
            if self._timesteps_total is None:
                self._timesteps_total = 0
            self._timesteps_total += result[TIMESTEPS_THIS_ITER]
            self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER]

        # self._episodes_total should only be tracked if increments provided
        if result.get(EPISODES_THIS_ITER) is not None:
            if self._episodes_total is None:
                self._episodes_total = 0
            self._episodes_total += result[EPISODES_THIS_ITER]

        # self._timesteps_total should not override user-provided total
        result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
        result.setdefault(EPISODES_TOTAL, self._episodes_total)
        result.setdefault(TRAINING_ITERATION, self._iteration)

        # Provides auto-filled neg_mean_loss for avoiding regressions
        if result.get("mean_loss"):
            result.setdefault("neg_mean_loss", -result["mean_loss"])

        now = datetime.today()
        result.update(
            experiment_id=self._experiment_id,
            date=now.strftime("%Y-%m-%d_%H-%M-%S"),
            timestamp=int(time.mktime(now.timetuple())),
            time_this_iter_s=time_this_iter,
            time_total_s=self._time_total,
            pid=os.getpid(),
            hostname=os.uname()[1],
            node_ip=self._local_ip,
            config=self.config,
            time_since_restore=self._time_since_restore,
            timesteps_since_restore=self._timesteps_since_restore,
            iterations_since_restore=self._iterations_since_restore)

        monitor_data = self._monitor.get_data()
        if monitor_data:
            result.update(monitor_data)

        self._log_result(result)

        return result

```

There are 2 important things to note:
1. at the beginning there is a call to `result = self._train()` that is implemented by the subclass
2. at the end there is a call to `self._log_result(result)` 


first, let's look into `self._train()` that is implemented in the `trainer_cls` :
```
        def _train(self):
            if self.train_exec_impl:
                return self._train_exec_impl()

            if before_train_step:
                before_train_step(self)
            prev_steps = self.optimizer.num_steps_sampled

            start = time.time()
            while True:
                fetches = self.optimizer.step()
                if after_optimizer_step:
                    after_optimizer_step(self, fetches)
                if (time.time() - start >= self.config["min_iter_time_s"]
                        and self.optimizer.num_steps_sampled - prev_steps >=
                        self.config["timesteps_per_iteration"]):
                    break

            if collect_metrics_fn:
                res = collect_metrics_fn(self)
            else:
                res = self.collect_metrics()
            res.update(
                timesteps_this_iter=self.optimizer.num_steps_sampled -
                prev_steps,
                info=res.get("info", {}))

            if after_train_result:
                after_train_result(self, res)
            return res

```
we see that within this method, we call the callbacks provided as argument to `build_trainer`. specifically, see at the bottom the call to `after_train_result(self,res)`. 



second, if we look at the line `self._log_result(result)` towards the end of `train()` implementation (in the `Trainer`). it calls the following `Trainer._log_result()`:
```
    @override(Trainable)
    def _log_result(self, result):
        if self.config["callbacks"].get("on_train_result"):
            self.config["callbacks"]["on_train_result"]({
                "trainer": self,
                "result": result,
            })
        # log after the callback is invoked, so that the user has a chance
        # to mutate the result
        Trainable._log_result(self, result)
```
we see that it calls the callback we have defined in the `config['Callbacks']` argument to `tune.run()`.

in this manner we can track each of the callbacks provided to `tune.run` and see at which point they are called. 

**Bottom line**
we should strive to define the model/checkpoint store using the callbacks provided to `tune.run()` s.t. we dont have to define a callback per each trainer (e.g. DQN, PG, PPO etc.). 
if we were defining it as input to `build_trainer` we should have implemented the callback in each of the trainers definition files (dqn.py, ppo.py etc.)

while tracking the `save` method of the trainer, I noticed that it's base class (`Trainable`) has method `export_model`.
This method calls an internal `_export_model` method that should be implemented by `Trainable` subclass. 
and it is indeed implemented in rllib's `Trainer` class.
```
    def _export_model(self, export_formats, export_dir):
        ExportFormat.validate(export_formats)
        exported = {}
        if ExportFormat.CHECKPOINT in export_formats:
            path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
            self.export_policy_checkpoint(path)
            exported[ExportFormat.CHECKPOINT] = path
        if ExportFormat.MODEL in export_formats:
            path = os.path.join(export_dir, ExportFormat.MODEL)
            self.export_policy_model(path)
            exported[ExportFormat.MODEL] = path
        return exported
```

as we can see it calls `self.export_policy_checkpoint` and `self.export_policy_model`. 
I think that as the trainer has multiple workers, this is the place to call on each worker to save its policy instance. 
I guess that this function will use the policy API to save the checkpoint. 

the interaction between the `Trainer` object and the `policy` object is that the trainer manages the workers and each has a policy object to run with. 
see for example the implementation of `export_policy_checkpoint` in the Trainer: 
```
    @DeveloperAPI
    def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
        """Export policy model with given policy_id to local directory.

        self.workers.local_worker().export_policy_model(export_dir, policy_id)

    @DeveloperAPI
    def export_policy_checkpoint(self,
                                 export_dir,
                                 filename_prefix="model",
                                 policy_id=DEFAULT_POLICY_ID):
        """Export tensorflow policy model checkpoint to local directory.

        self.workers.local_worker().export_policy_checkpoint(
            export_dir, filename_prefix, policy_id)
```

Question: how does `export_policy_model` and `export_policy_checkpoint` is implemented ?
we see that the trainer calls the `local_worker` method on its `workers` member (`WorkerSet` type). if we look at `WorkerSet` implmenetation ([worker_set.py](https://github.com/ray-project/ray/blob/master/rllib/evaluation/worker_set.py))we see that the `local_worker` is a `Rolloutworker` :
```
# in  WorkerSet.__init__:
    # Always create a local worker
    self._local_worker = self._make_worker(
        RolloutWorker, env_creator, policy, 0, self._local_config)
```

and this `RolloutWorker` is implemented in [rollout_worker.py](https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py) and has a method that calls to its policy object to export the model:

```
    @DeveloperAPI
    def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
        self.policy_map[policy_id].export_model(export_dir)

```

in a very similar way the checkpoint is saved.


### rllib Policy object
[`CLASS ray.rllib.policy.Policy(observation_space, action_space, config)`](https://ray.readthedocs.io/en/latest/rllib-package-ref.html#module-ray.rllib.policy)  
This object defines how to act in the environment, and also losses used to improve the policy based on its experiences. Note that both policy and loss are defined together for convenience, though the policy itself is logically separate.

All policies can directly extend Policy, however TensorFlow users may find TFPolicy simpler to implement. TFPolicy also enables RLlib to apply TensorFlow-specific optimizations such as fusing multiple policy graphs and multi-GPU support.

**TODO** need to check the relation between the policy object and the policy optimizer.
the [Rllib policy](https://ray.readthedocs.io/en/latest/rllib-package-ref.html#module-ray.rllib.policy) object has also an `export` method. when is it called ? 

**TODO** look at `TFPolicy` and `build_tf_policy` to understand whether I should directly save the policy model at the end and not a checkpoint of the trainer.

**How a policy object is created ?**
when we build a trainer using `build_trainer` we provide it with `deafult_policy` which is the policy class that we want to train.  for example, in dqn:
```
GenericOffPolicyTrainer = build_trainer(
    name="GenericOffPolicyAlgorithm",
    default_policy=None,
    default_config=DEFAULT_CONFIG,
    validate_config=validate_config_and_setup_param_noise,
    get_initial_state=get_initial_state,
    make_policy_optimizer=make_policy_optimizer,
    before_train_step=update_worker_exploration,
    after_optimizer_step=update_target_if_needed,
    after_train_result=after_train_result,
    execution_plan=execution_plan)

DQNTrainer = GenericOffPolicyTrainer.with_updates(name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)
```
This policy class (e.g. `DQNTFPolicy`) can be built in 2 main ways, similar to `Trainer` :
1. Directly inherit from the `Policy` class (or one of its descendants : `TFPolicy` &rarr; `DynamicTFPolic`)
1. Use `build_tf_policy` to build the class. for example in dqn_policy.py:
```
DQNTFPolicy = build_tf_policy(
    name="DQNTFPolicy",
    get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
    make_model=build_q_model,
    action_sampler_fn=sample_action_from_q_network,
    log_likelihood_fn=get_log_likelihood,
    loss_fn=build_q_losses,
    stats_fn=build_q_stats,
    postprocess_fn=postprocess_nstep_and_prio,
    optimizer_fn=adam_optimizer,
    gradients_fn=clip_gradients,
    extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
    extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
    before_init=setup_early_mixins,
    before_loss_init=setup_mid_mixins,
    after_init=setup_late_mixins,
    obs_include_prev_action_reward=False,
    mixins=[
        ParameterNoiseMixin,
        TargetNetworkMixin,
        ComputeTDErrorMixin,
        LearningRateSchedule,
    ])
```

    this builder is implemented in [`tf_policy_template.py`](https://github.com/ray-project/ray/blob/master/rllib/policy/tf_policy_template.py) works very similar to `build_trainer`. it defines a base class that is constructed from `DynamicTFPolicy` and add some mixins that are provided as argument (see last argument in the above example) - and return this `policy_cls`.  

All the mechanism of saving checkpoint or model is implemented in the `TFPolicy` (or the equivalent Pytorch class).

What is the difference between saving a model and a checkpoint ? see [the implementation](https://ray.readthedocs.io/en/latest/_modules/ray/rllib/policy/tf_policy.html#TFPolicy.export_checkpoint)




## Forming the flow
OK, given the above, lets form the code for training while saving checkpoints

it turns out that I can also let `tune.run` do the checkpoint management automatically (I searched for 'checkpoint' in the help):
```
    trials = tune.run(
        "PG",
        stop={
            "training_iteration": args.num_iters,
        },
        config={
            "num_gpus":1,
            "env": "CartPole-v0",
            "callbacks": {
                "on_episode_start": on_episode_start,
                "on_episode_step": on_episode_step,
                "on_episode_end": on_episode_end,
                "on_sample_end": on_sample_end,
                "on_train_result": on_train_result,
                "on_postprocess_traj": on_postprocess_traj,
            },
        },
        checkpoint_at_end=True,
        checkpoint_freq=50,
        return_trials=True)

```
`checkpoint_at_end` asks to save checkpoint at the end of the run
`checkpoint_freq` will instruct tune to save checkpoint every 50 iterations.

we'll see the checkpoints in the results directory with `checkpoint_<iter>`.

how do we load the checkpoint ? 
this is a checkpoint of the trainer. to use it for prediction, I followed the `rollout.py` code. see the following example:

In [1]:
path_to_trial_results='/home/guy/ray_results/PG/PG_CartPole-v0_0_2020-04-01_17-56-43vl31kz32'

In [2]:
cd $path_to_trial_results

/home/guy/ray_results/PG/PG_CartPole-v0_0_2020-04-01_17-56-43vl31kz32


In [3]:
!tree

[01;34m.[00m
├── [01;34mcheckpoint_100[00m
│   ├── checkpoint-100
│   └── checkpoint-100.tune_metadata
├── [01;34mcheckpoint_110[00m
│   ├── checkpoint-110
│   └── checkpoint-110.tune_metadata
├── [01;34mcheckpoint_50[00m
│   ├── checkpoint-50
│   └── checkpoint-50.tune_metadata
├── events.out.tfevents.1585753003.guy-970
├── params.json
├── params.pkl
├── [01;34mpolicy_model.h5[00m
│   ├── events.out.tfevents.1585756992.guy-970
│   ├── saved_model.pb
│   └── [01;34mvariables[00m
│       ├── variables.data-00000-of-00001
│       └── variables.index
├── progress.csv
└── result.json

5 directories, 15 files


In [4]:
import tensorflow as tf
import ray
import pickle
import numpy as np
from ray.rllib.agents.registry import get_agent_class
ray.init()

2020-04-02 07:55:53,998	INFO resource_spec.py:212 -- Starting Ray with 13.62 GiB memory available for workers and up to 6.82 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2020-04-02 07:55:54,348	INFO services.py:1120 -- View the Ray dashboard at [1m[32mlocalhost:8265[39m[22m


{'node_ip_address': '192.168.1.93',
 'redis_address': '192.168.1.93:55509',
 'object_store_address': '/tmp/ray/session_2020-04-02_07-55-53_997504_23713/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2020-04-02_07-55-53_997504_23713/sockets/raylet',
 'webui_url': 'localhost:8265',
 'session_dir': '/tmp/ray/session_2020-04-02_07-55-53_997504_23713'}

In [5]:
# we have to assume we know what agent was it...
cls=get_agent_class("PG")
with open('params.pkl',"rb") as f:
     config=pickle.load(f)
agent=cls(env=config['env'],config=config)

2020-04-02 07:55:59,297	INFO trainer.py:427 -- Tip: set 'eager': true or the --eager flag to enable TensorFlow eager execution
2020-04-02 07:55:59,322	INFO trainer.py:584 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
2020-04-02 07:56:01,713	INFO trainable.py:217 -- Getting current IP.


now we have the agent, we can restore a checkpoint:

In [6]:
agent.restore('checkpoint_100/checkpoint-100')
policy=agent.get_policy()
policy

2020-04-02 07:56:07,775	INFO trainable.py:217 -- Getting current IP.
2020-04-02 07:56:07,775	INFO trainable.py:423 -- Restored on 192.168.1.93 from checkpoint: checkpoint_100/checkpoint-100
2020-04-02 07:56:07,776	INFO trainable.py:430 -- Current state after restoring: {'_iteration': 100, '_timesteps_total': None, '_time_total': 109.65056538581848, '_episodes_total': 548}


<ray.rllib.policy.tf_policy_template.PGTFPolicy at 0x7fe5cc663050>

In [7]:
policy.model.base_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
observations (InputLayer)       [(None, 4)]          0                                            
__________________________________________________________________________________________________
fc_1 (Dense)                    (None, 256)          1280        observations[0][0]               
__________________________________________________________________________________________________
fc_2 (Dense)                    (None, 256)          65792       fc_1[0][0]                       
__________________________________________________________________________________________________
fc_out (Dense)                  (None, 2)            514         fc_2[0][0]                       
______________________________________________________________________________________________

In [8]:
import gym
e=gym.make(config['env'])
e

<TimeLimit<CartPoleEnv<CartPole-v0>>>

In [9]:
o=e.reset()
o=e.observation_space.sample()
o

array([-3.1259754e+00, -2.3251294e+38, -4.0303564e-01, -8.3273986e+37],
      dtype=float32)

In [10]:
oo=np.vstack([e.observation_space.sample() for _ in range(5)])
oo.shape

(5, 4)

In [11]:
policy.compute_actions(oo,explore=False)

(array([0, 1, 1, 1, 0]),
 [],
 {'action_prob': array([1., 1., 1., 1., 1.], dtype=float32),
  'action_logp': array([0., 0., 0., 0., 0.], dtype=float32)})

In [28]:
policy.export_model('policy_model.h5')

In [33]:
ll 

total 78344
drwxrwxr-x 2 guy     4096 Apr  1 18:00 [0m[01;34mcheckpoint_100[0m/
drwxrwxr-x 2 guy     4096 Apr  1 18:00 [01;34mcheckpoint_110[0m/
drwxrwxr-x 2 guy     4096 Apr  1 17:58 [01;34mcheckpoint_50[0m/
-rw-rw-r-- 1 guy   510442 Apr  1 18:00 events.out.tfevents.1585753003.guy-970
-rw-rw-r-- 1 guy      494 Apr  1 17:56 params.json
-rw-rw-r-- 1 guy     2223 Apr  1 17:56 params.pkl
drwxr-xr-x 3 guy     4096 Apr  1 19:03 [01;34mpolicy_model.h5[0m/
-rw-rw-r-- 1 guy 39618902 Apr  1 18:00 progress.csv
-rw-rw-r-- 1 guy 40058562 Apr  1 18:00 result.json


now we need to see how to load the model...

In [36]:
ll policy_model.h5/

total 732
-rw-rw-r-- 1 guy 356788 Apr  1 19:03 events.out.tfevents.1585756992.guy-970
-rw-rw-r-- 1 guy 382025 Apr  1 19:03 saved_model.pb
drwxr-xr-x 2 guy   4096 Apr  1 19:03 [0m[01;34mvariables[0m/


it looks like the export_model simply export a model for serving. not clear how it can be used like with the following:

In [34]:
m=tf.keras.models.load_model('policy_model.h5/saved_model.pb')

OSError: SavedModel file does not exist at: policy_model.h5/saved_model.pb/{saved_model.pbtxt|saved_model.pb}

<font color='red'> TODO: figure out how to save the model s.t. it can be loaded as in the previous cell </font>

In [12]:
policy.model.base_model.save('base_model.h5')

RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

In [13]:
policy.model.base_model.save_weights('model_weights.h5')

RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

In [24]:
policy.model.base_model

<tensorflow.python.keras.engine.training.Model at 0x7fe5cc25fbd0>