In [None]:
#import
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import argparse
import time
from typing import Tuple, Optional, Dict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random

from env import MazeEnv

In [2]:
#open config file
with open('config.json', 'r') as f:
    config = json.load(f)

In [3]:
env = MazeEnv(
    size=20,
    walls_proportion=0.2,
    num_dynamic_obstacles=1,
    num_agents=1,
    communication_range=config["communication_range"],
    max_lidar_dist_main=config["max_lidar_dist_main"],
    max_lidar_dist_second=config["max_lidar_dist_second"],
    max_episode_steps=config["max_episode_steps"],
    render_mode=config["render_mode"]
)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:
class DQN(nn.Module):
    """
    Deep Q Network to model the Q function.
    """
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [7]:
# replay buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

In [8]:
import heapq

In [10]:
# A star algorithm
class AStarPathfinder:
    def __init__(self, grid_size, obstacles):
        self.grid_size = grid_size
        self.obstacles = obstacles # list of (x, y) tuples

    def heuristic(self, a, b):
        return abs(a[0] - b[0]) + abs(a[1] - b[1])

    def get_neighbors(self, node):
        x, y = node
        neighbors = [(x+1, y), (x-1, y), (x, y+1), (x, y-1)]
        return [n for n in neighbors if 0 <= n[0] < self.grid_size[0] and 0 <= n[1] < self.grid_size[1] and n not in self.obstacles]

    def find_path(self, start, goal):
        open_list = []
        heapq.heappush(open_list, (0, start))
        came_from = {}
        g_score = {start: 0}
        f_score = {start: self.heuristic(start, goal)}

        while open_list:
            _, current = heapq.heappop(open_list)
            if current == goal:
                path = []
                while current in came_from:
                    path.append(current)
                    current = came_from[current]
                path.reverse()
                return path

            for neighbor in self.get_neighbors(current):
                tentative_g_score = g_score[current] + 1
                if neighbor not in g_score or tentative_g_score < g_score[neighbor]:
                    came_from[neighbor] = current
                    g_score[neighbor] = tentative_g_score
                    f_score[neighbor] = tentative_g_score + self.heuristic(neighbor, goal)
                    heapq.heappush(open_list, (f_score[neighbor], neighbor))
        return []


In [None]:
class MyAgents():
    def __init__(self, state_dim: int, action_dim: int, device : str, lr: float):
        self.device = torch.device(device)        
        # state dim = 11 pour preprocess_states
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # define the Networks
        self.q_network = DQN.DQN(input_dim=state_dim, output_dim=action_dim).to(self.device)
        self.target_network = DQN.DQN(input_dim=state_dim, output_dim=action_dim).to(self.device)
        
        # define the optimizer
        self.optimizers = torch.optim.Adam(self.q_network.parameters(), lr=lr)
        self.loss_fn = torch.nn.SmoothL1Loss()
        
        # define the hyperparameters
        self.batch_size = 128
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.gamma = 0.99
        self.TAU = 0.005
        
        # define the replay buffer
        self.replay_buffer = ReplayBuffer(maxlen=10000)
    
    def process_states(self, state: list) :
        final = []
        if state[3]==0: # état du drone
            main_state = np.concatenate((state[:3], state[6:12]))
            goal = (state[4], state[5])
            main_state[0] = main_state[0] - goal[0]
            main_state[1] = main_state[1] - goal[1]
            other_state = state[12:]
            num_drone = int(len(other_state)/10)
            avg_position = [(state[0], state[1])]

            for i in range(num_drone): # on récupère la postion de chaque drone pour calculer la position moyenne de l'essaim 
                if other_state[10*i+3] == 0 :
                    avg_position.append(( other_state[10*i+0],  other_state[10*i+1]))
            mean_position = np.round(np.mean(avg_position, axis=0)).astype(int) - goal
            
            # l'état final contient la position du drone, du goal et les infos lidar et la position moyenne de l'essaim
            final.append(np.concatenate((main_state, mean_position)).astype(int))
        if state[3]!=0: # état du drone
            final.append(np.full((11), -1))
        return final
        
    def get_action(self, state):
        state = self.process_states(state)
        if np.random.rand() < self.epsilon:
            return random.choice(np.arange(self.action_dim))
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)  # Add batch dimension
            q_values = torch.softmax(self.q_network(state_tensor),dim=-1)
            action = torch.argmax(q_values).item()
            action = max(0, min(action, self.action_dim - 1))
            return action

    def update_target_network(self):
        target_net_state_dict = self.target_network.state_dict()
        policy_net_state_dict = self.q_network.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*self.TAU + target_net_state_dict[key]*(1-self.TAU)
        self.target_network.load_state_dict(target_net_state_dict)
        
    def update_policy(self):
        if self.replay_buffer.__len__() < self.batch_size:
            return

        batch = self.replay_buffer.sample(self.batch_size)
        batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = zip(*batch)
        
        #clean states
        batch_dim = len(batch_states)
        clean_batch_states = []
        clean_batch_next_states = []
        for i in range(batch_dim):
            clean_batch_states.append(self.process_states(batch_states[i]))
            clean_batch_next_states.append(self.process_states(batch_next_states[i]))
        batch_states = np.array(clean_batch_states)
        batch_next_states = np.array(clean_batch_next_states)
        
        #convert to tensor
        states = torch.FloatTensor(batch_states).squeeze(1).to(self.device)
        actions = torch.LongTensor(batch_actions).to(self.device)
        rewards = torch.FloatTensor(batch_rewards).to(self.device)
        next_states = torch.FloatTensor(batch_next_states).squeeze(1).to(self.device)
        dones = torch.FloatTensor(batch_dones).to(self.device)
        
        # compute actual q values
        q_values = self.q_network(states)
        # Retrieve the q values for the actions that were taken
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
        
        # compute target q values
        with torch.no_grad():
            target_q_values = self.target_network(next_states)
            target_q_values = rewards + self.gamma * target_q_values.max(dim=1).values * (1 - dones)
        
        # compute loss
        loss = self.loss_fn(q_values, target_q_values)
        self.optimizers.zero_grad()
        loss.backward()
        
        torch.nn.utils.clip_grad_value_(self.q_network.parameters(), 100)
        self.optimizers.step()
        
        # update epsilon
        #self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        
        return loss.item()