In [15]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 定义Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return self.softmax(x)

# 定义Critic网络
class Critic(nn.Module):
    def __init__(self, state_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建环境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# 初始化Actor和Critic
actor = Actor(state_dim, action_dim)
critic = Critic(state_dim)
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-4)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

# 超参数
gamma = 0.99
num_episodes = 500

# 训练
for episode in range(num_episodes):
    state = env.reset()
    episode_reward = 0
    
    while True:
        # 添加调试信息
        if len(state) != state_dim:
            print(f"State length mismatch: expected {state_dim}, got {len(state)}")
            break
        
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        action_probs = actor(state_tensor)
        action = np.random.choice(action_dim, p=action_probs.detach().numpy()[0])
        
        next_state, reward, done, _ = env.step(action)
        next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
        
        # 计算目标和优势
        value = critic(state_tensor)
        next_value = critic(next_state_tensor)
        target = reward + gamma * next_value * (1 - int(done))
        advantage = target - value
        
        # 更新Critic
        critic_loss = advantage.pow(2).mean()
        critic_optimizer.zero_grad()
        critic_loss.backward()
        critic_optimizer.step()
        
        # 更新Actor
        log_prob = torch.log(action_probs[0, action])
        actor_loss = -log_prob * advantage.detach()
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()
        
        state = next_state
        episode_reward += reward
        
        if done:
            break
    
    if (episode + 1) % 10 == 0:
        print(f'Episode {episode + 1}, Reward: {episode_reward}')

env.close()








State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
Episode 10, Reward: 0
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
Episode 20, Reward: 0
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length mismatch: expected 4, got 2
State length 