In [1]:
from pyrl.agents.classic import DQNAgent
from pyrl.agents.survival import SurvivalDQNAgent
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import numpy as np

from tensorforce.agents import Agent
from tensorforce.environments import Environment
from pyrl.environments import CustomEnvironment
from tensorforce.execution import Runner


In [2]:
map_size = 50
horizon = 5000
points = 6
repeat = 50
survival_threshold = 100
initial_budgets = np.linspace(100, horizon, points, dtype=int)

replay_capacity = 6000
batch_size = 32

In [3]:
env = Environment.create(
    environment=CustomEnvironment, max_episode_timesteps=horizon
)

dqn_time_mean = np.full(initial_budgets.shape, -1)
dqn_exploration_rate = np.full(initial_budgets.shape, -1)
dqn_alive_rate = np.full(initial_budgets.shape, -1)
dqn_budget_evolutions_mean = np.full(initial_budgets.shape, None)
dqn_budget_evolutions_max = np.full(initial_budgets.shape, None)
dqn_budget_evolutions_min = np.full(initial_budgets.shape, None)


observation_space = env.states
action_space = env.actions()

for i, b in enumerate(initial_budgets):
    nb_alive = 0
    for j in range(repeat):
        print(f"====> Classic DQN {b} | Try {j + 1}")
        agent = DQNAgent(environment=env, memory=replay_capacity, batch_size=batch_size, initial_budget=b)
        states = env.reset()
        # print(states)
        agent.reset(states)
        exploration_matrix = np.zeros((env.observation_space.n, env.action_space.n))
        
        for t in count():
            actions = agent.act(states=states)
            states, terminated, reward = env.execute(actions=actions)
            
            done = terminated or t >= horizon or agent.b <= 0
            # print(states, actions.item())
            exploration_matrix[int(states), int(actions)] = exploration_matrix[int(states), int(actions)] + 1
            
            agent.observe(states, reward, terminated=terminated)

            
            if done:
                print("t = ", t)
                if dqn_time_mean[i] == -1:
                    dqn_time_mean[i] = t
                else:
                    dqn_time_mean[i] = dqn_time_mean[i] + (1/j) * (t - dqn_time_mean[i])
                break
        
        
        exploration_rate = (np.count_nonzero(exploration_matrix) / (env.observation_space.n * env.action_space.n)) * 100

        if dqn_exploration_rate[i] == -1:
            dqn_exploration_rate[i] = exploration_rate
        else:
            dqn_exploration_rate[i] = dqn_exploration_rate[i] + (1 / j) * (exploration_rate - dqn_exploration_rate[i])

        if agent.b > 0:
            nb_alive = nb_alive + 1

        dqn_alive_rate[i] = nb_alive / (j+1) * 100
        
         # budget evolution mean
        if dqn_budget_evolutions_mean[i] is None:
            dqn_budget_evolutions_mean[i] = agent.b
        else:
            dqn_budget_evolutions_mean[i] = dqn_budget_evolutions_mean[i] + (1 / j) * (agent.b - dqn_budget_evolutions_mean[i])

        # budget evolution max
        if dqn_budget_evolutions_max[i] is None:
            dqn_budget_evolutions_max[i] = agent.b
        else:
            dqn_budget_evolutions_max[i] = np.maximum(dqn_budget_evolutions_max[i], agent.b)

        # budget evolution min
        if dqn_budget_evolutions_min[i] is None:
            dqn_budget_evolutions_min[i] = agent.b
        else:
            dqn_budget_evolutions_min[i] = np.minimum(dqn_budget_evolutions_min[i], agent.b)
        

        print(f"Time mean : {dqn_time_mean[i]}")
        print(f"Alive rate : {dqn_alive_rate[i]}%")
        print(f"Exploration rate: {dqn_exploration_rate[i]}%")

====> Classic DQN 100 | Try 1




t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 2%
====> Classic DQN 100 | Try 2
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Classic DQN 100 | Try 3
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Classic DQN 100 | Try 4
t =  111
Time mean : 103
Alive rate : 0%
Exploration rate: 3%
====> Classic DQN 100 | Try 5
t =  111
Time mean : 105
Alive rate : 0%
Exploration rate: 3%
====> Classic DQN 100 | Try 6
t =  100
Time mean : 104
Alive rate : 0%
Exploration rate: 2%
====> Classic DQN 100 | Try 7
t =  100
Time mean : 103
Alive rate : 0%
Exploration rate: 2%
====> Classic DQN 100 | Try 8
t =  100
Time mean : 102
Alive rate : 0%
Exploration rate: 2%
====> Classic DQN 100 | Try 9
t =  100
Time mean : 101
Alive rate : 0%
Exploration rate: 2%
====> Classic DQN 100 | Try 10
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 2%
====> Classic DQN 100 | Try 11
t =  111
Time mean : 101
Alive rate : 0%
Exploration rate: 2%
====> Classic DQ

In [4]:
env = Environment.create(
    environment=CustomEnvironment, max_episode_timesteps=horizon
)

survival_dqn_time_mean = np.full(initial_budgets.shape, -1)
survival_dqn_exploration_rate = np.full(initial_budgets.shape, -1)
survival_dqn_alive_rate = np.full(initial_budgets.shape, -1)
survival_dqn_budget_evolutions = list()

survival_dqn_budget_evolutions_mean = np.full(initial_budgets.shape, None)
survival_dqn_budget_evolutions_max = np.full(initial_budgets.shape, None)
survival_dqn_budget_evolutions_min = np.full(initial_budgets.shape, None)


observation_space = env.states
action_space = env.actions()

for i, b in enumerate(initial_budgets):
    nb_alive = 0
    for j in range(repeat):
        print(f"====> Survival DQN {b} | Try {j + 1}")
        agent = SurvivalDQNAgent(environment=env, memory=replay_capacity, batch_size=batch_size, initial_budget=b, threshold=survival_threshold)
        states = env.reset()
        
        agent.reset(states)
        exploration_matrix = np.zeros((env.observation_space.n, env.action_space.n))
        
        for t in count():
            actions = agent.act(states=states)
            states, terminated, reward = env.execute(actions=actions)
            
            done = terminated or t >= horizon or agent.b <= 0
            
            exploration_matrix[int(states), int(actions)] = exploration_matrix[int(states), int(actions)] + 1
            

            agent.observe(states, reward, terminated=terminated)

            if done:
                print("t = ", t)
                if survival_dqn_time_mean[i] == -1:
                    survival_dqn_time_mean[i] = t
                else:
                    survival_dqn_time_mean[i] = survival_dqn_time_mean[i] + (1/j) * (t - survival_dqn_time_mean[i])
                break
        
        
        exploration_rate = (np.count_nonzero(exploration_matrix) / (env.observation_space.n * env.action_space.n)) * 100

        if survival_dqn_exploration_rate[i] == -1:
            survival_dqn_exploration_rate[i] = exploration_rate
        else:
            survival_dqn_exploration_rate[i] = survival_dqn_exploration_rate[i] + (1 / j) * (exploration_rate - survival_dqn_exploration_rate[i])

        if agent.b > 0:
            nb_alive = nb_alive + 1

        survival_dqn_alive_rate[i] = nb_alive / (j+1) * 100

        # budget evolution max
        if survival_dqn_budget_evolutions_max[i] is None:
            survival_dqn_budget_evolutions_max[i] = agent.b
        else:
            survival_dqn_budget_evolutions_max[i] = np.maximum(survival_dqn_budget_evolutions_max[i], agent.b)

        # budget evolution min
        if survival_dqn_budget_evolutions_min[i] is None:
            survival_dqn_budget_evolutions_min[i] = agent.b
        else:
            survival_dqn_budget_evolutions_min[i] = np.minimum(survival_dqn_budget_evolutions_min[i], agent.b)
        
        
        print(f"Time mean : {survival_dqn_time_mean[i]}")
        print(f"Alive rate : {survival_dqn_alive_rate[i]}%")
        print(f"Exploration rate: {survival_dqn_exploration_rate[i]}%")

====> Survival DQN 100 | Try 1
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 2
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 3
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 4
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 5
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 6
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 7
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 8
t =  100
Time mean : 100
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 9
t =  133
Time mean : 104
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 10
t =  100
Time mean : 103
Alive rate : 0%
Exploration rate: 3%
====> Survival DQN 100 | Try 11
t =  133
Time mean : 106
Alive rate :

KeyboardInterrupt: 

## Results

### Survival time

In [None]:
plt.plot(initial_budgets, dqn_time_mean, label="Classic QLearning")
plt.plot(initial_budgets, survival_dqn_time_mean, label="Survival QLearning")
plt.xlabel("Initial budget")
plt.ylabel("Survival time")
plt.legend()
plt.title(f"Survival time in function of initial budget with horizon {horizon} repeated {repeat} times and survival threshold {survival_threshold} for map of size {map_size}")
plt.show()

: 

### Alive rate

In [None]:
plt.plot(initial_budgets, dqn_alive_rate, label="Classic QLearning")
plt.plot(initial_budgets, survival_dqn_alive_rate, label="Survival QLearning")
plt.xlabel("Initial budget")
plt.ylabel("Alive rate (%)")
plt.legend()
plt.title(f"Alive rate in function of initial budget with horizon {horizon} repeated {repeat} times and survival threshold {survival_threshold} for map of size {map_size}")
plt.show()

: 

### Exploration rate

In [None]:
plt.plot(initial_budgets, dqn_exploration_rate, label="Classic QLearning")
plt.plot(initial_budgets, survival_dqn_exploration_rate, label="Survival QLearning")
plt.xlabel("Initial budget")
plt.ylabel("Exploration rate (%)")
plt.legend()
plt.title(f"Exploration rate (state + action) in function of initial budget with horizon {horizon} repeated {repeat} times and survival threshold {survival_threshold} for map of size {map_size}")
plt.show()

: 

### Budget evolution

In [None]:
print("Available budgets :")
print(initial_budgets)

In [5]:
plt.rcParams["figure.figsize"] = (10,10)

with_bounds = False

show_only = []

#### Classic DQN

In [None]:
t = np.arange(horizon)

for i, evo in enumerate(dqn_budget_evolutions_mean):
    if len(show_only) > 0 and initial_budgets[i] not in show_only:
        continue
    
    lines = plt.plot(t, evo, label=f"Start with budget {initial_budgets[i]}")

    if with_bounds:
        plt.fill_between(t, dqn_budget_evolutions_min[i], dqn_budget_evolutions_max[i], color=lines[0].get_color(), alpha=0.15)

plt.plot(t, np.full((horizon,), survival_threshold), color="magenta", label=f"Survival threshold {survival_threshold}")
plt.plot(t, 4.5 * t, color="lawngreen", ls="--", label=f"Budget optimal minor")
plt.plot(t, 49.5 * t, color="turquoise", ls="--", label=f"Budget optimal major")
plt.xlabel("Time")
plt.ylabel("Budget")
plt.legend()
plt.grid()
plt.title(f"Budget evolution in function of time with horizon {horizon} \n repeated {repeat} times and survival threshold {survival_threshold} for map of size {map_size}")
plt.show()

#### Survival DQN

In [None]:
import matplotlib.colors as mcolors

t = np.arange(horizon)

for i, evo in enumerate(survival_dqn_budget_evolutions_mean):
    if len(show_only) > 0 and initial_budgets[i] not in show_only:
        continue
    
    lines = plt.plot(t, evo, label=f"Start with budget {initial_budgets[i]}")

    if with_bounds:
        plt.fill_between(t, survival_dqn_budget_evolutions_min[i], survival_dqn_budget_evolutions_max[i], color=lines[0].get_color(), alpha=0.15)

plt.plot(t, np.full((horizon,), survival_threshold), color="magenta", label=f"Survival threshold {survival_threshold}")
plt.plot(t, 4.5 * t, color="lawngreen", ls="--", label=f"Budget optimal minor")
plt.plot(t, 49.5 * t, color="turquoise", ls="--", label=f"Budget optimal major")
plt.xlabel("Time")
plt.ylabel("Budget")
plt.legend()
plt.grid()
plt.title(f"Budget evolution in function of time with horizon {horizon} \n repeated {repeat} times and survival threshold {survival_threshold} for map of size {map_size}")
plt.show()