In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np
from collections import deque
import random


# 定义DQN模型
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


# 定义经验回放缓冲区
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)

    def push(self, transition):
        self.memory.append(transition)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


# 定义DQN智能体
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

        self.memory = ReplayMemory(10000)
        self.gamma = 0.95  # 折扣因子
        self.epsilon = 1.0  # 探索率
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.batch_size = 64

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN(state_size, action_size).to(self.device)
        self.target_net = DQN(state_size, action_size).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
        self.update_target_every = 10  # 每10个回合更新一次目标网络

    def select_action(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.policy_net(state)
        return q_values.argmax().item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.push((state, action, reward, next_state, done))

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        minibatch = self.memory.sample(self.batch_size)
        states, actions, rewards, next_states, dones = zip(*minibatch)

        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

        # 当前Q值
        current_q = self.policy_net(states).gather(1, actions)
        # 下一个状态的最大Q值
        next_q = self.target_net(next_states).max(1)[0].unsqueeze(1)
        # 计算目标Q值
        target_q = rewards + (self.gamma * next_q * (1 - dones))

        # 计算损失
        loss = nn.MSELoss()(current_q, target_q)

        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # 更新探索率
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())


# 训练函数
def train_dqn(episodes):
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)

    for e in range(1, episodes + 1):
        step_result = env.reset()
        if isinstance(step_result, tuple) or isinstance(step_result, list):
            if len(step_result) == 2:
                state, _ = step_result
            else:
                state = step_result[0]
        else:
            state = step_result  # 仅返回状态

        total_reward = 0
        done = False
        while not done:
            action = agent.select_action(state)
            step_result = env.step(action)
            if len(step_result) == 4:
                # 适用于旧版本 Gym
                next_state, reward, done, _ = step_result
            elif len(step_result) == 5:
                # 适用于新版本 Gym (如 Gymnasium)
                next_state, reward, terminated, truncated, _ = step_result
                done = terminated or truncated
            else:
                raise ValueError("Unexpected number of values returned by env.step()")

            agent.remember(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward
            agent.replay()

        if e % agent.update_target_every == 0:
            agent.update_target_network()
        print(f"回合 {e}/{episodes} - 奖励: {total_reward} - 探索率: {agent.epsilon:.4f}")
    env.close()


# 主程序
if __name__ == "__main__":
    train_dqn(episodes=500)


回合 1/500 - 奖励: 32.0 - 探索率: 1.00
回合 2/500 - 奖励: 26.0 - 探索率: 1.00
回合 3/500 - 奖励: 26.0 - 探索率: 0.90
回合 4/500 - 奖励: 22.0 - 探索率: 0.81
回合 5/500 - 奖励: 13.0 - 探索率: 0.76
回合 6/500 - 奖励: 11.0 - 探索率: 0.71
回合 7/500 - 奖励: 19.0 - 探索率: 0.65
回合 8/500 - 奖励: 13.0 - 探索率: 0.61
回合 9/500 - 奖励: 13.0 - 探索率: 0.57
回合 10/500 - 奖励: 18.0 - 探索率: 0.52
回合 11/500 - 奖励: 11.0 - 探索率: 0.49
回合 12/500 - 奖励: 11.0 - 探索率: 0.47
回合 13/500 - 奖励: 18.0 - 探索率: 0.43
回合 14/500 - 奖励: 12.0 - 探索率: 0.40
回合 15/500 - 奖励: 13.0 - 探索率: 0.38
回合 16/500 - 奖励: 11.0 - 探索率: 0.36
回合 17/500 - 奖励: 13.0 - 探索率: 0.33
回合 18/500 - 奖励: 11.0 - 探索率: 0.32
回合 19/500 - 奖励: 10.0 - 探索率: 0.30
回合 20/500 - 奖励: 8.0 - 探索率: 0.29
回合 21/500 - 奖励: 13.0 - 探索率: 0.27
回合 22/500 - 奖励: 8.0 - 探索率: 0.26
回合 23/500 - 奖励: 10.0 - 探索率: 0.25
回合 24/500 - 奖励: 11.0 - 探索率: 0.23
回合 25/500 - 奖励: 12.0 - 探索率: 0.22
回合 26/500 - 奖励: 11.0 - 探索率: 0.21
回合 27/500 - 奖励: 10.0 - 探索率: 0.20
回合 28/500 - 奖励: 9.0 - 探索率: 0.19
回合 29/500 - 奖励: 10.0 - 探索率: 0.18
回合 30/500 - 奖励: 11.0 - 探索率: 0.17
回合 31/500 - 奖励: 14.0 -