In [None]:
# tag::gym_mock[]
import gym


class Env:

    action_space: gym.spaces.Space
    observation_space: gym.spaces.Space  # <1>

    def step(self, action):  # <2>
        ...

    def reset(self):  # <3>
        ...

    def render(self, mode="human"):  # <4>
        ...
# end::gym_mock[]

In [None]:
# tag::rllib_dqn_simple[]
from ray.tune.logger import pretty_print
from maze_gym_env import GymEnvironment
from ray.rllib.algorithms.dqn import DQNConfig


config = DQNConfig().environment(GymEnvironment).rollouts(num_rollout_workers=4)
pretty_print(config.to_dict())

algo = config.build()  # <1>

for i in range(10):
    result = algo.train()  # <2>

print(pretty_print(result))  # <3>
# end::rllib_dqn_simple[]


# tag::rllib_simple_save[]
from ray.rllib.algorithms.algorithm import Algorithm


checkpoint = algo.save()  # <1>
print(checkpoint)

restored_algorithm = Algorithm.from_checkpoint(checkpoint)  # <2>

evaluation = algo.evaluate()  # <3>
print(pretty_print(evaluation))

# end::rllib_simple_save[]

# TODO: if I pretty print in the loop above, the "evaluation" has only NaNs.

# tag::rllib_manual_rollout[]
env = GymEnvironment()
done = False
total_reward = 0
observations = env.reset()

while not done:
    action = algo.compute_single_action(observations)  # <1>
    observations, reward, done, info = env.step(action)
    total_reward += reward
# end::rllib_manual_rollout[]

# tag::rllib_actions[]
action = algo.compute_actions(  # <1>
    {"obs_1": observations, "obs_2": observations}
)
print(action)
# {'obs_1': 0, 'obs_2': 1}
# end::rllib_actions[]

# tag::rllib_policy[]
policy = algo.get_policy()
print(policy.get_weights())

model = policy.model
# end::rllib_policy[]

# tag::rllib_workers[]
workers = algo.workers
workers.foreach_worker(
    lambda remote_trainer: remote_trainer.get_policy().get_weights()
)
# end::rllib_workers[]

# tag::rllib_q_network[]
model.base_model.summary()

# end::rllib_q_network[]

# tag::rllib_model_output[]
from ray.rllib.models.preprocessors import get_preprocessor


env = GymEnvironment()
obs_space = env.observation_space
preprocessor = get_preprocessor(obs_space)(obs_space)  # <1>

observations = env.reset()
transformed = preprocessor.transform(observations).reshape(1, -1)  # <2>

model_output, _ = model.from_batch({"obs": transformed})  # <3>
# end::rllib_model_output[]

# tag::rllib_q_values_action_dist[]
q_values = model.get_q_value_distributions(model_output)  # <1>
print(q_values)

action_distribution = policy.dist_class(model_output, model)  # <2>
sample = action_distribution.sample()  # <3>
print(sample)
# end::rllib_q_values_action_dist[]

# model.get_state_value(model_output)

# tag::multi_agent_init[]
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from gym.spaces import Discrete
import os


class MultiAgentMaze(MultiAgentEnv):

    def __init__(self,  *args, **kwargs):  # <1>
        self.action_space = Discrete(4)
        self.observation_space = Discrete(5*5)
        self.agents = {1: (4, 0), 2: (0, 4)}  # <2>
        self.goal = (4, 4)
        self.info = {1: {'obs': self.agents[1]}, 2: {'obs': self.agents[2]}}  # <3>

    def reset(self):
        self.agents = {1: (4, 0), 2: (0, 4)}

        return {1: self.get_observation(1), 2: self.get_observation(2)}  # <4>
# end::multi_agent_init[]

# tag::multi_agent_helpers[]
    def get_observation(self, agent_id):
        seeker = self.agents[agent_id]
        return 5 * seeker[0] + seeker[1]

    def get_reward(self, agent_id):
        return 1 if self.agents[agent_id] == self.goal else 0

    def is_done(self, agent_id):
        return self.agents[agent_id] == self.goal
# end::multi_agent_helpers[]

# tag::multi_agent_step[]
    def step(self, action):  # <1>
        agent_ids = action.keys()

        for agent_id in agent_ids:
            seeker = self.agents[agent_id]
            if action[agent_id] == 0:  # move down
                seeker = (min(seeker[0] + 1, 4), seeker[1])
            elif action[agent_id] == 1:  # move left
                seeker = (seeker[0], max(seeker[1] - 1, 0))
            elif action[agent_id] == 2:  # move up
                seeker = (max(seeker[0] - 1, 0), seeker[1])
            elif action[agent_id] == 3:  # move right
                seeker = (seeker[0], min(seeker[1] + 1, 4))
            else:
                raise ValueError("Invalid action")
            self.agents[agent_id] = seeker  # <2>

        observations = {i: self.get_observation(i) for i in agent_ids}  # <3>
        rewards = {i: self.get_reward(i) for i in agent_ids}
        done = {i: self.is_done(i) for i in agent_ids}

        done["__all__"] = all(done.values())  # <4>

        return observations, rewards, done, self.info
# end::multi_agent_step[]

# tag::multi_agent_render[]
    def render(self, *args, **kwargs):
        os.system('cls' if os.name == 'nt' else 'clear')
        grid = [['| ' for _ in range(5)] + ["|\n"] for _ in range(5)]
        grid[self.goal[0]][self.goal[1]] = '|G'
        grid[self.agents[1][0]][self.agents[1][1]] = '|1'
        grid[self.agents[2][0]][self.agents[2][1]] = '|2'
        print(''.join([''.join(grid_row) for grid_row in grid]))
# end::multi_agent_render[]


# tag::multi_agent_run[]
import time

env = MultiAgentMaze()

while True:
    obs, rew, done, info = env.step(
        {1: env.action_space.sample(), 2: env.action_space.sample()}
    )
    time.sleep(0.1)
    env.render()
    if any(done.values()):
        break
# end::multi_agent_run[]

# tag::multi_agent_simple[]
from ray.rllib.algorithms.dqn import DQNConfig

simple_trainer = DQNConfig().environment(env=MultiAgentMaze)
simple_trainer.train()
# end::multi_agent_simple[]

# tag::multi_agent_mapping[]
algo = DQNConfig()\
    .environment(env=MultiAgentMaze)\
    .multi_agent(
        policies={  # <1>
            "policy_1": (
                None, env.observation_space, env.action_space, {"gamma": 0.80}
            ),
            "policy_2": (
                None, env.observation_space, env.action_space, {"gamma": 0.95}
            ),
        },
        policy_mapping_fn = lambda agent_id: f"policy_{agent_id}",  # <2>
    ).build()

print(algo.train())
# end::multi_agent_mapping[]


# tag::advanced_env_init[]
from gym.spaces import Discrete
import random
import os


class AdvancedEnv(GymEnvironment):

    def __init__(self, seeker=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.maze_len = 11
        self.action_space = Discrete(4)
        self.observation_space = Discrete(self.maze_len * self.maze_len)

        if seeker:  # <1>
            assert 0 <= seeker[0] < self.maze_len and \
                   0 <= seeker[1] < self.maze_len
            self.seeker = seeker
        else:
            self.reset()

        self.goal = (self.maze_len-1, self.maze_len-1)
        self.info = {'seeker': self.seeker, 'goal': self.goal}

        self.punish_states = [  # <2>
            (i, j) for i in range(self.maze_len) for j in range(self.maze_len)
            if i % 2 == 1 and j % 2 == 0
        ]
# end::advanced_env_init[]

# tag::advanced_env_rest[]
    def reset(self):
        """Reset seeker position randomly, return observations."""
        self.seeker = (
            random.randint(0, self.maze_len - 1),
            random.randint(0, self.maze_len - 1)
        )
        return self.get_observation()

    def get_observation(self):
        """Encode the seeker position as integer"""
        return self.maze_len * self.seeker[0] + self.seeker[1]

    def get_reward(self):
        """Reward finding the goal and punish forbidden states"""
        reward = -1 if self.seeker in self.punish_states else 0
        reward += 5 if self.seeker == self.goal else 0
        return reward

    def render(self, *args, **kwargs):
        """Render the environment, e.g. by printing its representation."""
        os.system('cls' if os.name == 'nt' else 'clear')
        grid = [['| ' for _ in range(self.maze_len)] +
                ["|\n"] for _ in range(self.maze_len)]
        for punish in self.punish_states:
            grid[punish[0]][punish[1]] = '|X'
        grid[self.goal[0]][self.goal[1]] = '|G'
        grid[self.seeker[0]][self.seeker[1]] = '|S'
        print(''.join([''.join(grid_row) for grid_row in grid]))
# end::advanced_env_rest[]


# tag::task_settable[]
from ray.rllib.env.apis.task_settable_env import TaskSettableEnv


class CurriculumEnv(AdvancedEnv, TaskSettableEnv):

    def __init__(self, *args, **kwargs):
        AdvancedEnv.__init__(self)

    def difficulty(self):  # <1>
        return abs(self.seeker[0] - self.goal[0]) + \
               abs(self.seeker[1] - self.goal[1])

    def get_task(self):  # <2>
        return self.difficulty()

    def set_task(self, task_difficulty):  # <3>
        while not self.difficulty() <= task_difficulty:
            self.reset()
# end::task_settable[]


# tag::curriculum_fn[]
def curriculum_fn(train_results, task_settable_env, env_ctx):
    time_steps = train_results.get("timesteps_total")
    difficulty = time_steps // 1000
    print(f"Current difficulty: {difficulty}")
    return difficulty
# end::curriculum_fn[]


# tag::curriculum_trainer[]
from ray.rllib.algorithms.dqn import DQNConfig
import tempfile


temp = tempfile.mkdtemp()  # <1>

trainer = (
    DQNConfig()
    .environment(env=CurriculumEnv, env_task_fn=curriculum_fn)  # <2>
    .offline_data(output=temp)  # <3>
    .build()
)

for i in range(15):
    trainer.train()
# end::curriculum_trainer[]

# tag::input_trainer[]
imitation_algo = (
    DQNConfig()
    .environment(env=AdvancedEnv)
    .offline_data(input_=temp, input_evaluation=[])
    .exploration(explore=False)
    .build())

for i in range(10):
    imitation_algo.train()

imitation_algo.evaluate()
# end::input_trainer[]