In [3]:
import sys
from pathlib import Path
sys.path.append(str(Path().resolve().parent))

In [4]:
import gym
from starlette.requests import Request
import requests

import ray
import ray.rllib.agents.ppo as ppo
from ray import serve

In [5]:
def train_ppo_model():
    trainer = ppo.PPOTrainer(
        config={
            "framework": "torch",
            "num_workers": 0
        },
        env="CartPole-v0",
    )
    # Train for one iteration
    trainer.train()
    checkpoint = trainer.save("./tmp/rllib_checkpoint")
    print(checkpoint)
    print(trainer.save())
    return "./tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"


checkpoint_path = train_ppo_model()

./tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1
/home/napnel/ray_results/PPO_CartPole-v0_2022-02-11_17-27-342qqst2jo/checkpoint_000001/checkpoint-1


In [6]:
@serve.deployment(route_prefix="/cartpole-ppo")
class ServePPOModel:
    def __init__(self, checkpoint_path) -> None:
        self.trainer = ppo.PPOTrainer(
            config={
                "framework": "torch",
                # only 1 "local" worker with an env (not really used here).
                "num_workers": 0,
            },
            env="CartPole-v0")
        self.trainer.restore(checkpoint_path)

    async def __call__(self, request: Request):
        json_input = await request.json()
        obs = json_input["observation"]

        action = self.trainer.compute_single_action(obs, explore=False)
        # action = self.trainer.compute_action(obs, explore=False)
        return {"action": int(action)}

In [7]:
serve.start()
ServePPOModel.deploy(checkpoint_path)

2022-02-11 17:27:47,314	INFO services.py:1338 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
[2m[36m(ServeController pid=15063)[0m 2022-02-11 17:27:48,592	INFO checkpoint_path.py:16 -- Using RayInternalKVStore for controller checkpoint and recovery.
[2m[36m(ServeController pid=15063)[0m 2022-02-11 17:27:48,595	INFO http_state.py:98 -- Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:dcpUYk:SERVE_PROXY_ACTOR-node:192.168.118.75-0' on node 'node:192.168.118.75-0' listening on '127.0.0.1:8000'
2022-02-11 17:27:48,812	INFO api.py:463 -- Started Serve instance in namespace 'serve'.
2022-02-11 17:27:48,821	INFO api.py:242 -- Updating deployment 'ServePPOModel'. component=serve deployment=ServePPOModel
[2m[36m(HTTPProxyActor pid=15056)[0m INFO:     Started server process [15056]
[2m[36m(ServeController pid=15063)[0m 2022-02-11 17:27:48,920	INFO deployment_state.py:912 -- Adding 1 replicas to deployment 'ServePPOModel'. component=serve deployment=ServePP

In [8]:
env = gym.make("CartPole-v0")
obs = env.reset()

In [9]:
# That's it! Let's test it
for _ in range(10):
    print(f"-> Sending observation {obs}")
    resp = requests.get(
        "http://localhost:8000/cartpole-ppo",
        json={"observation": obs.tolist()})
    print(f"<- Received response {resp.json()}")
# Output:
# <- Received response {'action': 1}
# -> Sending observation [0.04228249 0.02289503 0.00690076 0.03095441]
# <- Received response {'action': 0}
# -> Sending observation [ 0.04819471 -0.04702759 -0.00477937 -0.00735569]
# <- Received response {'action': 0}
# ...

-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.01783823  0.04272411  0.01209767  0.04633382]
<- Received response {'action': 1}
-> Sending observation [-0.0

[2m[36m(ServePPOModel pid=15054)[0m E0211 17:28:38.046864600   16354 backup_poller.cc:134]       Run client channel backup poller: {"created":"@1644568118.046833000","description":"pollset_work","file":"src/core/lib/iomgr/ev_epollex_linux.cc","file_line":320,"referenced_errors":[{"created":"@1644568118.046827700","description":"Bad file descriptor","errno":9,"file":"src/core/lib/iomgr/ev_epollex_linux.cc","file_line":950,"os_error":"Bad file descriptor","syscall":"epoll_wait"}]}
