In [3]:
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from gym import Env
import gymnasium as gym
from gymnasium import spaces
from gym.spaces import Discrete, Box
import numpy as np
import pygame

In [45]:
class Agent:
    def __init__(self, position, speed=1.0):
        self.position = position
        self.speed = speed

    def update_position(self):
        self.position += self.speed

    def update_speed(self, action):
        # 0: slow down, 1: maintain speed, 2: speed up
        if action == 0:
            self.speed = max(0, self.speed - 0.01)
        elif action == 2:
            self.speed = min(1.0, self.speed + 0.01)

class FollowLeaderEnv(gym.Env):
    def __init__(self, num_agents, visualize=False):
        super(FollowLeaderEnv, self).__init__()

        # Define action space: 0 - slow down, 1 - keep speed, 2 - speed up
        self.action_space = spaces.Discrete(3)

        # Observation space: distance to leader, agent speed, leader speed
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32)

        # Initialize agents
        self.agents = [Agent(position=0.0 + i * 50) for i in range(num_agents)]

        # Target distance to maintain
        self.target_distance = 50.0

        # Leader's speed range
        self.leader_speed_min = 0.25
        self.leader_speed_max = 1.0

        if visualize:
            self.screen = None
            self.clock = None
            self.is_pygame_initialized = False
            self.init_pygame()

    def init_pygame(self):
        print("Initializing Pygame...")
        pygame.init()
        self.screen = pygame.display.set_mode((1000, 300))
        self.clock = pygame.time.Clock()
        self.is_pygame_initialized = True
        pygame.font.init()
        self.font = pygame.font.Font(None, 36)

    def reset(self):
        # Reset agents and environment
        self.agents = [Agent(position=0.0 + i * 50) for i in range(len(self.agents))]
        return self._get_observation(self.agents[0], self.agents[-1])  # Return initial observation for the first agent

    def step(self, action, agent_idx):
        agent = self.agents[agent_idx]
        leader = self.agents[agent_idx+1]

        # Update agent speed based on action
        agent.update_speed(action)
        agent.update_position()

        # Calculate distance to leader
        distance_to_leader = leader.position - agent.position

        # Reward calculation
        reward = -abs(distance_to_leader - self.target_distance)
        done = False

        # Penalty for being too close or too far from the leader
        if distance_to_leader < 5 or distance_to_leader > 100:
            done = True

        if self.agents[-1].position >= 900:
            reward += 100 if 40 < distance_to_leader < 60 else 50
            done = True

        # Observation: distance to leader, agent speed, leader speed
        state = np.array([distance_to_leader, agent.speed, leader.speed], dtype=np.float32)

        return state, reward, done, {}

    def leader_move(self):
        leader = self.agents[-1]
        leader_action = np.random.choice([0, 1, 2])
        leader.update_speed(leader_action)
        leader.update_position()

    def render(self):
        if not self.is_pygame_initialized:
            return

        self.screen.fill((0, 0, 0))

        # Draw end
        pygame.draw.circle(self.screen, (255, 0, 0), (900, 150), 10)

        # Draw the agents
        for agent in self.agents:
            pygame.draw.circle(self.screen, (0, 255, 0), (int(agent.position), 150), 10)

        # Display leader speed
        speed_text = self.font.render(f"Leader Speed: {round(self.agents[-1].speed, 3)}", True, (255, 255, 255))
        self.screen.blit(speed_text, (10, 10))

        pygame.display.flip()
        self.clock.tick(60)

    def close(self):
        if self.is_pygame_initialized:
            print("Closing Pygame...")
            pygame.quit()
            self.is_pygame_initialized = False



In [40]:
# Neural network for Q-learning (DQN)
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [48]:
import time

# Initialize the environment with visualization
env = FollowLeaderEnv(num_agents=6, visualize=True)  # Assuming you want 5 agents for the example

# Get the state and action dimensions from the environment
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Initialize a new DQN network for testing
dqn_test = DQN(state_dim, action_dim)

# Load the trained model's weights
dqn_test.load_state_dict(torch.load("policy_net_weights_final.pth"))

# Set the model to evaluation mode 
dqn_test.eval()

done = False

# Loop to run the episode
while not done:
    # Check for pygame events
    for event in pygame.event.get():
        if event.type == pygame.KEYDOWN:
            keys = pygame.key.get_pressed()
            if keys[pygame.K_ESCAPE]:
                done = True
        if event.type == pygame.QUIT:
            done = True

    time.sleep(0.05)  # Add a small delay to slow down the visualization (optional)

    # Loop over each agent to select actions and update their states
    for i in range(len(env.agents) - 1):  # '-1' excludes the leader from the loop
        agent = env.agents[i]

        # Choose the best action for the agent (exploit, no exploration during testing)
        with torch.no_grad():
            action = torch.argmax(dqn_test(torch.FloatTensor([env.agents[i+1].position-env.agents[i].position, env.agents[i].speed, env.agents[i+1].speed]))).item()

        # Step the environment with the chosen action
        next_state, reward, done, _ = env.step(action, i)

        if done:
            break

        # Render the environment to visualize the agent's movements
        env.render()

    # Update the leader's position independently after all agents move
    env.leader_move()

    # Render the environment again after leader movement
    env.render()

# Close the environment
env.close()

Initializing Pygame...
Closing Pygame...
