In [None]:
!pip install pygame
!pip install tyro

In [None]:
import gc
import os
import random
import time

from gym import spaces
import gymnasium as gym
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pygame
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tyro
from collections import deque
from collections import Counter
from dataclasses import dataclass
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter



In [None]:
# Inputs: xss - n lists of x values, where n is the amount of times the experiments 
def plot_binned_line_with_std(xss, yss, n_bins, y_label = "", title = "", plot_individuals = False):
    assert len(xss) == len(yss)

    mx = np.max([np.max(xs) for xs in xss])

    bin_size = (mx+1) / n_bins

    binned = []

    # What to plot on the x axis
    bins_x = np.linspace(0, mx, n_bins)
    
    for i in range(len(xss)):
        xs = xss[i]
        ys = yss[i]
        bins = [[] for _ in range(n_bins)]
        for j in range(len(xs)):
            bin_index = int(xs[j] / bin_size)
            bins[bin_index].append(ys[j])
        binned.append(bins)

    avgs = [[np.mean(bin) for bin in binned[i]] for i in range(len(xss))]
    avg = np.mean(avgs, axis=0)
    std = np.std(avgs, axis=0)

    plt.figure(figsize=(10, 6))

    if plot_individuals:
        c = 0
        for a in avgs:
            colours = ['red', 'green', 'blue', 'orange', 'purple', 'brown', 'pink', 'gray', 'cyan', 'magenta']
            plt.plot(bins_x, a, linestyle=':', linewidth=1, alpha=0.7, color=colours[c])
            c += 1

    plt.plot(bins_x, avg, label='Average', color='blue')
    plt.fill_between(bins_x, avg - std, avg + std, color='blue', alpha=0.2, label='Standard Deviation')

    plt.xlabel('Environment steps')
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_grid_heatmap(uncertainties, best_actions, colour_scheme = "ryg"):
    # Possible colour schemes: "ryg", "light_ryg", "hot/cold"
    height, width, _ = uncertainties.shape
    heatmap_values = np.zeros((height, width))
    for x in range(width):
        for y in range(height):
            action = best_actions[y, x]
            heatmap_values[y, x] = uncertainties[y, x, action]

    cmap = {
        "ryg": mcolors.LinearSegmentedColormap.from_list("stoplight", [(0, "green"), (0.5, "yellow"), (1, "red")]),
        "light ryg": mcolors.LinearSegmentedColormap.from_list("stoplight", [(0, "#66cdaa"), (0.5, "#fffacd"), (1, "#ff9999")]),
        "hot/cold": "coolwarm",
        }[colour_scheme]

    plt.figure(figsize=(10, 8))
    plt.imshow(heatmap_values, cmap=cmap, origin='upper', interpolation='nearest')
    plt.colorbar(label='Uncertainty')
    plt.title('Uncertainties for each best action per cell')
    plt.xlabel('X')
    plt.ylabel('Y')

    for x in range(width):
        for y in range(height):
            plt.text(x, y, f'{heatmap_values[y, x]:.1f}', ha='center', va='center', color='black')

    plt.show()


def plot_barchart_rewards(reward_histories, y_label, title):
    n_agents = len(reward_histories)
    all_rewards = [reward for rewards in reward_histories for reward in rewards]

    # Separate the unique values and their corresponding counts
    unique_values, reward_counts = np.unique(np.array(all_rewards), return_counts=True)

    # Calculate the width for the bars to be adjacent
    width = np.min(np.diff(unique_values)) / n_agents if len(unique_values) > 1 else 1.0

    plt.figure()
    plt.bar(unique_values, reward_counts/n_agents, width=width, align='center')
    plt.xlabel('Unique Values')
    plt.ylabel(y_label)
    plt.title(title)
    plt.xticks(unique_values)  # Ensure each unique value has a tick
    plt.tight_layout()
    plt.show()


def plot_barchart_episode_length(episode_length_histories, n_bins, y_label, title):
    # Determine the number of agents and the maximum episode length
    n_agents = len(episode_length_histories)
    max_length = max(max(lengths) for lengths in episode_length_histories)

    # Flatten the 2D list into a 1D list of all episode lengths
    all_lengths = [length for agent_lengths in episode_length_histories for length in agent_lengths]

    # Initialize the figure and axis
    fig, ax = plt.subplots()

    # Create the histogram data with the specified number of bins
    counts, bin_edges = np.histogram(all_lengths, bins=n_bins)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Divide the counts by the number of agents to get the average occurrence
    average_counts = counts / n_agents

    # Plot the bar chart
    plt.xlabel('Episode Length')
    plt.ylabel(y_label)
    plt.title(title)
    plt.bar(bin_centers, average_counts, width=(bin_edges[1] - bin_edges[0]) - 0.1, align='center')


def qtable_directions_map(qtable, map):
    """Get the best learned action & map it to arrows."""
    directions = {0: "←", 1: "↓", 2: "→", 3: "↑"}
    qtable = qtable.flatten()
    for idx, val in enumerate(qtable):
        qtable[int(idx)] = directions[int(val)]
    qtable_directions = qtable.reshape(len(map), len(map[0]))
    return qtable_directions

def plot_grid_statespace(state_history, optimal_moves, state_map):
    qtable_directions = qtable_directions_map(optimal_moves, state_map)
    for y in range(len(qtable_directions)):
        for x in range(len(qtable_directions[0])):
            if state_map[y][x] == 'X' or state_map[y][x] == 'G':
                qtable_directions[y, x] = ''
    state_counts = np.bincount(state_history)

    # Step 2: Normalize the visit counts
    max_count = np.max(state_counts)
    normalized_counts = state_counts / max_count
    reshaped_counts = np.reshape(normalized_counts, (len(state_map), len(state_map[0])))
    plt.figure()
    sns.heatmap(
        reshaped_counts,
        annot=qtable_directions,
        fmt="",
        cmap=sns.color_palette("Blues", as_cmap=True),
        linewidths=0.7,
        linecolor="black",
        xticklabels=[],
        yticklabels=[],
        annot_kws={"fontsize": "xx-large"},
    ).set(title="Learned Q-values\nArrows represent best action")

In [None]:
class FrozenLakeEnv(gym.Env):
    def __init__(self, grid):
        super(FrozenLakeEnv, self).__init__()
        self.grid = grid
        self.grid_height = len(grid)
        self.grid_width = len(grid[0])
        [self.start_y], [self.start_x] = np.where(np.array(grid) == 'S')
        self.slip_probability = 1/3 # Probability to slip to one side (so the chance of slipping in any direction is 2 times this value)
        assert self.slip_probability <= 1/2
        self.action_space = spaces.Discrete(4)  # Left, Down, Right, Up
        self.observation_space = spaces.Discrete(self.grid_height * self.grid_width)

        self.state_action_count = {}
        for x in range(self.grid_width):
            for y in range(self.grid_height):
                self.state_action_count[(x, y)] = {0: 0, 1: 0, 2: 0, 3: 0}

    def reset(self):
        # Top left corner is 0, 0
        self.state = (self.start_x, self.start_y)
        return self.to_observation(self.state)
    
    def step(self, action):
        self.state_action_count[self.state][action] += 1
        x, y = self.state

        # Define possible actions for each chosen direction
        # Make sure the first action in the array is the action itself (no slip)
        possible_actions = {
                0: [0, 3, 1],  # Left
                1: [1, 0, 2],  # Down
                2: [2, 1, 3],  # Right
                3: [3, 2, 0]   # Up
            }            

        # Choose a random action from the possible actions according to self.slip_probability
        p = self.slip_probability
        action = np.random.choice(possible_actions[action], p=[1-2*p, p, p])
        # print("Actual action", ["left", "down", "right", "up"][action])

        # Move in the chosen direction if its within bounds
        if action == 0 and x > 0:
            x -= 1
        elif action == 1 and y < self.grid_height - 1:
            y += 1
        elif action == 2 and x < self.grid_width - 1:
            x += 1
        elif action == 3 and y > 0:
            y -= 1

        self.state = (x, y)
        reward = 0
        done = False

        # Check state of the cell
        if self.grid[y][x] == 'X':
            reward = -5
            done = True
        elif self.grid[y][x] == 'G':
            reward = 10
            done = True
            
#         print(f"State: {self.state}, Action: {action}, Reward: {reward}, Done: {done}")

        return self.to_observation(self.state), reward, done, {}

    def to_observation(self, state):
        x, y = state
        return y * self.grid_width + x

    def render(self):
        grid = np.full((self.grid_height, self.grid_width), ' ')
        for y in range(self.grid_height):
            for x in range(self.grid_width):
                grid[y, x] = self.grid[y][x]
        x, y = self.state
        grid[y, x] = 'A'
        print('\n'.join(' '.join(row) for row in grid))
        print()

def make_frozenlake(grid):
    return FrozenLakeEnv(grid)

# Example grid
safe_3x3 = [
    ['S', 'X', '.'],
    ['.', '.', '.'],
    ['.', '.', 'G']
]

# env = make_frozenlake(safe_3x3) # 3x3 grid call

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
gc.collect()
torch.cuda.empty_cache()
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Used for debugging; CUDA related errors shown immediately.

# Seed everything for reproducible results
seed = 2024
np.random.seed(seed)
np.random.default_rng(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
class ReplayMemory:
    def __init__(self, capacity, device):
        self.capacity = capacity
        self.device = device
        self.states = deque(maxlen=capacity)
        self.actions = deque(maxlen=capacity)
        self.next_states = deque(maxlen=capacity)
        self.rewards = deque(maxlen=capacity)
        self.dones = deque(maxlen=capacity)

    def store(self, state, action, next_state, reward, done):
        self.states.append(state)
        self.actions.append(action)
        self.next_states.append(next_state)
        self.rewards.append(reward)
        self.dones.append(done)

    def sample(self, batch_size):
        indices = np.random.choice(len(self), size=batch_size, replace=False)
        device = self.device

        states = torch.stack([torch.as_tensor(self.states[i], dtype=torch.float32, device=device) for i in indices]).to(device)
        actions = torch.as_tensor([self.actions[i] for i in indices], dtype=torch.long, device=device)
        next_states = torch.stack([torch.as_tensor(self.next_states[i], dtype=torch.float32, device=device) for i in indices]).to(device)
        rewards = torch.as_tensor([self.rewards[i] for i in indices], dtype=torch.float32, device=device)
        dones = torch.as_tensor([self.dones[i] for i in indices], dtype=torch.bool, device=device)

        return states, actions, next_states, rewards, dones

    def __len__(self):
        return len(self.dones)

In [None]:
class C51Network(nn.Module):
    def __init__(self, num_actions, input_dim, support_size, v_min, v_max):
        super(C51Network, self).__init__()
        self.support_size = support_size
        self.num_actions = num_actions
        self.v_min = v_min
        self.v_max = v_max
        self.delta_z = (v_max - v_min) / (support_size - 1)
        self.support = torch.linspace(v_min, v_max, support_size).to(device)

        self.FC = nn.Sequential(
            nn.Linear(input_dim, 12),
            nn.ReLU(inplace=True),
            nn.Linear(12, 8),
            nn.ReLU(inplace=True),
            nn.Linear(8, num_actions * support_size)
        )
        
#         self.FC = nn.Sequential(
#             nn.Linear(input_dim, 120),
#             nn.ReLU(),
#             nn.Linear(120, 84),
#             nn.ReLU(),
#             nn.Linear(84, num_actions * support_size),
#         )

        # Initialize FC layer weights using He initialization
        for layer in [self.FC]:
            for module in layer:
                if isinstance(module, nn.Linear):
                    nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.FC(x)
        x = x.view(-1, self.num_actions, self.support_size)
        x = F.softmax(x, dim=2)
        return x

    def get_q_values(self, dist):
        q_values = torch.sum(dist * self.support, dim=2)
#         print("q_values", q_values)
        return q_values

In [None]:
class C51Agent:
    def __init__(self, env, seed, device, epsilon_max, epsilon_min, epsilon_decay,
                 clip_grad_norm, learning_rate, discount, memory_capacity, support_size, v_min, v_max):
        self.loss_history = []
        self.running_loss = 0
        self.learned_counts = 0

        self.epsilon_max = epsilon_max
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.discount = discount

        self.action_space = env.action_space
        self.action_space.seed(seed)
        self.observation_space = env.observation_space
        self.replay_memory = ReplayMemory(memory_capacity, device)

        self.support_size = support_size
        self.v_min = v_min
        self.v_max = v_max
        self.delta_z = (v_max - v_min) / (support_size - 1)

        self.main_network = C51Network(num_actions=self.action_space.n, input_dim=self.observation_space.n,
                                       support_size=support_size, v_min=v_min, v_max=v_max).to(device)
        self.target_network = C51Network(num_actions=self.action_space.n, input_dim=self.observation_space.n,
                                         support_size=support_size, v_min=v_min, v_max=v_max).to(device).eval()
        self.target_network.load_state_dict(self.main_network.state_dict())

        self.clip_grad_norm = clip_grad_norm
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.main_network.parameters(), lr=learning_rate)

    def select_action(self, state):
        if np.random.random() < self.epsilon_max:
#             print("action_space", self.action_space)
            return self.action_space.sample()
        else:
            with torch.no_grad():
                Q_values = self.main_network.get_q_values(self.main_network(state))
                action = torch.argmax(Q_values).item()
#                 print(f"State: {state}, Selected action: {action}")
                return action

    def learn(self, batch_size, done):
        if len(self.replay_memory) < batch_size:
            return

        states, actions, next_states, rewards, dones = self.replay_memory.sample(batch_size)
        actions = actions.unsqueeze(1)
        rewards = rewards.unsqueeze(1)
        dones = dones.unsqueeze(1).float()

        predicted_dist = self.main_network(states).gather(1, actions.unsqueeze(1).expand(-1, -1, self.support_size)).squeeze(1)

        with torch.no_grad():
            next_dist = self.target_network(next_states)
            next_q_values = self.target_network.get_q_values(next_dist)
            next_actions = next_q_values.max(1)[1]
#             print("next_actions:", next_actions)
            next_dist = next_dist[range(batch_size), next_actions]

            Tz = rewards + (1 - dones) * self.discount * self.main_network.support.unsqueeze(0)
            Tz = Tz.clamp(min=self.v_min, max=self.v_max)
            b = (Tz - self.v_min) / self.delta_z
            l = b.floor().long()
            u = b.ceil().long()

            m = torch.zeros(batch_size, self.support_size).to(device)
            offset = torch.linspace(0, (batch_size - 1) * self.support_size, batch_size).unsqueeze(1).expand(batch_size, self.support_size).to(device)
            m.view(-1).index_add_(0, (l + offset).view(-1).long(), (next_dist * (u.float() - b)).view(-1))
            m.view(-1).index_add_(0, (u + offset).view(-1).long(), (next_dist * (b - l.float())).view(-1))
#             for i in range(batch_size):
#                 for j in range(self.support_size):
#                     m[i, l[i, j]] += next_dist[i, j] * (u[i, j] - b[i, j])
#                     m[i, u[i, j]] += next_dist[i, j] * (b[i, j] - l[i, j])


            

        loss = -torch.sum(m * predicted_dist.log(), dim=1).mean()
        self.optimizer.zero_grad()
        loss.backward()
#         print(f"Loss: {loss.item()}, Epsilon: {self.epsilon_max}")
        torch.nn.utils.clip_grad_norm_(self.main_network.parameters(), self.clip_grad_norm)
        self.optimizer.step()

        self.running_loss += loss.item()
        self.learned_counts += 1
        if done:
            episode_loss = self.running_loss / self.learned_counts
            self.loss_history.append(episode_loss)
            self.running_loss = 0
            self.learned_counts = 0

    def hard_update(self):
        self.target_network.load_state_dict(self.main_network.state_dict())

    def update_epsilon(self):
        self.epsilon_max = max(self.epsilon_min, self.epsilon_max * self.epsilon_decay)

    def save(self, path):
        torch.save(self.main_network.state_dict(), path)

In [None]:
class Model_TrainTest:
    def __init__(self, seed, device, hyperparams, agent_type='c51'):
        # Define RL Hyperparameters
        self.train_mode             = hyperparams["train_mode"]
        self.RL_load_path           = hyperparams["RL_load_path"]
        self.save_path              = hyperparams["save_path"]
        self.save_interval          = hyperparams["save_interval"]
        
        self.clip_grad_norm         = hyperparams["clip_grad_norm"]
        self.learning_rate          = hyperparams["learning_rate"]
        self.discount_factor        = hyperparams["discount_factor"]
        self.batch_size             = hyperparams["batch_size"]
        self.update_frequency       = hyperparams["update_frequency"]
        self.max_episodes           = hyperparams["max_episodes"]
        self.max_episode_length     = hyperparams["max_episode_length"]
        self.max_steps              = hyperparams["max_steps"]
        self.render                 = hyperparams["render"]
        
        self.epsilon_max            = hyperparams["epsilon_max"]
        self.epsilon_min            = hyperparams["epsilon_min"]
        self.epsilon_decay          = hyperparams["epsilon_decay"]
        
        self.memory_capacity        = hyperparams["memory_capacity"]
        
        self.num_states             = hyperparams["num_states"]
        # self.map_size               = hyperparams["map_size"]
        self.map                    = hyperparams["map"]
        self.render_fps             = hyperparams["render_fps"]
        
        self.n_bins                 = hyperparams["n_bins"]
        
        self.env = make_frozenlake(
            self.map)
        self.env.metadata['render_fps'] = self.render_fps # For max frame rate make it 0
        self.agent_type = agent_type
        # Define the agent class
        self.agent = None
        self.reset_agent()
        self.agent_type = agent_type
        self.state_history = []
        self.loss_bins = []

    def state_preprocess(self, state:int, num_states:int):
        """
        Convert an state to a tensor and basically it encodes the state into
        an onehot vector. For example, the return can be something like tensor([0,0,1,0,0])
        which could mean agent is at state 2 from total of 5 states.

        """
        onehot_vector = torch.zeros(num_states, dtype=torch.float32, device=device)
        onehot_vector[state] = 1
        return onehot_vector
    
    
    def reset_agent(self):
#         if self.agent_type == 'dqn':
#             self.agent = DQN_Agent(env            = self.env,
#                                 seed              = seed,
#                                 device            = device,
#                                 epsilon_max       = self.epsilon_max,
#                                 epsilon_min       = self.epsilon_min,
#                                 epsilon_decay     = self.epsilon_decay,
#                                 clip_grad_norm    = self.clip_grad_norm,
#                                 learning_rate     = self.learning_rate,
#                                 discount          = self.discount_factor,
#                                 memory_capacity   = self.memory_capacity)
#         if self.agent_type == 'ucb':
#             self.agent = UCB_Agent(env            = self.env,
#                                 seed              = seed,
#                                 device            = device,
#                                 epsilon_max       = self.epsilon_max,
#                                 epsilon_min       = self.epsilon_min,
#                                 epsilon_decay     = self.epsilon_decay,
#                                 clip_grad_norm    = self.clip_grad_norm,
#                                 learning_rate     = self.learning_rate,
#                                 discount          = self.discount_factor,
#                                 memory_capacity   = self.memory_capacity)
#         if self.agent_type == 'ids':
#             self.agent = IDS_Agent(   env               = self.env,
#                                 seed              = seed,
#                                 device            = device,
#                                 epsilon_max       = self.epsilon_max,
#                                 epsilon_min       = self.epsilon_min,
#                                 epsilon_decay     = self.epsilon_decay,
#                                 clip_grad_norm    = self.clip_grad_norm,
#                                 learning_rate     = self.learning_rate,
#                                 discount          = self.discount_factor,
#                                 memory_capacity   = self.memory_capacity,
#                                 num_ensembles     = 10)
        if self.agent_type == 'c51':
            support_size = 51 # n atoms
            v_min = -10
            v_max = 10
            self.agent = C51Agent(env=self.env, 
                                  seed=seed, 
                                  device=device, 
                                  epsilon_max=self.epsilon_max,
                                  epsilon_min=self.epsilon_min, 
                                  epsilon_decay=self.epsilon_decay,
                                  clip_grad_norm=self.clip_grad_norm, 
                                  learning_rate=self.learning_rate,
                                  discount=self.discount_factor,
                                  memory_capacity=self.memory_capacity,
                                  support_size=support_size,
                                  v_min=v_min, 
                                  v_max=v_max)
    
    def run(self, training, agent=None):
        """                
        Reinforcement learning training loop.
        """        
        total_steps = 0
        episode = 0
        self.reward_history = []
        self.episode_length_history = []
        self.environment_steps_history = []
        if not training:
            self.agent = agent
        else:
            self.reset_agent()

        # Training loop over episodes
        while total_steps < self.max_steps and episode < self.max_episodes:
            state = self.env.reset()
            self.state_history.append(state)
            state = self.state_preprocess(state, num_states=self.num_states)
            done = False
            truncation = False
            step_size = 0
            episode_reward = 0
                                                
            while not done and not truncation:
                action = self.agent.select_action(state)
                next_state, reward, done, truncation = self.env.step(action)
                self.state_history.append(next_state)
                next_state = self.state_preprocess(next_state, num_states=self.num_states)

                if(training):
                    self.agent.replay_memory.store(state, action, next_state, reward, done)

                    # if len(self.agent.replay_memory) > self.batch_size and sum(self.reward_history) > 0: This was the original but since we have a negative reward so the sum doesn't work the same way anymore
                    if len(self.agent.replay_memory) > self.batch_size:
                        self.agent.learn(self.batch_size, (done or truncation))

                        # Update target-network weights
                        if total_steps % self.update_frequency == 0:
                            self.agent.hard_update()

                    # Penalize if the episode is truncated (BAD PRACTICE, NON MARKOVIAN)
                    if not truncation and step_size >= self.max_episode_length:
                        truncation = True
                        reward = 0
                
                if self.render:
                    self.env.render()
                    print(f"Step: {step_size}, State: {state}, Action: {action}, Reward: {reward}, Done: {done}")

                state = next_state
                episode_reward += reward
                step_size +=1
                            
            # Appends for tracking history
            total_steps += step_size
            self.reward_history.append(episode_reward) # episode reward    
            self.episode_length_history.append(step_size) # episode length
            self.environment_steps_history.append(total_steps) # total steps
            episode += 1
                                                                           
            # Decay epsilon at the end of each episode
            self.agent.update_epsilon()
            
            result = (f"Episode: {episode}, "
                      f"Total Steps: {total_steps}, "
                      f"Raw Reward: {episode_reward:.2f}, " 
                      f"Episode Length: {step_size}, ")
#             print(result)
            
        return self.agent
                                                                    

    # def test(self, max_episodes):
    #     """
    #     Reinforcement learning policy evaluation.
    #     """
    #
    #     # # Load the weights of the test_network
    #     # self.agent.main_network.load_state_dict(torch.load(self.RL_load_path))
    #     # self.agent.main_network.eval()
    #
    #     # Testing loop over episodes
    #     for episode in range(1, max_episodes+1):
    #         state = self.env.reset()
    #         done = False
    #         truncation = False
    #         step_size = 0
    #         episode_reward = 0
    #
    #         while not done and not truncation:
    #             state = self.state_preprocess(state, num_states=self.num_states)
    #             action = self.agent.select_action(state)
    #             next_state, reward, done, truncation, _ = self.env.step(action)
    #
    #             state = next_state
    #             episode_reward += reward
    #             step_size += 1
    #
    #         # Print log
    #         result = (f"Episode: {episode}, "
    #                   f"Steps: {step_size:}, "
    #                   f"Reward: {episode_reward:.2f}, ")
    #         print(result)
    
    def plot_training(self):
        plot_binned_line_with_std([self.environment_steps_history], [self.reward_history], self.n_bins, y_label="Reward", title=f"Reward over time ({self.agent_type})", plot_individuals=False)
        plot_binned_line_with_std([self.environment_steps_history], [self.episode_length_history], self.n_bins, y_label="Episode Length", title=f"Episode Length over time ({self.agent_type})", plot_individuals=False)
        print(len(self.environment_steps_history[3:]), len(self.agent.loss_history))
        plot_binned_line_with_std([self.environment_steps_history[3:]], [self.agent.loss_history], self.n_bins, y_label="Loss", title=f"Loss over time ({self.agent_type})", plot_individuals=False)

    def get_plotting_data(self):
        padded_loss_history = self.agent.loss_history
        if len(padded_loss_history) == 0:
            padded_loss_history = []
        else:
            while len(padded_loss_history) < len(self.environment_steps_history):
                padded_loss_history = [padded_loss_history[0]] + padded_loss_history
        return self.environment_steps_history, self.reward_history, self.episode_length_history, padded_loss_history, self.state_history

In [None]:
no_aleatoric_uncertainty_3x3 = [
    ['S', '.', '.'],
    ['.', '.', '.'],
    ['.', '.', 'G']
]

safe_3x3 = [
    ['S', 'X', '.'],
    ['.', '.', '.'],
    ['.', '.', 'G']
]

long_safe_4x3 = [
    ['S', 'X', 'G'],
    ['.', 'X', '.'],
    ['.', 'X', '.'],
    ['.', '.', '.']
]

short_unsafe_long_safe_4x3 = [
    ['S', 'X', 'G'],
    ['.', '.', '.'],
    ['.', 'X', '.'],
    ['.', '.', '.']
]

unsafe_path_safe_area_3x4 = [
    ['S', '.', 'X', '.'],
    ['.', '.', '.', '.'],
    ['.', '.', 'X', 'G']
]

In [None]:
def plot_results(env_steps_histories, eps_length_histories, map_state_histories, agent_reward_histories, agent_loss_histories, state_map, agents, n_bins):
    if train_mode:
        plot_binned_line_with_std(env_steps_histories, agent_reward_histories, n_bins, y_label="Reward", title="Reward over time", plot_individuals=True)
        plot_binned_line_with_std(env_steps_histories, eps_length_histories, n_bins, y_label="Episode Length", title="Episode Length over time", plot_individuals=True)
        plot_binned_line_with_std(env_steps_histories, agent_loss_histories, n_bins, y_label="Loss", title="Loss over time", plot_individuals=True)
    else:
        for i in range(len(map_state_histories)):
            state_table = np.zeros_like(state_map)
            print(agents[i].get_best_actions(state_table, device))
            plot_grid_statespace(map_state_histories[i], agents[i].get_best_actions(state_table, device), state_map)
        plot_barchart_rewards(agent_reward_histories, y_label="Reward Counts", title="Bar Chart of Unique Values and Their Corresponding Reward Counts")
        plot_barchart_episode_length(eps_length_histories, n_bins, y_label='Amount of Finished Runs', title="Bar Chart of Unique Values and Their Corresponding Reward Counts")
        
def run_model(agents, agent_type):
    DRL = Model_TrainTest(seed, device, RL_hyperparams, agent_type=agent_type) # Define the instance
    environment_steps_histories, reward_histories, episode_length_histories, loss_histories, state_histories = [], [], [], [], []
    if train_mode:
        for i in range(10):
            print(i)
            trained_agent = DRL.run(train_mode, None)
            agents.append(trained_agent)
    
            # Save all the data for plotting (plotted in next cell)
            environment_steps_history, reward_history, episode_length_history, loss_history, state_history = DRL.get_plotting_data()
            environment_steps_histories.append(environment_steps_history)
            reward_histories.append(reward_history)
            episode_length_histories.append(episode_length_history)
            loss_histories.append(loss_history)
            state_histories.append(state_history)
    else:
        for agent in agents:
            DRL.run(train_mode, agent)
            
            # Save all the data for plotting (plotted in next cell)
            environment_steps_history, reward_history, episode_length_history, loss_history, state_history = DRL.get_plotting_data()
            environment_steps_histories.append(environment_steps_history)
            reward_histories.append(reward_history)
            episode_length_histories.append(episode_length_history)
            loss_histories.append(loss_history)
            state_histories.append(state_history)
    return environment_steps_histories, reward_histories, episode_length_histories, loss_histories, state_histories, agents

In [None]:
# Parameters:
train_mode = True
agent_type = 'c51'               # 'dqn', 'ucb', 'ids' or 'c51'

render = not train_mode
# map_size = 4 # 4x4 or 8x8 (outdated)
map_mame = "safe_3x3"
state_map = short_unsafe_long_safe_4x3
RL_hyperparams = {
    "train_mode"            : train_mode,
    "RL_load_path"          : f'./level_stats/{map_mame}/final_weights' + '_' + '3000' + '.pth',
    "save_path"             : f'./level_stats/{map_mame}/final_weights',
    "save_interval"         : 500,
    
    "clip_grad_norm"        : 3,
    "learning_rate"         : 8e-6,
    "discount_factor"       : 0.94,
    "batch_size"            : 64,
    "update_frequency"      : 10,
    "max_episodes"          : 100000           if train_mode else 200,
    "max_steps"             : 300000,
    "max_episode_length"    : 600,
    "render"                : render,
    
    "epsilon_max"           : 0.999         if train_mode else -1,
    "epsilon_min"           : 0.01,
    "epsilon_decay"         : 0.999,
    
    "memory_capacity"       : 10_000        if train_mode else 0,
        
    # "map_size"              : map_size,
    "num_states"            : len(state_map) * len(state_map[0]),    # 3 rows in your example
    "map"                   : state_map,
    "render_fps"            : 6,
    "n_bins"                : 100
    }

agent_histories = []
environment_steps_histories, reward_histories, episode_length_histories, loss_histories, state_histories, agent_histories = run_model(agent_histories, agent_type)

In [None]:


n_bins = 20
plot_results(environment_steps_histories, episode_length_histories, state_histories, reward_histories, loss_histories, state_map, agent_histories, n_bins)

