In [4]:
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
np.bool8 = np.bool_


env = gym.make("Pong-v4")
H = 200  # 隐藏层维度
D = 80 * 80  # 输入维度 (80x80)

class PolicyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(D, H)
        self.fc2 = nn.Linear(H, 1)
        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        self.fc1.bias.data.zero_()
        self.fc2.bias.data.zero_()

    def forward(self, x):
        h = F.relu(self.fc1(x))
        logit = self.fc2(h)
        return h, torch.sigmoid(logit)

model = PolicyNet()
optimizer = optim.RMSprop(model.parameters(), lr=1e-3, alpha=0.99, eps=1e-5)

def prepro(I):
    """预处理帧图像"""
    I = I[35:195]
    I = I[::2, ::2, 0]
    I[I == 144] = 0
    I[I == 109] = 0
    I[I != 0] = 1
    return I.astype(np.float32).ravel()

def discounted_rewards(rewards):
    discounted = np.zeros_like(rewards, dtype=np.float32)
    running_add = 0
    for t in reversed(range(len(rewards))):
        if rewards[t] != 0: running_add = 0
        running_add = running_add * 0.99 + rewards[t]
        discounted[t] = running_add
    return (discounted - discounted.mean()) / (discounted.std() + 1e-8)

# 训练循环
observation, _ = env.reset()
prev_x = None
running_reward = None
batch_size = 10

xs, hs, logps, rewards = [], [], [], []
episode_number = 0

while True:
    current_x = prepro(observation)
    x = current_x - prev_x if prev_x is not None else np.zeros(D)
    x_tensor = torch.FloatTensor(x)
    
    # 前向传播
    with torch.no_grad():
        h_tensor, prob = model(x_tensor)
    
    # 采样动作
    action = 2 if np.random.rand() > prob.item() else 3
    y = 1 if action == 3 else 0
    
    # 需要重新计算梯度跟踪的版本
    h_tensor.requires_grad_()
    _, prob = model(x_tensor)
    log_prob = y * torch.log(prob) + (1 - y) * torch.log(1 - prob)
    
    # 存储轨迹信息
    xs.append(x_tensor)
    hs.append(h_tensor)
    logps.append(log_prob)
    
    # 执行动作
    observation, reward, terminated, truncated, _ = env.step(action)
    rewards.append(reward)
    
    if terminated or truncated:
        episode_number += 1
        
        # 计算折扣回报
        dr_tensor = torch.FloatTensor(discounted_rewards(np.array(rewards)))
        
        # 计算损失
        logps_tensor = torch.stack(logps)
        loss = -torch.sum(logps_tensor * dr_tensor)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 更新运行奖励
        reward_sum = sum(rewards)
        running_reward = reward_sum if running_reward is None else running_reward*0.99 + reward_sum*0.01
        print(f"Episode {episode_number} reward: {reward_sum}, running: {running_reward:.2f}")
        
        # 重置环境
        xs, hs, logps, rewards = [], [], [], []
        observation, _ = env.reset()
        prev_x = None

Episode 1 reward: -20.0, running: -20.00
Episode 2 reward: -21.0, running: -20.01
Episode 3 reward: -21.0, running: -20.02
Episode 4 reward: -21.0, running: -20.03
Episode 5 reward: -20.0, running: -20.03
Episode 6 reward: -21.0, running: -20.04
Episode 7 reward: -21.0, running: -20.05
Episode 8 reward: -21.0, running: -20.06
Episode 9 reward: -19.0, running: -20.05
Episode 10 reward: -21.0, running: -20.06
Episode 11 reward: -21.0, running: -20.07
Episode 12 reward: -21.0, running: -20.08
Episode 13 reward: -20.0, running: -20.08
Episode 14 reward: -21.0, running: -20.08
Episode 15 reward: -20.0, running: -20.08
Episode 16 reward: -21.0, running: -20.09
Episode 17 reward: -19.0, running: -20.08
Episode 18 reward: -17.0, running: -20.05
Episode 19 reward: -21.0, running: -20.06
Episode 20 reward: -21.0, running: -20.07
Episode 21 reward: -20.0, running: -20.07
Episode 22 reward: -21.0, running: -20.08
Episode 23 reward: -20.0, running: -20.08
Episode 24 reward: -21.0, running: -20.09
E

Episode 195 reward: -21.0, running: -20.30
Episode 196 reward: -20.0, running: -20.30
Episode 197 reward: -21.0, running: -20.31
Episode 198 reward: -21.0, running: -20.31
Episode 199 reward: -21.0, running: -20.32
Episode 200 reward: -21.0, running: -20.33
Episode 201 reward: -21.0, running: -20.33
Episode 202 reward: -21.0, running: -20.34
Episode 203 reward: -21.0, running: -20.35
Episode 204 reward: -20.0, running: -20.34
Episode 205 reward: -21.0, running: -20.35
Episode 206 reward: -21.0, running: -20.36
Episode 207 reward: -19.0, running: -20.34
Episode 208 reward: -21.0, running: -20.35
Episode 209 reward: -21.0, running: -20.36
Episode 210 reward: -21.0, running: -20.36
Episode 211 reward: -19.0, running: -20.35
Episode 212 reward: -21.0, running: -20.36
Episode 213 reward: -18.0, running: -20.33
Episode 214 reward: -20.0, running: -20.33
Episode 215 reward: -21.0, running: -20.34
Episode 216 reward: -19.0, running: -20.32
Episode 217 reward: -20.0, running: -20.32
Episode 218

Episode 386 reward: -21.0, running: -20.39
Episode 387 reward: -20.0, running: -20.38
Episode 388 reward: -21.0, running: -20.39
Episode 389 reward: -21.0, running: -20.39
Episode 390 reward: -20.0, running: -20.39
Episode 391 reward: -21.0, running: -20.40
Episode 392 reward: -20.0, running: -20.39
Episode 393 reward: -21.0, running: -20.40
Episode 394 reward: -18.0, running: -20.37
Episode 395 reward: -21.0, running: -20.38
Episode 396 reward: -20.0, running: -20.38
Episode 397 reward: -21.0, running: -20.38
Episode 398 reward: -21.0, running: -20.39
Episode 399 reward: -21.0, running: -20.40
Episode 400 reward: -17.0, running: -20.36
Episode 401 reward: -19.0, running: -20.35
Episode 402 reward: -20.0, running: -20.34
Episode 403 reward: -19.0, running: -20.33
Episode 404 reward: -21.0, running: -20.34
Episode 405 reward: -21.0, running: -20.34
Episode 406 reward: -21.0, running: -20.35
Episode 407 reward: -20.0, running: -20.35
Episode 408 reward: -21.0, running: -20.35
Episode 409

Episode 577 reward: -19.0, running: -20.40
Episode 578 reward: -21.0, running: -20.40
Episode 579 reward: -21.0, running: -20.41
Episode 580 reward: -21.0, running: -20.41
Episode 581 reward: -18.0, running: -20.39
Episode 582 reward: -21.0, running: -20.40
Episode 583 reward: -21.0, running: -20.40
Episode 584 reward: -21.0, running: -20.41
Episode 585 reward: -21.0, running: -20.41
Episode 586 reward: -19.0, running: -20.40
Episode 587 reward: -21.0, running: -20.41
Episode 588 reward: -21.0, running: -20.41
Episode 589 reward: -21.0, running: -20.42
Episode 590 reward: -21.0, running: -20.42
Episode 591 reward: -21.0, running: -20.43
Episode 592 reward: -20.0, running: -20.42
Episode 593 reward: -20.0, running: -20.42
Episode 594 reward: -21.0, running: -20.43
Episode 595 reward: -21.0, running: -20.43
Episode 596 reward: -20.0, running: -20.43
Episode 597 reward: -21.0, running: -20.43
Episode 598 reward: -21.0, running: -20.44
Episode 599 reward: -21.0, running: -20.44
Episode 600

Episode 768 reward: -20.0, running: -20.37
Episode 769 reward: -21.0, running: -20.38
Episode 770 reward: -20.0, running: -20.38
Episode 771 reward: -21.0, running: -20.38
Episode 772 reward: -21.0, running: -20.39
Episode 773 reward: -18.0, running: -20.37
Episode 774 reward: -21.0, running: -20.37
Episode 775 reward: -21.0, running: -20.38
Episode 776 reward: -21.0, running: -20.38
Episode 777 reward: -21.0, running: -20.39
Episode 778 reward: -21.0, running: -20.40
Episode 779 reward: -21.0, running: -20.40
Episode 780 reward: -20.0, running: -20.40
Episode 781 reward: -19.0, running: -20.38
Episode 782 reward: -21.0, running: -20.39
Episode 783 reward: -20.0, running: -20.39
Episode 784 reward: -21.0, running: -20.39
Episode 785 reward: -20.0, running: -20.39
Episode 786 reward: -21.0, running: -20.40
Episode 787 reward: -21.0, running: -20.40
Episode 788 reward: -21.0, running: -20.41
Episode 789 reward: -21.0, running: -20.41
Episode 790 reward: -20.0, running: -20.41
Episode 791

Episode 959 reward: -21.0, running: -20.56
Episode 960 reward: -21.0, running: -20.57
Episode 961 reward: -21.0, running: -20.57
Episode 962 reward: -21.0, running: -20.58
Episode 963 reward: -21.0, running: -20.58
Episode 964 reward: -21.0, running: -20.59
Episode 965 reward: -20.0, running: -20.58
Episode 966 reward: -21.0, running: -20.58
Episode 967 reward: -21.0, running: -20.59
Episode 968 reward: -21.0, running: -20.59
Episode 969 reward: -21.0, running: -20.60
Episode 970 reward: -21.0, running: -20.60
Episode 971 reward: -21.0, running: -20.60
Episode 972 reward: -20.0, running: -20.60
Episode 973 reward: -21.0, running: -20.60
Episode 974 reward: -21.0, running: -20.61
Episode 975 reward: -20.0, running: -20.60
Episode 976 reward: -21.0, running: -20.60
Episode 977 reward: -21.0, running: -20.61
Episode 978 reward: -21.0, running: -20.61
Episode 979 reward: -21.0, running: -20.62
Episode 980 reward: -20.0, running: -20.61
Episode 981 reward: -20.0, running: -20.60
Episode 982

Episode 1147 reward: -18.0, running: -20.49
Episode 1148 reward: -21.0, running: -20.50
Episode 1149 reward: -21.0, running: -20.50
Episode 1150 reward: -21.0, running: -20.51
Episode 1151 reward: -21.0, running: -20.51
Episode 1152 reward: -21.0, running: -20.52
Episode 1153 reward: -20.0, running: -20.51
Episode 1154 reward: -21.0, running: -20.52
Episode 1155 reward: -21.0, running: -20.52
Episode 1156 reward: -21.0, running: -20.53
Episode 1157 reward: -20.0, running: -20.52
Episode 1158 reward: -21.0, running: -20.53
Episode 1159 reward: -21.0, running: -20.53
Episode 1160 reward: -21.0, running: -20.54
Episode 1161 reward: -21.0, running: -20.54
Episode 1162 reward: -20.0, running: -20.54
Episode 1163 reward: -21.0, running: -20.54
Episode 1164 reward: -21.0, running: -20.54
Episode 1165 reward: -20.0, running: -20.54
Episode 1166 reward: -21.0, running: -20.54
Episode 1167 reward: -20.0, running: -20.54
Episode 1168 reward: -17.0, running: -20.50
Episode 1169 reward: -20.0, runn

Episode 1334 reward: -21.0, running: -20.52
Episode 1335 reward: -21.0, running: -20.53
Episode 1336 reward: -21.0, running: -20.53
Episode 1337 reward: -21.0, running: -20.54
Episode 1338 reward: -21.0, running: -20.54
Episode 1339 reward: -21.0, running: -20.55
Episode 1340 reward: -21.0, running: -20.55
Episode 1341 reward: -21.0, running: -20.55
Episode 1342 reward: -19.0, running: -20.54
Episode 1343 reward: -21.0, running: -20.54
Episode 1344 reward: -20.0, running: -20.54
Episode 1345 reward: -21.0, running: -20.54
Episode 1346 reward: -18.0, running: -20.52
Episode 1347 reward: -18.0, running: -20.49
Episode 1348 reward: -21.0, running: -20.50
Episode 1349 reward: -19.0, running: -20.48
Episode 1350 reward: -21.0, running: -20.49
Episode 1351 reward: -21.0, running: -20.49
Episode 1352 reward: -21.0, running: -20.50
Episode 1353 reward: -19.0, running: -20.48
Episode 1354 reward: -21.0, running: -20.49
Episode 1355 reward: -21.0, running: -20.49
Episode 1356 reward: -21.0, runn

KeyboardInterrupt: 

In [5]:
loss

tensor(-0.0154, grad_fn=<NegBackward0>)