# Application: Training an RL restrictor for a discrete action space

## Setup and definitions

### Imports

In [None]:
import os, sys
sys.path.append(f'{os.getcwd()}/../../')

In [None]:
import networkx as nx
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt

from drama.wrapper import RestrictionWrapper
from drama.utils import flatdim

from examples.traffic_new.env import TrafficEnvironment
from examples.traffic_new.agent import TrafficAgent
from examples.traffic_new.restrictor import TrafficRestrictor
from examples.traffic_new.utils import create_graph, powerset
from examples.utils import play, ReplayBuffer

### Definitions

In [None]:
def smoothen(data, kernel_size):
    kernel = np.ones(kernel_size) / kernel_size

    if data.ndim == 1:
        return np.convolve(data, kernel, mode='same')
    elif data.ndim == 2:
        return np.array([np.convolve(col, kernel, mode='same') for col in data.T]).T
    else:
        raise NotImplementedError

In [None]:
graph = create_graph([
    ((0, 1), (0, 8, 1)), 
    ((0, 2), (11, 0, 0)), 
    ((1, 2), (1, 0, 0)), 
    ((1, 3), (11, 0, 0)), 
    ((2, 3), (0, 8, 1))
])

possible_start_and_target_nodes = [(0, 3)]

In [None]:
number_of_nodes = graph.number_of_nodes()
number_of_edges = graph.number_of_edges()

edges = {edge: i for i, edge in enumerate(graph.edges)}
routes = sum((list(nx.all_simple_paths(graph, s, t)) for s, t in possible_start_and_target_nodes), [])

# Create all valid edge restrictions as sets of allowed edges
all_start_and_target_nodes = set(sum(possible_start_and_target_nodes, tuple()))
valid_edge_restrictions = []
for allowed_edges in powerset(graph.edges):
    subgraph = graph.edge_subgraph(allowed_edges)
    if all_start_and_target_nodes.issubset(subgraph.nodes) and all(nx.has_path(subgraph, s, t) for s, t in possible_start_and_target_nodes):
        valid_edge_restrictions.append([edge in allowed_edges for edge in edges.keys()])

seed = 42

## Execution

### Simulation

In [None]:
number_of_agents = 2

agents = {f'agent_{i}': TrafficAgent(routes, edges, seed=seed) for i in range(number_of_agents)}
env = TrafficEnvironment(graph, list(agents.keys()), possible_start_and_target_nodes, routes, number_of_steps=100, seed=seed)

restrictor = TrafficRestrictor(edges, routes, valid_edge_restrictions, total_timesteps=1000, seed=seed)
env = RestrictionWrapper(env, restrictor, restrictor_reward_fns={'restrictor_0': lambda env, rewards: rewards[env.agent_selection]}, return_object=True)

In [None]:
total_timesteps = 500_000

restricted_history = pd.DataFrame(columns=['episode', 'episode_step', 'agent', 'observation', 'reward', 'action'], index=(range(total_timesteps)))
replay_buffer = ReplayBuffer(state_dim=flatdim(restrictor.observation_space), action_dim=flatdim(restrictor.action_space))

# Do not render during training
env.unwrapped.render_mode = None

current_timestep = 0
current_episode = 0
t = tqdm(total=total_timesteps)

while current_timestep < total_timesteps:
    env.reset()
    current_episode += 1
    current_episode_timestep = 0
    previous_restrictor_observation = None

    for agent in env.agent_iter():
        observation, reward, termination, truncation, info = env.last()

        if agent == 'restrictor_0':
            if previous_restrictor_observation is not None:
                restrictor.learn(previous_restrictor_observation, previous_restrictor_action, observation, reward, termination or truncation)

            action = restrictor.act(observation)

            previous_restrictor_observation = observation
            previous_restrictor_action = action
        else:
            action = agents[agent].act(observation)

        if termination or truncation:
            action = None
        else:
            restricted_history.loc[current_timestep] = pd.Series({'episode': current_episode, 
                                               'episode_step': current_episode_timestep, 
                                               'agent': agent,
                                               'observation': observation, 
                                               'reward': reward, 
                                               'action': action}
                                               )
            
            current_timestep += 1
            current_episode_timestep += 1

        env.step(action)

        t.update()

### Visualization

In [None]:
kernel_size = 5_000

valid_edge_restriction_sets = ['{' + ', '.join(f'{i}' for i, allowed in enumerate(restriction) if allowed) + '}' for restriction in valid_edge_restrictions]

restrictor_actions = restricted_history[restricted_history.agent == 'restrictor_0']['action'].astype(int)
one_hot_restrictor_actions = np.eye(len(valid_edge_restrictions))[restrictor_actions.to_numpy().reshape(-1)]

fig, ax = plt.subplots()

data = pd.DataFrame(smoothen(one_hot_restrictor_actions, kernel_size=kernel_size), index=restrictor_actions.index, columns=valid_edge_restriction_sets)
lines = ax.plot(data.iloc[kernel_size:-kernel_size], color='gray', lw=1)

lines[11].set_color('red')
ax.legend(labels=valid_edge_restriction_sets, loc='center left', bbox_to_anchor=(0.95, 0.5))
ax.get_xaxis().set_major_formatter(matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))

ax.set_ylabel('Frequency of restriction')
ax.set_xlabel('Time step')

fig.savefig('traffic-result-actions.pdf', bbox_inches='tight')