In [4]:
# =========================================
# file: maddpg_with_arbitration.py
# =========================================

import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import copy
from env_marl import MultiAgentRoadCharging
import os

# ============ 1) Replay Buffer =============
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, actions, actions_arbitrated, rewards, next_state, done):
        self.buffer.append((state, actions, actions_arbitrated, rewards, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, actions, actions_arbitrated, rewards, next_state, done = zip(*batch)
        return (np.array(state), 
                np.array(actions), 
                np.array(actions_arbitrated),
                np.array(rewards), 
                np.array(next_state), 
                np.array(done))
    
    def __len__(self):
        return len(self.buffer)

# ============ 2) Networks (Actor/Critic) =============
class ActorNetwork(nn.Module):
    def __init__(self, obs_dim, hidden_dim=64):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(obs_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # 输出[0,1]
        return x

class CriticNetwork(nn.Module):
    def __init__(self, state_dim, n_agents, hidden_dim=64):
        super(CriticNetwork, self).__init__()
        self.input_dim = state_dim + n_agents  # 拼接全局状态 + n_agents动作
        self.fc1 = nn.Linear(self.input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, state, actions):
        x = torch.cat([state, actions], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# ============ 3) MADDPG Agent =============
class MADDPGAgent:
    def __init__(self, 
                 actor_lr, critic_lr, 
                 obs_dim, state_dim, 
                 n_agents, agent_index,
                 gamma=0.95, tau=0.01, hidden_dim=64,
                 device='cpu'):
        self.agent_index = agent_index
        self.gamma = gamma
        self.tau = tau
        self.n_agents = n_agents
        self.device = device
        
        # Actor
        self.actor = ActorNetwork(obs_dim, hidden_dim).to(device)
        self.target_actor = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        
        # Critic
        self.critic = CriticNetwork(state_dim, n_agents, hidden_dim).to(device)
        self.target_critic = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
    
    def select_action(self, obs, exploration=False):
        obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
        self.actor.eval()
        with torch.no_grad():
            p = self.actor(obs_t).item()
        self.actor.train()
        
        if exploration:
            # 简易随机：eps
            eps = 0.05
            if np.random.rand() < eps:
                return np.random.randint(0,2)
        # 伯努力采样
        action = 1 if np.random.rand() < p else 0
        return action
    
    def soft_update(self, net, target_net):
        for param, target_param in zip(net.parameters(), target_net.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
    def update_targets(self):
        self.soft_update(self.actor, self.target_actor)
        self.soft_update(self.critic, self.target_critic)

# ============ 4) 仲裁函数 =============
def arbitrate_actions(actions, m, mode='random'):
    a_tilde = actions[:]
    total_requests = sum(a_tilde)
    if total_requests <= m:
        return a_tilde
    
    idx_ones = [i for i, val in enumerate(a_tilde) if val == 1]
    if mode == 'random':
        chosen = np.random.choice(idx_ones, m, replace=False)
    else:
        # 可做优先级排序
        chosen = idx_ones[:m]  # 演示
    
    for i in idx_ones:
        a_tilde[i] = 0
    for c in chosen:
        a_tilde[c] = 1
    return a_tilde

# ============ 5) 训练每个agent的函数 ============
def train_maddpg_agent(agent_i, agents, replay_buffer, batch_size):
    agent = agents[agent_i]
    device = agent.device
    
    # 采样
    state_b, actions_b, actions_arbi_b, rewards_b, next_state_b, done_b = replay_buffer.sample(batch_size)
    # 转tensor
    state_t = torch.FloatTensor(state_b).to(device)
    actions_arbi_t = torch.FloatTensor(actions_arbi_b).to(device)
    rewards_t = torch.FloatTensor(rewards_b[:, agent_i]).unsqueeze(-1).to(device)
    next_state_t = torch.FloatTensor(next_state_b).to(device)
    done_t = torch.FloatTensor(done_b).unsqueeze(-1).to(device)
    
    # Critic Update
    with torch.no_grad():
        # 下个时刻: 让所有agent用target_actor 输出动作(连续p?),拼合 => Critic
        next_actions_list = []
        for idx, ag in enumerate(agents):
            # 这里假设 local obs == next_state，若不同需拆分
            p_next = ag.target_actor(next_state_t).squeeze(-1)
            next_actions_list.append(p_next)
        next_actions_t = torch.stack(next_actions_list, dim=1)
        
        # 计算 target Q
        Q_next = agent.target_critic(next_state_t, next_actions_t)
        y = rewards_t + agent.gamma * (1 - done_t) * Q_next
    
    Q_now = agent.critic(state_t, actions_arbi_t)
    critic_loss = F.mse_loss(Q_now, y)
    
    agent.critic_optimizer.zero_grad()
    critic_loss.backward()
    agent.critic_optimizer.step()
    
    # Actor Update
    # 让agent_i 的actor输出新的动作(continuous p), 其他agent也用当前actor(或不更新)
    cur_actions_list = []
    for idx, ag in enumerate(agents):
        if idx == agent_i:
            p_cur = ag.actor(state_t).squeeze(-1)  # requires grad
        else:
            with torch.no_grad():
                p_cur = ag.actor(state_t).squeeze(-1)
        cur_actions_list.append(p_cur)
    cur_actions_t = torch.stack(cur_actions_list, dim=1)
    
    actor_loss = -agent.critic(state_t, cur_actions_t).mean()
    agent.actor_optimizer.zero_grad()
    actor_loss.backward()
    agent.actor_optimizer.step()

# ============ 6) 主训练循环 ============
def maddpg_train(env, 
                 n_agents, 
                 n_episodes=500,
                 max_steps=1000,
                 m=2,
                 gamma=0.95, tau=0.01,
                 actor_lr=1e-3, critic_lr=1e-3,
                 batch_size=64,
                 buffer_capacity=100000,
                 print_interval=10,
                 device='cpu'):
    
    obs_dim = env.obs_dim
    state_dim = env.state_dim
    # 初始化agents
    agents = []
    for i in range(n_agents):
        agent_i = MADDPGAgent(actor_lr, critic_lr, 
                              obs_dim, state_dim, 
                              n_agents, i, 
                              gamma=gamma, tau=tau, 
                              device=device)
        agents.append(agent_i)
    
    replay_buffer = ReplayBuffer(buffer_capacity)
    
    all_rewards = []
    for ep in range(n_episodes):
        state, obs_list = env.reset()  # 需要你自己封装
        ep_reward = np.zeros(n_agents)
        
        for t in range(max_steps):
            # 每个agent选择动作
            actions = []
            for i in range(n_agents):
                a_i = agents[i].select_action(obs_list[i], exploration=True)
                actions.append(a_i)
            # 仲裁
            actions_arbi = arbitrate_actions(actions, m, 'random')
            
            # 交互
            next_state, next_obs_list, reward_list, done, info = env.step(actions_arbi)
            
            # 存储
            replay_buffer.push(state, actions, actions_arbi, reward_list, next_state, done)
            ep_reward += reward_list
            
            # 更新
            if len(replay_buffer) > batch_size:
                for i in range(n_agents):
                    train_maddpg_agent(i, agents, replay_buffer, batch_size)
            
            state = next_state
            obs_list = next_obs_list
            
            if done:
                break
        
        all_rewards.append(ep_reward)
        
        # soft update
        for ag in agents:
            ag.update_targets()
        
        if (ep+1) % print_interval == 0:
            avg_r = np.mean(all_rewards[-print_interval:], axis=0)
            print(f"Episode {ep+1}/{n_episodes}, avg reward per agent = {avg_r}")
    
    return agents, all_rewards

# ============ 用法示例 ============
if __name__ == "__main__":
    n_EVs = 5
    n_chargers = 1
    avg_return = 0
    SoC_data_type = "high"
    data_folder = "test_cases"
    results_folder = "results"
    policy_name = "base_policy"
    instance_count = 20
    instance_num = 1
    
    
    test_case = f"all_days_negativePrices_{SoC_data_type}InitSoC_{n_chargers}for{n_EVs}"
    test_cases_dir = os.path.join(data_folder, test_case)  
    data_file = os.path.join(test_cases_dir, f"config{instance_num}_{n_EVs}EVs_{n_chargers}chargers.json")
    print(data_file)
    # env = ConstrainAction(data_file)
    env = MultiAgentRoadCharging(data_file)
    n_agents = env.n
    trained_agents, rewards_history = maddpg_train(env, n_agents, n_episodes=50, m=1)


test_cases/all_days_negativePrices_highInitSoC_1for5/config1_5EVs_1chargers.json


AttributeError: 'MultiAgentRoadCharging' object has no attribute 'obs_dim'