In [None]:
! pip install catalyst==21.04.2 gym==0.18.0

---

# Off-policy DQN

## Imports

In [None]:
from typing import Iterator, Optional, Sequence, Tuple
from collections import deque, namedtuple

import gym
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

from catalyst import dl, metrics, utils

## RL common

In [None]:
Transition = namedtuple(
    "Transition", field_names=["state", "action", "reward", "done", "next_state"]
)


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def append(self, transition: Transition):
        self.buffer.append(transition)

    def sample(self, size: int) -> Sequence[np.array]:
        indices = np.random.choice(len(self.buffer), size, replace=size > len(self.buffer))
        states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
        states = np.array(states, dtype=np.float32)
        actions = np.array(actions, dtype=np.int64)
        rewards = np.array(rewards, dtype=np.float32)
        dones = np.array(dones, dtype=np.bool)
        next_states = np.array(next_states, dtype=np.float32)
        return states, actions, rewards, dones, next_states

    def __len__(self) -> int:
        return len(self.buffer)


# as far as RL does not have some predefined dataset,
# we need to specify epoch length by ourselfs
class ReplayDataset(IterableDataset):
    def __init__(self, buffer: ReplayBuffer, epoch_size: int = int(1e3)):
        self.buffer = buffer
        self.epoch_size = epoch_size

    def __iter__(self) -> Iterator[Sequence[np.array]]:
        states, actions, rewards, dones, next_states = self.buffer.sample(self.epoch_size)
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], next_states[i]

    def __len__(self) -> int:
        return self.epoch_size


def soft_update(target: nn.Module, source: nn.Module, tau: float) -> None:
    """Updates the `target` data with the `source` one smoothing by ``tau`` (inplace operation)."""
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

## DQN

In [None]:
def get_action(env, network: nn.Module, state: np.array, epsilon: float = -1) -> int:
    if np.random.random() < epsilon:
        action = env.action_space.sample()
    else:
        state = torch.tensor(state[None], dtype=torch.float32)
        q_values = network(state).detach().cpu().numpy()[0]
        action = np.argmax(q_values)

    return int(action)


def generate_session(
    env,
    network: nn.Module,
    t_max: int = 1000,
    epsilon: float = -1,
    replay_buffer: Optional[ReplayBuffer] = None,
) -> Tuple[float, int]:
    total_reward = 0
    state = env.reset()

    for t in range(t_max):
        action = get_action(env, network, state=state, epsilon=epsilon)
        next_state, reward, done, _ = env.step(action)

        if replay_buffer is not None:
            transition = Transition(state, action, reward, done, next_state)
            replay_buffer.append(transition)

        total_reward += reward
        state = next_state
        if done:
            break

    return total_reward, t


def generate_sessions(
    env,
    network: nn.Module,
    t_max: int = 1000,
    epsilon: float = -1,
    replay_buffer: ReplayBuffer = None,
    num_sessions: int = 100,
) -> Tuple[float, int]:
    sessions_reward, sessions_steps = 0, 0
    for i_episone in range(num_sessions):
        r, t = generate_session(
            env=env, network=network, t_max=t_max, epsilon=epsilon, replay_buffer=replay_buffer,
        )
        sessions_reward += r
        sessions_steps += t
    return sessions_reward, sessions_steps

In [None]:
def get_network(env, num_hidden: int = 128):
    inner_fn = utils.get_optimal_inner_init(nn.ReLU)
    outer_fn = utils.outer_init

    network = torch.nn.Sequential(
        nn.Linear(env.observation_space.shape[0], num_hidden),
        nn.ReLU(),
        nn.Linear(num_hidden, num_hidden),
        nn.ReLU(),
    )
    head = nn.Linear(num_hidden, env.action_space.n)

    network.apply(inner_fn)
    head.apply(outer_fn)

    return torch.nn.Sequential(network, head)

## Catalyst

In [None]:
class GameCallback(dl.Callback):
    def __init__(
        self,
        *,
        env,
        replay_buffer: ReplayBuffer,
        session_period: int,
        epsilon: float,
        epsilon_k: float,
        actor_key: str,
        num_start_sessions: int = int(1e3),
        num_valid_sessions: int = int(1e2),
    ):
        super().__init__(order=0)
        self.env = env
        self.replay_buffer = replay_buffer
        self.session_period = session_period
        self.epsilon = epsilon
        self.epsilon_k = epsilon_k
        self.actor_key = actor_key
        self.actor: nn.Module = None
        self.num_start_sessions = num_start_sessions
        self.num_valid_sessions = num_valid_sessions
        self.session_counter = 0
        self.session_steps = 0

    def on_stage_start(self, runner: dl.IRunner) -> None:
        self.actor = runner.model[self.actor_key]

        self.actor.eval()
        generate_sessions(
            env=self.env,
            network=self.actor,
            epsilon=self.epsilon,
            replay_buffer=self.replay_buffer,
            num_sessions=self.num_start_sessions,
        )
        self.actor.train()

    def on_epoch_start(self, runner: dl.IRunner):
        self.epsilon *= self.epsilon_k
        self.session_counter = 0
        self.session_steps = 0

    def on_batch_end(self, runner: dl.IRunner):
        if runner.global_batch_step % self.session_period == 0:
            self.actor.eval()

            session_reward, session_steps = generate_session(
                env=self.env,
                network=self.actor,
                epsilon=self.epsilon,
                replay_buffer=self.replay_buffer,
            )

            self.session_counter += 1
            self.session_steps += session_steps

            runner.batch_metrics.update({"s_reward": session_reward})
            runner.batch_metrics.update({"s_steps": session_steps})

            self.actor.train()

    def on_epoch_end(self, runner: dl.IRunner):
        self.actor.eval()
        valid_rewards, valid_steps = generate_sessions(
            env=self.env, network=self.actor, num_sessions=int(self.num_valid_sessions)
        )
        self.actor.train()

        valid_rewards /= float(self.num_valid_sessions)
        valid_steps /= float(self.num_valid_sessions)
        runner.epoch_metrics["_epoch_"]["epsilon"] = self.epsilon
        runner.epoch_metrics["_epoch_"]["num_sessions"] = self.session_counter
        runner.epoch_metrics["_epoch_"]["num_samples"] = self.session_steps
        runner.epoch_metrics["_epoch_"]["updates_per_sample"] = (
            runner.loader_sample_step / self.session_steps
        )
        runner.epoch_metrics["_epoch_"]["v_reward"] = valid_rewards
        runner.epoch_metrics["_epoch_"]["v_steps"] = valid_steps

In [None]:
class CustomRunner(dl.Runner):
    def __init__(
        self,
        *,
        gamma: float,
        tau: float,
        tau_period: int = 1,
        origin_key: str = "origin",
        target_key: str = "target",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.gamma: float = gamma
        self.tau: float = tau
        self.tau_period: int = tau_period
        self.origin_key: str = origin_key
        self.target_key: str = target_key
        self.origin_network: nn.Module = None
        self.target_network: nn.Module = None

    def on_stage_start(self, runner: dl.IRunner):
        super().on_stage_start(runner)
        self.origin_network = self.model[self.origin_key]
        self.target_network = self.model[self.target_key]
        soft_update(self.target_network, self.origin_network, 1.0)

    def on_loader_start(self, runner: dl.IRunner):
        super().on_loader_start(runner)
        self.meters = {key: metrics.AdditiveValueMetric(compute_on_call=False) for key in ["loss"]}

    def handle_batch(self, batch: Sequence[np.array]):
        # model train/valid step
        states, actions, rewards, dones, next_states = batch

        # get q-values for all actions in current states
        state_qvalues = self.origin_network(states)
        # select q-values for chosen actions
        state_action_qvalues = state_qvalues.gather(1, actions.unsqueeze(-1)).squeeze(-1)

        # compute q-values for all actions in next states
        # compute V*(next_states) using predicted next q-values
        # at the last state we shall use simplified formula:
        # Q(s,a) = r(s,a) since s' doesn't exist
        with torch.no_grad():
            next_state_qvalues = self.target_network(next_states)
            next_state_values = next_state_qvalues.max(1)[0]
            next_state_values[dones] = 0.0
            next_state_values = next_state_values.detach()

        # compute "target q-values" for loss,
        # it's what's inside square parentheses in the above formula.
        target_state_action_qvalues = next_state_values * self.gamma + rewards

        # mean squared error loss to minimize
        loss = self.criterion(state_action_qvalues, target_state_action_qvalues.detach())
        self.batch_metrics.update({"loss": loss})
        for key in ["loss"]:
            self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            if self.global_batch_step % self.tau_period == 0:
                soft_update(self.target_network, self.origin_network, self.tau)

    def on_loader_end(self, runner: dl.IRunner):
        for key in ["loss"]:
            self.loader_metrics[key] = self.meters[key].compute()[0]
        super().on_loader_end(runner)

## Training

In [None]:
batch_size = 64
epoch_size = int(1e3) * batch_size
buffer_size = int(1e5)
# runner settings, ~training
gamma = 0.99
tau = 0.01
tau_period = 1  # in batches
# callback, ~exploration
session_period = 100  # in batches
epsilon = 0.98
epsilon_k = 0.9
# optimization
lr = 3e-4

# env_name = "LunarLander-v2"
env_name = "CartPole-v1"
env = gym.make(env_name)
replay_buffer = ReplayBuffer(buffer_size)

network, target_network = get_network(env), get_network(env)
utils.set_requires_grad(target_network, requires_grad=False)
models = {"origin": network, "target": target_network}
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=lr)
loaders = {
    "train_game": DataLoader(
        ReplayDataset(replay_buffer, epoch_size=epoch_size), batch_size=batch_size,
    ),
}

runner = CustomRunner(gamma=gamma, tau=tau, tau_period=tau_period)
runner.train(
    model=models,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    logdir="./logs_dqn",
    num_epochs=10,
    verbose=True,
    valid_loader="_epoch_",
    valid_metric="v_reward",
    minimize_valid_metric=False,
    load_best_on_end=True,
    callbacks=[
        GameCallback(
            env=env,
            replay_buffer=replay_buffer,
            session_period=session_period,
            epsilon=epsilon,
            epsilon_k=epsilon_k,
            actor_key="origin",
        )
    ],
)

## Evaluating

In [None]:
env = gym.wrappers.Monitor(gym.make(env_name), directory="videos_dqn", force=True)
generate_sessions(env=env, network=runner.model["origin"], num_sessions=100)
env.close()

# show video
from IPython.display import HTML
import os

video_names = list(filter(lambda s: s.endswith(".mp4"), os.listdir("./videos_dqn/")))

HTML("""
<video width="640" height="480" controls>
  <source src="{}" type="video/mp4">
</video>
""".format("./videos/" + video_names[-1]))
# this may or may not be _last_ video. Try other indices

---

# Off-policy DDPG

## Imports

In [None]:
from typing import Iterator, Optional, Sequence, Tuple
from collections import deque, namedtuple

import gym
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

from catalyst import dl, metrics, utils

## RL common

In [None]:
Transition = namedtuple(
    "Transition", field_names=["state", "action", "reward", "done", "next_state"]
)


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def append(self, transition: Transition):
        self.buffer.append(transition)

    def sample(self, size: int) -> Sequence[np.array]:
        indices = np.random.choice(len(self.buffer), size, replace=size > len(self.buffer))
        states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
        states = np.array(states, dtype=np.float32)
        actions = np.array(actions, dtype=np.int64)
        rewards = np.array(rewards, dtype=np.float32)
        dones = np.array(dones, dtype=np.bool)
        next_states = np.array(next_states, dtype=np.float32)
        return states, actions, rewards, dones, next_states

    def __len__(self) -> int:
        return len(self.buffer)


# as far as RL does not have some predefined dataset,
# we need to specify epoch length by ourselfs
class ReplayDataset(IterableDataset):
    def __init__(self, buffer: ReplayBuffer, epoch_size: int = int(1e3)):
        self.buffer = buffer
        self.epoch_size = epoch_size

    def __iter__(self) -> Iterator[Sequence[np.array]]:
        states, actions, rewards, dones, next_states = self.buffer.sample(self.epoch_size)
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], next_states[i]

    def __len__(self) -> int:
        return self.epoch_size


def soft_update(target: nn.Module, source: nn.Module, tau: float) -> None:
    """Updates the `target` data with the `source` one smoothing by ``tau`` (inplace operation)."""
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

## DDPG

In [None]:
class NormalizedActions(gym.ActionWrapper):
    def action(self, action: float) -> float:
        low_bound = self.action_space.low
        upper_bound = self.action_space.high

        action = low_bound + (action + 1.0) * 0.5 * (upper_bound - low_bound)
        action = np.clip(action, low_bound, upper_bound)

        return action

    def _reverse_action(self, action: float) -> float:
        low_bound = self.action_space.low
        upper_bound = self.action_space.high

        action = 2 * (action - low_bound) / (upper_bound - low_bound) - 1
        action = np.clip(action, low_bound, upper_bound)

        return action


def get_action(
    env, network: nn.Module, state: np.array, sigma: Optional[float] = None
) -> np.array:
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    action = network(state).detach().cpu().numpy()[0]
    if sigma is not None:
        action = np.random.normal(action, sigma)
    return action


def generate_session(
    env,
    network: nn.Module,
    sigma: Optional[float] = None,
    replay_buffer: Optional[ReplayBuffer] = None,
) -> Tuple[float, int]:
    total_reward = 0
    state = env.reset()

    for t in range(env.spec.max_episode_steps):
        action = get_action(env, network, state=state, sigma=sigma)
        next_state, reward, done, _ = env.step(action)

        if replay_buffer is not None:
            transition = Transition(state, action, reward, done, next_state)
            replay_buffer.append(transition)

        total_reward += reward
        state = next_state
        if done:
            break

    return total_reward, t


def generate_sessions(
    env,
    network: nn.Module,
    sigma: Optional[float] = None,
    replay_buffer: Optional[ReplayBuffer] = None,
    num_sessions: int = 100,
) -> Tuple[float, int]:
    sessions_reward, sessions_steps = 0, 0
    for i_episone in range(num_sessions):
        r, t = generate_session(
            env=env, network=network, sigma=sigma, replay_buffer=replay_buffer,
        )
        sessions_reward += r
        sessions_steps += t
    return sessions_reward, sessions_steps

In [None]:
def get_network_actor(env):
    inner_fn = utils.get_optimal_inner_init(nn.ReLU)
    outer_fn = utils.outer_init

    network = torch.nn.Sequential(
        nn.Linear(env.observation_space.shape[0], 400), nn.ReLU(), nn.Linear(400, 300), nn.ReLU(),
    )
    head = torch.nn.Sequential(nn.Linear(300, 1), nn.Tanh())

    network.apply(inner_fn)
    head.apply(outer_fn)

    return torch.nn.Sequential(network, head)


def get_network_critic(env):
    inner_fn = utils.get_optimal_inner_init(nn.LeakyReLU)
    outer_fn = utils.outer_init

    network = torch.nn.Sequential(
        nn.Linear(env.observation_space.shape[0] + 1, 400),
        nn.LeakyReLU(0.01),
        nn.Linear(400, 300),
        nn.LeakyReLU(0.01),
    )
    head = nn.Linear(300, 1)

    network.apply(inner_fn)
    head.apply(outer_fn)

    return torch.nn.Sequential(network, head)

## Catalyst

In [None]:
class GameCallback(dl.Callback):
    def __init__(
        self,
        *,
        env,
        replay_buffer: ReplayBuffer,
        session_period: int,
        sigma: float,
        actor_key: str,
        num_start_sessions: int = int(1e3),
        num_valid_sessions: int = int(1e2),
    ):
        super().__init__(order=0)
        self.env = env
        self.replay_buffer = replay_buffer
        self.session_period = session_period
        self.sigma = sigma
        self.actor_key = actor_key
        self.num_start_sessions = num_start_sessions
        self.num_valid_sessions = num_valid_sessions
        self.session_counter = 0
        self.session_steps = 0

    def on_stage_start(self, runner: dl.IRunner):
        self.actor = runner.model[self.actor_key]

        self.actor.eval()
        generate_sessions(
            env=self.env,
            network=self.actor,
            sigma=self.sigma,
            replay_buffer=self.replay_buffer,
            num_sessions=self.num_start_sessions,
        )
        self.actor.train()

    def on_epoch_start(self, runner: dl.IRunner):
        self.session_counter = 0
        self.session_steps = 0

    def on_batch_end(self, runner: dl.IRunner):
        if runner.global_batch_step % self.session_period == 0:
            self.actor.eval()

            session_reward, session_steps = generate_session(
                env=self.env,
                network=self.actor,
                sigma=self.sigma,
                replay_buffer=self.replay_buffer,
            )

            self.session_counter += 1
            self.session_steps += session_steps

            runner.batch_metrics.update({"s_reward": session_reward})
            runner.batch_metrics.update({"s_steps": session_steps})

            self.actor.train()

    def on_epoch_end(self, runner: dl.IRunner):
        self.actor.eval()
        valid_rewards, valid_steps = generate_sessions(
            env=self.env, network=self.actor, num_sessions=int(self.num_valid_sessions)
        )
        self.actor.train()

        valid_rewards /= float(self.num_valid_sessions)
        valid_steps /= float(self.num_valid_sessions)
        runner.epoch_metrics["_epoch_"]["num_sessions"] = self.session_counter
        runner.epoch_metrics["_epoch_"]["num_samples"] = self.session_steps
        runner.epoch_metrics["_epoch_"]["updates_per_sample"] = (
            runner.loader_sample_step / self.session_steps
        )
        runner.epoch_metrics["_epoch_"]["v_reward"] = valid_rewards
        runner.epoch_metrics["_epoch_"]["v_steps"] = valid_steps

In [None]:
class CustomRunner(dl.Runner):
    def __init__(
        self,
        *,
        gamma: float,
        tau: float,
        tau_period: int = 1,
        actor_key: str = "actor",
        critic_key: str = "critic",
        target_actor_key: str = "target_actor",
        target_critic_key: str = "target_critic",
        actor_optimizer_key: str = "actor",
        critic_optimizer_key: str = "critic",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.tau = tau
        self.tau_period = tau_period
        self.actor_key: str = actor_key
        self.critic_key: str = critic_key
        self.target_actor_key: str = target_actor_key
        self.target_critic_key: str = target_critic_key
        self.actor_optimizer_key: str = actor_optimizer_key
        self.critic_optimizer_key: str = critic_optimizer_key
        self.actor: nn.Module = None
        self.critic: nn.Module = None
        self.target_actor: nn.Module = None
        self.target_critic: nn.Module = None
        self.actor_optimizer: nn.Module = None
        self.critic_optimizer: nn.Module = None

    def on_stage_start(self, runner: dl.IRunner):
        super().on_stage_start(runner)
        self.actor = self.model[self.actor_key]
        self.critic = self.model[self.critic_key]
        self.target_actor = self.model[self.target_actor_key]
        self.target_critic = self.model[self.target_critic_key]
        soft_update(self.target_actor, self.actor, 1.0)
        soft_update(self.target_critic, self.critic, 1.0)
        self.actor_optimizer = self.optimizer[self.actor_optimizer_key]
        self.critic_optimizer = self.optimizer[self.critic_optimizer_key]

    def on_loader_start(self, runner: dl.IRunner):
        super().on_loader_start(runner)
        self.meters = {
            key: metrics.AdditiveValueMetric(compute_on_call=False)
            for key in ["critic_loss", "actor_loss"]
        }

    def handle_batch(self, batch: Sequence[torch.Tensor]):
        # model train/valid step
        states, actions, rewards, dones, next_states = batch

        # get actions for the current state
        pred_actions = self.actor(states)
        # get q-values for the actions in current states
        pred_critic_states = torch.cat([states, pred_actions], 1)
        # use q-values to train the actor model
        policy_loss = (-self.critic(pred_critic_states)).mean()

        with torch.no_grad():
            # get possible actions for the next states
            next_state_actions = self.target_actor(next_states)
            # get possible q-values for the next actions
            next_critic_states = torch.cat([next_states, next_state_actions], 1)
            next_state_values = self.target_critic(next_critic_states).detach().squeeze()
            next_state_values[dones] = 0.0

        # compute Bellman's equation value
        target_state_values = next_state_values * self.gamma + rewards
        # compute predicted values
        critic_states = torch.cat([states, actions], 1)
        state_values = self.critic(critic_states).squeeze()

        # train the critic model
        value_loss = self.criterion(state_values, target_state_values.detach())

        self.batch_metrics.update({"critic_loss": value_loss, "actor_loss": policy_loss})
        for key in ["critic_loss", "actor_loss"]:
            self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)

        if self.is_train_loader:
            self.actor.zero_grad()
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            self.actor_optimizer.step()

            self.critic.zero_grad()
            self.critic_optimizer.zero_grad()
            value_loss.backward()
            self.critic_optimizer.step()

            if self.global_batch_step % self.tau_period == 0:
                soft_update(self.target_actor, self.actor, self.tau)
                soft_update(self.target_critic, self.critic, self.tau)

    def on_loader_end(self, runner: dl.IRunner):
        for key in ["critic_loss", "actor_loss"]:
            self.loader_metrics[key] = self.meters[key].compute()[0]
        super().on_loader_end(runner)

## Training

In [None]:
# data
batch_size = 64
epoch_size = int(1e3) * batch_size
buffer_size = int(1e5)
# runner settings, ~training
gamma = 0.99
tau = 0.01
tau_period = 1
# callback, ~exploration
session_period = 1
sigma = 0.3
# optimization
lr_actor = 1e-4
lr_critic = 1e-3

# You can change game
# env_name = "LunarLanderContinuous-v2"
env_name = "Pendulum-v0"
env = NormalizedActions(gym.make(env_name))
replay_buffer = ReplayBuffer(buffer_size)

actor, target_actor = get_network_actor(env), get_network_actor(env)
critic, target_critic = get_network_critic(env), get_network_critic(env)
utils.set_requires_grad(target_actor, requires_grad=False)
utils.set_requires_grad(target_critic, requires_grad=False)

models = {
    "actor": actor,
    "critic": critic,
    "target_actor": target_actor,
    "target_critic": target_critic,
}

criterion = torch.nn.MSELoss()
optimizer = {
    "actor": torch.optim.Adam(actor.parameters(), lr_actor),
    "critic": torch.optim.Adam(critic.parameters(), lr=lr_critic),
}

loaders = {
    "train_game": DataLoader(
        ReplayDataset(replay_buffer, epoch_size=epoch_size), batch_size=batch_size,
    ),
}

runner = CustomRunner(gamma=gamma, tau=tau, tau_period=tau_period,)

runner.train(
    engine=dl.DeviceEngine("cpu"),  # for simplicity reasons, let's run everything on cpu
    model=models,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    logdir="./logs_ddpg",
    num_epochs=10,
    verbose=True,
    valid_loader="_epoch_",
    valid_metric="v_reward",
    minimize_valid_metric=False,
    load_best_on_end=True,
    callbacks=[
        GameCallback(
            env=env,
            replay_buffer=replay_buffer,
            session_period=session_period,
            sigma=sigma,
            actor_key="actor",
        )
    ],
)

## Evaluating

In [None]:
env = gym.wrappers.Monitor(gym.make(env_name), directory="videos_ddpg", force=True)
generate_sessions(env=env, network=runner.model["actor"], num_sessions=100)
env.close()

# show video
from IPython.display import HTML
import os

video_names = list(filter(lambda s: s.endswith(".mp4"), os.listdir("./videos_ddpg/")))

HTML("""
<video width="640" height="480" controls>
  <source src="{}" type="video/mp4">
</video>
""".format("./videos/" + video_names[-1]))
# this may or may not be _last_ video. Try other indices

---

# On-policy REINFORCE

## Imports

In [None]:
from typing import Iterator, Optional, Sequence, Tuple
from collections import deque, namedtuple

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

from catalyst import dl, metrics, utils

## RL common

In [None]:
Rollout = namedtuple("Rollout", field_names=["states", "actions", "rewards",])


class RolloutBuffer:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)

    def append(self, rollout: Rollout):
        self.buffer.append(rollout)

    def sample(self, idx: int) -> Sequence[np.array]:
        states, actions, rewards = self.buffer[idx]
        states = np.array(states, dtype=np.float32)
        actions = np.array(actions, dtype=np.int64)
        rewards = np.array(rewards, dtype=np.float32)
        return states, actions, rewards

    def __len__(self) -> int:
        return len(self.buffer)


# as far as RL does not have some predefined dataset,
# we need to specify epoch length by ourselfs
class RolloutDataset(IterableDataset):
    def __init__(self, buffer: RolloutBuffer):
        self.buffer = buffer

    def __iter__(self) -> Iterator[Sequence[np.array]]:
        for i in range(len(self.buffer)):
            states, actions, rewards = self.buffer.sample(i)
            yield states, actions, rewards
        self.buffer.buffer.clear()

    def __len__(self) -> int:
        return self.buffer.capacity

## REINFORCE

In [None]:
def get_cumulative_rewards(rewards, gamma=0.99):
    G = [rewards[-1]]
    for r in reversed(rewards[:-1]):
        G.insert(0, r + gamma * G[0])
    return G


def to_one_hot(y, n_dims=None):
    """ Takes an integer vector and converts it to 1-hot matrix. """
    y_tensor = y
    y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
    return y_one_hot


def get_action(env, network: nn.Module, state: np.array) -> int:
    state = torch.tensor(state[None], dtype=torch.float32)
    logits = network(state).detach()
    probas = F.softmax(logits, -1).cpu().numpy()[0]
    action = np.random.choice(len(probas), p=probas)
    return int(action)


def generate_session(
    env, network: nn.Module, t_max: int = 1000, rollout_buffer: Optional[RolloutBuffer] = None,
) -> Tuple[float, int]:
    total_reward = 0
    states, actions, rewards = [], [], []
    state = env.reset()

    for t in range(t_max):
        action = get_action(env, network, state=state)
        next_state, reward, done, _ = env.step(action)

        # record session history to train later
        states.append(state)
        actions.append(action)
        rewards.append(reward)

        total_reward += reward
        state = next_state
        if done:
            break
    if rollout_buffer is not None:
        rollout_buffer.append(Rollout(states, actions, rewards))

    return total_reward, t


def generate_sessions(
    env,
    network: nn.Module,
    t_max: int = 1000,
    rollout_buffer: Optional[RolloutBuffer] = None,
    num_sessions: int = 100,
) -> Tuple[float, int]:
    sessions_reward, sessions_steps = 0, 0
    for i_episone in range(num_sessions):
        r, t = generate_session(
            env=env, network=network, t_max=t_max, rollout_buffer=rollout_buffer,
        )
        sessions_reward += r
        sessions_steps += t
    return sessions_reward, sessions_steps

In [None]:
def get_network(env, num_hidden: int = 128):
    inner_fn = utils.get_optimal_inner_init(nn.ReLU)
    outer_fn = utils.outer_init

    network = torch.nn.Sequential(
        nn.Linear(env.observation_space.shape[0], num_hidden),
        nn.ReLU(),
        nn.Linear(num_hidden, num_hidden),
        nn.ReLU(),
    )
    head = nn.Linear(num_hidden, env.action_space.n)

    network.apply(inner_fn)
    head.apply(outer_fn)

    return torch.nn.Sequential(network, head)

## Catalyst

In [None]:
class GameCallback(dl.Callback):
    def __init__(
        self,
        *,
        env,
        rollout_buffer: RolloutBuffer,
        num_train_sessions: int = int(1e2),
        num_valid_sessions: int = int(1e2),
    ):
        super().__init__(order=0)
        self.env = env
        self.rollout_buffer = rollout_buffer
        self.num_train_sessions = num_train_sessions
        self.num_valid_sessions = num_valid_sessions

    def on_epoch_start(self, runner: dl.IRunner):
        self.actor = runner.model

        self.actor.eval()
        train_rewards, train_steps = generate_sessions(
            env=self.env,
            network=self.actor,
            rollout_buffer=self.rollout_buffer,
            num_sessions=self.num_train_sessions,
        )
        train_rewards /= float(self.num_train_sessions)
        train_steps /= float(self.num_train_sessions)
        runner.epoch_metrics["_epoch_"]["t_reward"] = train_rewards
        runner.epoch_metrics["_epoch_"]["t_steps"] = train_steps
        self.actor.train()

    def on_epoch_end(self, runner: dl.IRunner):
        self.actor.eval()
        valid_rewards, valid_steps = generate_sessions(
            env=self.env, network=self.actor, num_sessions=self.num_valid_sessions
        )
        self.actor.train()

        valid_rewards /= float(self.num_valid_sessions)
        valid_steps /= float(self.num_valid_sessions)
        runner.epoch_metrics["_epoch_"]["v_reward"] = valid_rewards
        runner.epoch_metrics["_epoch_"]["v_steps"] = valid_steps

In [None]:
class CustomRunner(dl.Runner):
    def __init__(
        self, *, gamma: float, entropy_coef: float = 0.1, **kwargs,
    ):
        super().__init__(**kwargs)
        self.gamma: float = gamma
        self.entropy_coef: float = entropy_coef

    def on_loader_start(self, runner: dl.IRunner):
        super().on_loader_start(runner)
        self.meters = {key: metrics.AdditiveValueMetric(compute_on_call=False) for key in ["loss"]}

    def handle_batch(self, batch: Sequence[np.array]):
        # model train/valid step
        # ATTENTION:
        #   because of different trajectories lens
        #   ONLY batch_size==1 supported
        states, actions, rewards = batch
        states, actions, rewards = states[0], actions[0], rewards[0]
        cumulative_returns = torch.tensor(get_cumulative_rewards(rewards, gamma))

        logits = self.model(states)
        probas = F.softmax(logits, -1)
        logprobas = F.log_softmax(logits, -1)
        n_actions = probas.shape[1]
        logprobas_for_actions = torch.sum(logprobas * to_one_hot(actions, n_dims=n_actions), dim=1)

        J_hat = torch.mean(logprobas_for_actions * cumulative_returns)
        entropy_reg = -torch.mean(torch.sum(probas * logprobas, dim=1))
        loss = -J_hat - self.entropy_coef * entropy_reg

        self.batch_metrics.update({"loss": loss})
        for key in ["loss"]:
            self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

    def on_loader_end(self, runner: dl.IRunner):
        for key in ["loss"]:
            self.loader_metrics[key] = self.meters[key].compute()[0]
        super().on_loader_end(runner)

## Training

In [None]:
batch_size = 1
epoch_size = int(1e3) * batch_size
buffer_size = int(1e2)
# runner settings
gamma = 0.99
# optimization
lr = 3e-4

# env_name = "LunarLander-v2"
env_name = "CartPole-v1"
env = gym.make(env_name)
rollout_buffer = RolloutBuffer(buffer_size)

model = get_network(env)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loaders = {
    "train_game": DataLoader(RolloutDataset(rollout_buffer), batch_size=batch_size,),
}

runner = CustomRunner(gamma=gamma)
runner.train(
    engine=dl.DeviceEngine("cpu"),  # for simplicity reasons, let's run everything on cpu
    model=model,
    optimizer=optimizer,
    loaders=loaders,
    logdir="./logs_dqn",
    num_epochs=10,
    verbose=True,
    valid_loader="_epoch_",
    valid_metric="v_reward",
    minimize_valid_metric=False,
    load_best_on_end=True,
    callbacks=[GameCallback(env=env, rollout_buffer=rollout_buffer,)],
)

## Evaluating

In [None]:
env = gym.wrappers.Monitor(gym.make(env_name), directory="videos_reinforce", force=True)
generate_sessions(env=env, network=model, num_sessions=100)
env.close()

from IPython.display import HTML
import os

video_names = list(filter(lambda s: s.endswith(".mp4"), os.listdir("./videos_reinforce/")))

HTML("""
<video width="640" height="480" controls>
  <source src="{}" type="video/mp4">
</video>
""".format("./videos/" + video_names[-1]))
# this may or may not be _last_ video. Try other indices