## DQN model

In [5]:
import pygame
from stable_baselines3 import DQN
import torch
import numpy as np

from envs import get_sb3_env

K = 3  # Number of top actions to display
NUM_AGENTS = 3  # Number of agents

# Initialize PyGame
pygame.init()

# Configuration
episodes = 10  # Number of episodes to run
font_size = 24
colors = [
    (255, 182, 193), # LANE LEFT
    (152, 251, 152), # IDLE
    (253, 253, 150), # LANE RIGHT
    (216, 191, 216), # FASTER
    (173, 216, 230)  # SLOWER
]

# Load the environment and model
env = get_sb3_env(n_agents=NUM_AGENTS, image_obs=False, density=1)
model = DQN.load("models/DQN/DQN_exp_0.zip")

# Get environment render size
rgb_array = env.render()
env_render_size = (rgb_array.shape[1], rgb_array.shape[0])  # (Width, Height)

# Grid configuration
row_height = 100  # Height for each row in the action display
row_height_agent = 34  # Height for each row in the action display
action_display_height = row_height*K + row_height_agent*(K+1)
window_size = (env_render_size[0], env_render_size[1] + action_display_height)

# Create a single extended PyGame window
screen = pygame.display.set_mode(window_size)
pygame.display.set_caption("DQN Demonstration")
font = pygame.font.Font(None, font_size)
clock = pygame.time.Clock()

# Helper function to draw a grid cell
def draw_cell(surface, text, color, rect, font):
    pygame.draw.rect(surface, color, rect)
    pygame.draw.rect(surface, (150, 150, 150), rect, 2)  # Border color
    text_surface = font.render(text, True, (0, 0, 0))
    surface.blit(text_surface, (rect.x + 10, rect.y + 10))

# Main loop for episodes
for episode in range(episodes):
    obs, _ = env.reset()
    done = False
    while not done:
        # Model selects an action
        obs_th = torch.as_tensor(obs, device=model.policy.device).unsqueeze(0)
        prob_actions_logits = model.policy.q_net(obs_th)
        prob_actions = torch.softmax(prob_actions_logits, dim=-1).squeeze().detach().cpu().numpy()
        top_k_actions = np.argsort(-prob_actions)[:K]
        action = [top_k_actions[0]]

        # Step the environment
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        # Render the environment
        rgb_array = env.render()
        env_surface = pygame.surfarray.make_surface(np.transpose(rgb_array, (1, 0, 2)))
        screen.blit(env_surface, (0, 0))  # Display the environment render

        # Display actions in the grid below the environment
        start_y = env_render_size[1]
        agent_width = env_render_size[0] // NUM_AGENTS

        # Header row
        for agent_idx in range(NUM_AGENTS):
            crashed = "CRASHED" if env.original_env.unwrapped.controlled_vehicles[agent_idx].crashed else ""
            draw_cell(screen, f"Agent {agent_idx} {crashed}", (220, 220, 220),
                      pygame.Rect(agent_idx * agent_width, start_y, agent_width, row_height_agent), font)
        start_y += row_height_agent

        # Top K actions
        translate_action = {0: 'LANE LEFT', 1: 'IDLE', 2: 'LANE RIGHT', 3: 'FASTER', 4: 'SLOWER'}
        
        for rank in range(K):
            draw_cell(screen, f"Top {rank + 1}", (220, 220, 220),
                          pygame.Rect(0, start_y, env_render_size[0], row_height_agent), font)
            start_y += row_height_agent
            
            for agent_idx in range(NUM_AGENTS):
                action_id = top_k_actions[rank]
                action_tuple = env._flat_to_tuple(action_id)
                action_name = translate_action[action_tuple[agent_idx].item()]
                draw_cell(screen, action_name, colors[action_tuple[agent_idx].item()],
                          pygame.Rect(agent_idx * agent_width, start_y, agent_width, row_height), font)
            start_y += row_height

        pygame.display.update()

        # Manage PyGame events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                done = True
                break

        clock.tick(30)

# Close everything
env.close()
pygame.quit()


AttributeError: 'NoneType' object has no attribute 'get_image'

## PPO model

In [4]:
import pygame
from stable_baselines3 import PPO
import torch
import numpy as np

from envs import get_sb3_env

K = 3  # Number of top actions to display
NUM_AGENTS = 3  # Number of agents

# Initialize PyGame
pygame.init()

# Configuration
episodes = 10  # Number of episodes to run
font_size = 24
colors = [
    (255, 182, 193), # LANE LEFT
    (152, 251, 152), # IDLE
    (253, 253, 150), # LANE RIGHT
    (216, 191, 216), # FASTER
    (173, 216, 230)  # SLOWER
]

# Load the environment and model
env = get_sb3_env(n_agents=NUM_AGENTS, image_obs=False, density=1)
model = PPO.load("models/PPO/PPO_exp_6.zip")

# Get environment render size
rgb_array = env.render()
env_render_size = (rgb_array.shape[1], rgb_array.shape[0])  # (Width, Height)

# Grid configuration
row_height = 100  # Height for each row in the action display
row_height_agent = 34  # Height for each row in the action display
action_display_height = row_height*K + row_height_agent*(K+1)
window_size = (env_render_size[0], env_render_size[1] + action_display_height)

# Create a single extended PyGame window
screen = pygame.display.set_mode(window_size)
pygame.display.set_caption("PPO Demonstration")
font = pygame.font.Font(None, font_size)
clock = pygame.time.Clock()

# Helper function to draw a grid cell
def draw_cell(surface, text, color, rect, font):
    pygame.draw.rect(surface, color, rect)
    pygame.draw.rect(surface, (150, 150, 150), rect, 2)  # Border color
    text_surface = font.render(text, True, (0, 0, 0))
    surface.blit(text_surface, (rect.x + 10, rect.y + 10))

# Main loop for episodes
for episode in range(episodes):
    obs, _ = env.reset()
    done = False
    while not done:
        # Model selects an action
        obs_th = torch.as_tensor(obs, device=model.policy.device).unsqueeze(0)
        prob_actions_logits = model.policy.get_distribution(obs_th).distribution.logits
        prob_actions = torch.softmax(prob_actions_logits, dim=-1).squeeze().detach().cpu().numpy()
        top_k_actions = np.argsort(-prob_actions)[:K]
        action = [top_k_actions[0]]

        # Step the environment
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        # Render the environment
        rgb_array = env.render()
        env_surface = pygame.surfarray.make_surface(np.transpose(rgb_array, (1, 0, 2)))
        screen.blit(env_surface, (0, 0))  # Display the environment render

        # Display actions in the grid below the environment
        start_y = env_render_size[1]
        agent_width = env_render_size[0] // NUM_AGENTS

        # Header row
        for agent_idx in range(NUM_AGENTS):
            crashed = "CRASHED" if env.original_env.unwrapped.controlled_vehicles[agent_idx].crashed else ""
            draw_cell(screen, f"Agent {agent_idx} {crashed}", (220, 220, 220),
                      pygame.Rect(agent_idx * agent_width, start_y, agent_width, row_height_agent), font)
        start_y += row_height_agent

        # Top K actions
        translate_action = {0: 'LANE LEFT', 1: 'IDLE', 2: 'LANE RIGHT', 3: 'FASTER', 4: 'SLOWER'}
        
        for rank in range(K):
            draw_cell(screen, f"Top {rank + 1}", (220, 220, 220),
                          pygame.Rect(0, start_y, env_render_size[0], row_height_agent), font)
            start_y += row_height_agent
            
            for agent_idx in range(NUM_AGENTS):
                action_id = top_k_actions[rank]
                action_tuple = env._flat_to_tuple(action_id)
                action_name = translate_action[action_tuple[agent_idx].item()]
                draw_cell(screen, action_name, colors[action_tuple[agent_idx].item()],
                          pygame.Rect(agent_idx * agent_width, start_y, agent_width, row_height), font)
            start_y += row_height

        pygame.display.update()

        # Manage PyGame events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                done = True
                break

        clock.tick(30)

# Close everything
env.close()
pygame.quit()


AttributeError: 'NoneType' object has no attribute 'get_image'