### 加载环境

In [1]:
from wrapper import RelativePosition, FlattenDict, SerializeAction
import numpy as np
from IPython.display import clear_output
from gymnasium.envs.registration import register
import gymnasium as gym
import train_params_with_model as params
from loguru import logger

size = params.size
relay_config = params.relay_config
client_config = params.client_config
init_config = params.init_config
is_polar = False

# register the environment
register(
    id='GridWorld-v0',
    entry_point='grid_world:GridWorldEnv',
    max_episode_steps=500,
    kwargs={
        "size": size,
        "relay_config": relay_config,
        "client_config": client_config,
        "init_config": init_config,
        "is_polar": is_polar,
        "is_plot": False,
        "is_log": False,
        "use_model": True,
        "keep_plot_data": True
    }
)

def get_env():
    origin_env = gym.make("GridWorld-v0")
    relative_env = RelativePosition(origin_env)
    flatten_env = FlattenDict(relative_env)
    env = SerializeAction(flatten_env, is_polar=is_polar)
    return env

# create the environment
env = get_env()

### 定义加载网络方法

In [2]:
import modified_DDPG
import OurDDPG

def load_modified_DDPG():
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0] 
    max_action = float(env.action_space.high[0])

    kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": max_action,
            "discount": 0.5,
            "tau": 0.005,
        }

    kwargs["position_range"] = {
                "position": [-size / 2, size / 2],
                "height": [relay_config.min_height, relay_config.max_height]
            }
    kwargs["relay_dim"] = relay_config.num * 3
    kwargs["client_dim"] = client_config.num * 2
    kwargs["speed"] = relay_config.speed
    policy = modified_DDPG.DDPG(**kwargs)
    policy.load("models/modified_DDPG_GridWorld-v0_with_model_0_2024-10-13_22-50-42")
    return policy

def load_OurDDPG():
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "max_action": max_action,
        "discount": 0.5,
        "tau": 0.005,
    }
    policy = OurDDPG.DDPG(**kwargs)
    policy.load("models/OurDDPG_GridWorld-v0_with_model_0_2024-11-06_20-35-13")
    return policy

class RandomPolicy:
    def select_action(self, state):
        return env.action_space.sample()
    
def set_seed(seed):
    env.action_space.seed(seed)
    np.random.seed(seed)
    state, info = env.reset(seed=seed)
    return state, info

# define the seed
seed = 3
# seed = 6

### 定义绘图函数

In [3]:
import numpy as np
from treelib import Tree
import matplotlib.pyplot as plt
from matplotlib.image import imread
import io
from PIL import Image
from typing import List, Optional, Union

def draw_map_for_show(tree: Tree, center_position: np.ndarray, relay_position: np.ndarray, client_position: np.ndarray, \
             relay_height: Optional[np.ndarray] = None, relay_traffic: Optional[np.ndarray] = None, client_link: np.ndarray = None, \
             size: Optional[List[list[float]]] = None, policy_name: str = "map", **kwargs) -> Union[Image.Image, None]:
    """
    Draw a map with the given tree structure and positions of center, relay, and client nodes.

    Args:
        tree (Tree): The tree structure representing the connections between nodes.
        center_position (np.ndarray): The position of the center node.
        relay_position (np.ndarray): The positions of the relay nodes.
        client_position (np.ndarray): The positions of the client nodes.
        relay_height (Optional[np.ndarray], optional): The heights of the relay nodes. Defaults to None.
        relay_traffic (Optional[np.ndarray], optional): The traffic data for the relay nodes. Defaults to None.
        client_link (np.ndarray): The link data between the client and relay nodes. Defaults to None.
        size (Optional[List[list[float]]], optional): The size of the map. Defaults to None.
        is_show (bool, optional): Whether to display the map or save it as an image. Defaults to True.

    Returns:
        Union[Image.Image, None]: The map image if `is_show` is False, otherwise None.
    """
    
    if size:
        plt.figure(figsize=(6, 6))
        x_ticks = np.linspace(size[0][0], size[0][1], 5)
        plt.xticks(x_ticks)
        plt.xlim(size[0][0], size[0][1])
        
        
        y_ticks = np.linspace(size[1][0], size[1][1], 5)
        plt.yticks(y_ticks)
        plt.ylim(size[1][0], size[1][1])
        # plt.axis('equal')
        plt.gca().set_aspect('equal', adjustable='box')

        background = imread('./eval/background.png')
        plt.imshow(background, extent=[size[0][0], size[0][1], size[1][0], size[1][1]])

    invalid_node = []

    # draw lines first
    for node_index in tree.expand_tree():
        parent = tree.parent(node_index)
        if parent:
            if parent.identifier == -1:
                # means the parent is the center
                node = tree.get_node(node_index)
                if node.data["link_speed"] > 0:
                    plt.plot([center_position[0,0], relay_position[node_index,0]], [center_position[0,1], relay_position[node_index,1]], linestyle="--", c='pink')
                else:
                    invalid_node.append(node_index)
            else:
                # means the parent is a relay
                plt.plot([relay_position[parent.identifier,0], relay_position[node_index,0]], [relay_position[parent.identifier,1], relay_position[node_index,1]], linestyle="--", c='pink')

    # based on the client position draw the link between the client and the relay
    if client_link is not None:
        for i in range(client_position.shape[0]):
            client = client_position[i]
            relay_index = client_link[i]
            if relay_index != -1:
                plt.plot([relay_position[relay_index,0], client[0]], [relay_position[relay_index,1], client[1]], linestyle="--", c='#add8e6')


    # draw the markers now
    plt.scatter(center_position[:, 0], center_position[:, 1], c='red', marker='s', label='center', zorder=10)
    
    plt.scatter(relay_position[:, 0], relay_position[:, 1], c='blue', marker='P', label='relay', zorder=9)
    # draw label for relay
    # for i in range(relay_position.shape[0]):
    #     plt.text(relay_position[i, 0], relay_position[i, 1], str(i))

    # plt.scatter(client_position[:, 0], client_position[:, 1], c='green', marker='o', label='client')
    # calculate client linked and unlinked
    # new_client_link = np.copy(client_link)
    new_client_link = client_link.copy()
    for i in invalid_node:
        new_client_link[np.where(new_client_link == i)] = -1
    client_position_linked = client_position[np.where(new_client_link != -1)[0]]
    client_position_unlinked = client_position[np.where(new_client_link == -1)[0]]
    plt.scatter(client_position_linked[:, 0], client_position_linked[:, 1], c='green', marker='o', label='client linked')
    plt.scatter(client_position_unlinked[:, 0], client_position_unlinked[:, 1], c='orange', marker='o', label='client unlinked')

    plt.legend(loc='upper left')
    if kwargs.get("need_title", True):
        plt.title(policy_name)
    plt.savefig(f"./eval/{policy_name}.png")
    plt.show()

def draw_trajectory(trajectory: np.ndarray, size: List[list[float]], policy_name: str = "trajectory", **kwargs) -> None:
    plt.figure(figsize=(6, 6))
    x_ticks = np.linspace(size[0][0], size[0][1], 5)
    plt.xticks(x_ticks)
    plt.xlim(size[0][0], size[0][1])
    
    
    y_ticks = np.linspace(size[1][0], size[1][1], 5)
    plt.yticks(y_ticks)
    plt.ylim(size[1][0], size[1][1])
    # plt.axis('equal')
    plt.gca().set_aspect('equal', adjustable='box')

    background = imread('./eval/background.png')
    plt.imshow(background, extent=[size[0][0], size[0][1], size[1][0], size[1][1]])
    
    for i in range(trajectory.shape[0]):
        plt.plot(trajectory[i, :, 0], trajectory[i, :, 1], linestyle="-")

    client_link = None
    if "client_link" in kwargs and "tree" in kwargs:
        client_link = kwargs["client_link"]
        tree: Tree = kwargs["tree"]
        invalid_node = []

        for node_index in tree.expand_tree():
            parent = tree.parent(node_index)
            if parent:
                if parent.identifier == -1:
                    # means the parent is the center
                    node = tree.get_node(node_index)
                    if node.data["link_speed"] == 0.0:
                        invalid_node.append(node_index)
        client_link = client_link.copy()
        for i in invalid_node:
            client_link[np.where(client_link == i)] = -1
    
    if "client_link" in kwargs and "client_position" in kwargs:
        if client_link is None:
            client_link = kwargs["client_link"]
        client_position = kwargs["client_position"]

        client_position_linked = client_position[np.where(client_link != -1)[0]]
        client_position_unlinked = client_position[np.where(client_link == -1)[0]]
        plt.scatter(client_position_linked[:, 0], client_position_linked[:, 1], c='green', marker='o', label='client linked')
        plt.scatter(client_position_unlinked[:, 0], client_position_unlinked[:, 1], c='orange', marker='o', label='client unlinked')
        plt.legend(loc='upper left')

    if kwargs.get("need_title", True):
        plt.title(policy_name)
    plt.show()

from link_tree import Node
import math

def calculate_client_percent(node: Node, rate: float = 1.0) -> List[float]:
    """
    Calculate the percentage of data that clients send to the center node.

    Args:
        center_node (Node): The center node.

    Returns:
        List[float]: The percentage of data that clients send to the center node.
    """
    client_percent = []
    if node.link_speed < node.traffic_load:
        rate = rate * (node.link_speed / node.traffic_load)

    client_percent.extend([rate] * math.floor(node.data_amount / params.client_config.traffic))
    for child in node.children:
        client_percent.extend(calculate_client_percent(child, rate))
    return client_percent
    

### 定义运行模型和绘图方法

In [4]:

def run_and_draw(policy_name, policy):
    state, info = set_seed(seed)
    plot_data = []
    reward_list = []

    for i in range(500):
        action = policy.select_action(np.array(state))
        next_state, reward, terminated, truncated, info = env.step(action=action)
        plot_data.append(info["plot_data"].copy())
        reward_list.append(reward)
        state = next_state
    # np.savetxt(f"./eval/client_position.txt", plot_data[-1]["client_position"])
    # np.savetxt(f"./eval/relay_position.txt", plot_data[-1]["relay_position"])
    # np.savetxt(f"./eval/relay_height.txt", plot_data[-1]["relay_height"])
    logger.info(f"---------------------------------")
    logger.info(f"Policy: {policy_name}")
    logger.info(f"Step: {0}, reward: {reward_list[0]}")
    draw_map_for_show(**plot_data[0], policy_name=f"{policy_name}", need_title=False)
    for i in range(99, 500, 100):
        logger.info(f"Step: {i}, reward: {reward_list[i]}")
        draw_map_for_show(**plot_data[i], policy_name=f"{policy_name}", need_title=False)
    draw_map_for_show(**plot_data[-1], policy_name=policy_name, need_title=True)
    trajectory = np.stack([plot_data[i]["relay_position"] for i in range(len(plot_data))], axis=1)
    draw_trajectory(trajectory, plot_data[-1]["size"], policy_name=policy_name, \
                    client_position=plot_data[-1]["client_position"], client_link=plot_data[-1]["client_link"], tree=plot_data[-1]["tree"], need_title=False)
    client_percent = calculate_client_percent(plot_data[-1]["center_node"])
    np.savetxt(f"./eval/{policy_name}_client_percent.txt", client_percent)

    plt.hist(client_percent + [0] * (params.client_config.num - len(client_percent)), bins=10, alpha=0.75)  # bins参数控制直方图的条形数量
    plt.title('Gradient Histogram')
    plt.xlabel('Gradient Value')
    plt.ylabel('Frequency')
    plt.show()

    logger.info(f"Client percent: {client_percent}")
    logger.info(f"Average client percent: {np.sum(client_percent) / params.client_config.num}")
    logger.info(f"Reward: {reward}")
    logger.info(f"Length of client percent: {len(client_percent)}")
    
    

### 绘制情况

In [None]:
policy = load_modified_DDPG()
run_and_draw("Modified DDPG", policy)

In [None]:
policy = load_OurDDPG()
run_and_draw("DDPG", policy)

In [None]:
policy = RandomPolicy()
run_and_draw("Random", policy)