In [1]:
import numpy as np

# Constants representing the state of each cell
UNEXPLORED = -2
OBSTACLE = -1
SAFE = 0

# Initialize the global grid status as unexplored
global_grid = np.full((10, 10), UNEXPLORED)

def drone_scan(drone_pos, scan_range, actual_env):
    """
    Perform a local scan around a drone position.

    Args:
        drone_pos (tuple): Drone coordinates (x, y).
        scan_range (int): The range of the drone's field of view (e.g., 3 or 5).
        actual_env (np.array): The actual environment (10x10 matrix).

    Returns:
        tuple: local scanned information and top-left position of the scan.
    """
    half_range = scan_range // 2
    local_info = np.full((scan_range, scan_range), UNEXPLORED)

    for i in range(scan_range):
        for j in range(scan_range):
            global_x = drone_pos[0] - half_range + i
            global_y = drone_pos[1] - half_range + j

            # Check boundaries
            if 0 <= global_x < 10 and 0 <= global_y < 10:
                local_info[i, j] = actual_env[global_x, global_y]

    return local_info, (drone_pos[0] - half_range, drone_pos[1] - half_range)

def stitch_information(global_grid, local_info, top_left):
    """
    Merge local drone scans into the global coverage grid.

    Args:
        global_grid (np.array): Current global grid state.
        local_info (np.array): Local scan result from a drone.
        top_left (tuple): Top-left coordinate of the local scan in the global grid.

    Returns:
        np.array: Updated global grid.
    """
    # find top left corner of local scan for better aligning and stitching
    x_offset, y_offset = top_left

    for i in range(local_info.shape[0]):
        for j in range(local_info.shape[1]):
            x, y = x_offset + i, y_offset + j
            if 0 <= x < 10 and 0 <= y < 10:
                # Update global grid only if it's unexplored
                if global_grid[x, y] == UNEXPLORED:
                    global_grid[x, y] = local_info[i, j]
                # If conflicting info, prioritize safe information
                elif global_grid[x, y] != local_info[i, j]:
                    if local_info[i, j] == SAFE:
                        global_grid[x, y] = SAFE
    return global_grid

In [2]:
# Example actual environment (randomly generated obstacles and safe zones when searching)
actual_env = np.random.choice([OBSTACLE, SAFE], size=(10, 10), p=[0.2, 0.8])

Training where to put drones

In [4]:
from gymnasium import Env, spaces

class DronePlacementEnv(Env):
    def __init__(self):
        super().__init__()
        self.grid_size = 10
        self.scan_range = 3
        self.max_steps = 10
        self.current_step = 0

        self.action_space = spaces.Discrete(self.grid_size * self.grid_size)
        self.observation_space = spaces.Box(low=UNEXPLORED, high=SAFE,
                                            shape=(self.grid_size, self.grid_size), dtype=np.int32)

        self.reset()

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.actual_env = np.random.choice([OBSTACLE, SAFE], size=(self.grid_size, self.grid_size), p=[0.2, 0.8]).astype(np.int32)
        self.global_grid = np.full((self.grid_size, self.grid_size), UNEXPLORED, dtype=np.int32)
        self.current_step = 0
        return self.global_grid.copy(), {}

    def step(self, action):
        self.current_step += 1
        x = action // self.grid_size
        y = action % self.grid_size

        local_info, top_left = drone_scan((x, y), self.scan_range, self.actual_env)
        prev_unexplored = np.sum(self.global_grid == UNEXPLORED)
        self.global_grid = stitch_information(self.global_grid, local_info, top_left)
        new_unexplored = np.sum(self.global_grid == UNEXPLORED)

        reward = float(prev_unexplored - new_unexplored - 0.1)
        terminated = bool(new_unexplored == 0)
        truncated = bool(self.current_step >= self.max_steps)
        return self.global_grid.copy(), reward, terminated, truncated, {}

    def render(self):
        print(self.global_grid)


In [5]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

def train_with_ppo():
    env = DronePlacementEnv()
    check_env(env)  # 确保环境没问题
    model = PPO("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=10000)
    model.save("ppo_drone_model")
    return model


In [6]:
model = train_with_ppo()

# 测试部署策略
env = DronePlacementEnv()
obs, _ = env.reset()

for step in range(10):
    action, _ = model.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)
    env.render()
    if terminated or truncated:
        print(f"终止于 step {step}")
        break




Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




---------------------------------
| rollout/           |          |
|    ep_len_mean     | 10       |
|    ep_rew_mean     | 54.1     |
| time/              |          |
|    fps             | 382      |
|    iterations      | 1        |
|    time_elapsed    | 5        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 10           |
|    ep_rew_mean          | 54.8         |
| time/                   |              |
|    fps                  | 318          |
|    iterations           | 2            |
|    time_elapsed         | 12           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0119997645 |
|    clip_fraction        | 0.117        |
|    clip_range           | 0.2          |
|    entropy_loss         | -4.6         |
|    explained_variance   | 0.0307       |
|    learning_r

In [None]:
model = PPO.load("ppo_drone_model")

# Example drones with position and scanning range
drones = [((1, 1), 3), ((4, 4), 5)]

# Perform local scans and stitch them into global grid
for pos, scan_range in drones:
    local_info, top_left = drone_scan(pos, scan_range, actual_env)
    global_grid = stitch_information(global_grid, local_info, top_left)

print("Final Global Grid after Stitching:")
print(global_grid)