In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np

from 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [9]:
# 定义地图环境
class GridEnv:
    def __init__(self, grid_size, start, goal, obstacles):
        self.grid_size = grid_size
        self.start = start
        self.goal = goal
        self.obstacles = obstacles
        self.state = start
        self.done = False

    def reset(self):
        self.state = self.start
        self.done = False
        return self._get_state()

    def step(self, action):
        if self.done:
            raise ValueError("环境已结束，请重置。")

        x, y = self.state
        if action == 0:  # 上
            next_state = (x, y - 1)
        elif action == 1:  # 下
            next_state = (x, y + 1)
        elif action == 2:  # 左
            next_state = (x - 1, y)
        elif action == 3:  # 右
            next_state = (x + 1, y)

        # 边界检查
        if (next_state[0] < 0 or next_state[0] >= self.grid_size[0] or
            next_state[1] < 0 or next_state[1] >= self.grid_size[1]):
            next_state = self.state

        # 障碍物检查
        if next_state in self.obstacles:
            reward = -1
            next_state = self.state
        elif next_state == self.goal:
            reward = 10
            self.done = True
        else:
            reward = -0.1

        self.state = next_state
        return self._get_state(), reward, self.done

    def _get_state(self):
        # 将位置状态编码为网格大小的向量，便于模型输入
        state = np.zeros(self.grid_size).flatten()
        idx = self.state[0] * self.grid_size[1] + self.state[1]
        state[idx] = 1
        return state

# 定义 Actor-Critic 模型
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc_policy = nn.Linear(128, action_dim)
        self.fc_value = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        policy_logits = self.fc_policy(x)
        state_value = self.fc_value(x)
        return policy_logits, state_value

    def choose_action(self, state):
        self.eval()
        logits, _ = self.forward(state)
        probs = F.softmax(logits, dim=-1)
        action = Categorical(probs).sample()
        return action.item()

# Worker 进程定义
class Worker(mp.Process):
    def __init__(self, global_model, optimizer, global_ep, global_ep_r, res_queue, env_params):
        super(Worker, self).__init__()
        self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
        self.global_model, self.optimizer = global_model, optimizer
        self.env = GridEnv(*env_params)
        self.local_model = ActorCritic(25, 4).to(device)
        self.local_model.load_state_dict(self.global_model.state_dict())

    def run(self):
        while self.g_ep.value < 500:  # 假设500个episode
            state = self.env.reset()
            buffer_s, buffer_a, buffer_r = [], [], []
            ep_r = 0
            while True:
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                action = self.local_model.choose_action(state_tensor)
                next_state, reward, done = self.env.step(action)

                ep_r += reward
                buffer_s.append(state)
                buffer_a.append(action)
                buffer_r.append(reward)

                if done:
                    self.update_global(buffer_s, buffer_a, buffer_r, done, next_state)
                    with self.g_ep.get_lock():
                        self.g_ep.value += 1
                    with self.g_ep_r.get_lock():
                        if self.g_ep_r.value == 0.:
                            self.g_ep_r.value = ep_r
                        else:
                            self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
                    self.res_queue.put(self.g_ep_r.value)
                    break
                state = next_state

    def update_global(self, buffer_s, buffer_a, buffer_r, done, next_state):
        if done:
            v_s_ = 0
        else:
            next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(device)
            _, v_s_ = self.local_model(next_state_tensor)
            v_s_ = v_s_.item()

        buffer_v_target = []
        for r in buffer_r[::-1]:
            v_s_ = r + 0.99 * v_s_
            buffer_v_target.append(v_s_)
        buffer_v_target.reverse()

        loss = 0
        self.optimizer.zero_grad()
        for s, a, v_t in zip(buffer_s, buffer_a, buffer_v_target):
            s_tensor = torch.FloatTensor(s).unsqueeze(0).to(device)
            logits, v = self.local_model(s_tensor)
            advantage = v_t - v.item()

            action_loss = -torch.log_softmax(logits, dim=-1)[0, a] * advantage
            value_loss = F.mse_loss(v, torch.tensor(v_t).to(device))
            loss += (action_loss + value_loss)

        loss.backward()
        for global_param, local_param in zip(self.global_model.parameters(), self.local_model.parameters()):
            global_param._grad = local_param.grad
        self.optimizer.step()
        self.local_model.load_state_dict(self.global_model.state_dict())

In [10]:
# 主函数
if __name__ == "__main__":
    env_params = ((5, 5), (0, 0), (4, 4), [(1, 1), (2, 2), (3, 3)])
    global_model = ActorCritic(25, 4).to(device)
    global_model.share_memory()
    optimizer = optim.Adam(global_model.parameters(), lr=0.001)
    global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()

    workers = [Worker(global_model, optimizer, global_ep, global_ep_r, res_queue, env_params) for _ in range(mp.cpu_count())]
    [w.start() for w in workers]
    [w.join() for w in workers]

    rewards = []
    while not res_queue.empty():
        rewards.append(res_queue.get())
    print("训练完成，累计奖励:", rewards)



训练完成，累计奖励: [-55.10000000000025, -55.10000000000025, -55.10000000000025, -53.91993000000024, -55.10000000000025, -52.774909259070235, -55.10000000000025, -55.10000000000025, -55.10000000000025, -54.50700000000025, -53.35041339300024, -52.20516016647953, -53.931730700000244, -51.74384109933117, -51.64110856481473, -51.67569747916659, -51.21762866145448, -51.70994050437492, -51.77740268833786, -51.29488785109153, -51.25645237483993, -51.332938972580614, -51.37060958285481, -51.23690348702627, -51.01953445215601, -50.88933910763445, -51.03544571655811, -51.180091259392526, -51.0392903467986, -50.92089744333062, -50.854688468897315, -50.862141584208345, -50.869520168366265, -50.7408249666826, -50.58741671701578, -50.43554254984562, -49.96918712434716, -49.498495253103684, -49.032510300572646, -48.57118519756692, -48.12347334559124, -48.198238612135334, -48.263256226013986, -48.15662366375385, -48.03605742711631, -48.10269685284515, -48.027669884316694, -48.022393185473526, -48.0441692536187