In [71]:
import gym
import numpy as np
np.bool8 = np.bool_

In [72]:
# 创建环境
env = gym.make("Pong-v4")

In [73]:
# env.reset()
# state,_ = env.reset()
# done = False
# count = 0
# while not done:
# #     env.render()
#     action = int(np.random.choice([2,3]))
#     next_state, reward, done, truncated, _ = env.step(action)
#     print(action,reward)
#     count += 1
# print(count)

In [74]:
# env.close()

In [75]:
from torch import nn
import torch

In [76]:
def prepro(I):
    """将 210x160x3 uint8 帧预处理为 6400 (80x80) 1D float 向量"""
    I = I[35:195]  # 裁剪
    I = I[::2, ::2, 0]  # 下采样因子为 2
    I[I == 144] = 0  # 删除背景类型 1
    I[I == 109] = 0  # 删除背景类型 2
    I[I != 0] = 1  # 其他设置为 1
    return I.astype(np.float32).ravel()

In [77]:
class PolicyNet(nn.Module):
    
    def __init__(self,input_dim,output_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim,200)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(200,output_dim)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self,state):
        ### n
        x = self.linear1(state)
        x = self.relu(x)
        x = self.linear2(x) # n
        x = self.softmax(x) # n
        return x

In [78]:
from torch.distributions import Categorical
import numpy as np
np.bool8 = np.bool_

from torch.optim import AdamW

In [92]:
class Agent:
    
    def __init__(self):
        self.policy_net = PolicyNet(6400,2)
        self.optimizer = AdamW(self.policy_net.parameters(),lr=1e-3)
    
    def sample_action(self,state):
        probs = self.policy_net(state) # 4
        if np.random.uniform() < 0.2:
            action = np.random.randint(0,2)
            return action + 2, torch.log(probs[action]+1e-8)
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item()+2,log_prob
    
    def update(self,rewards,log_probs):
        ### 一次游戏时间
        ret = []
        adding = 0
        for r in rewards[::-1]:
            if r != 0:
                adding = 0
            adding = adding * 0.99 + r
            ret.insert(0,adding)
        ret = torch.FloatTensor(ret)
        ret = ret - ret.mean()
        ret = ret / (ret.std()+1e-8)
        
        r_log_probs = []
        for r,log_prob in zip(ret,log_probs):
            r_log_probs.append(-r*log_prob)
        r_log_probs = torch.vstack(r_log_probs)
        
        loss = r_log_probs.sum()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss


In [102]:
def train(agent,env):
    success_count = []
    max_size = 2000
    for epoch in range(20000):
        rewards = []
        log_probs = []
        terminated = False
        state,_ = env.reset()
        prev_x = None
        while not terminated:
            x = prepro(state)
            diff = np.zeros(6400) if prev_x is None else x - prev_x
            prev_x = x
            diff = torch.FloatTensor(diff)
            action, log_prob = agent.sample_action(diff)
            next_state, reward, terminated, truncated, _ = env.step(action)
            state = next_state
            rewards.append(reward)
            log_probs.append(log_prob)
        
        loss = agent.update(rewards,log_probs) 
        
        
        if (epoch+1) % 10 == 0:
#             torch.save('pong.pt',agent.policy_net)
            torch.save(agent.policy_net,'pong.pt')
            print(f'epoch: {epoch}, loss: {loss}, rewards: {sum(rewards)}, count: {len(rewards)}')

In [81]:
agent = Agent()

In [100]:
# torch.save(agent.policy_net,'pong.pt')

In [101]:
# torch.load('pong.pt')

PolicyNet(
  (linear1): Linear(in_features=6400, out_features=200, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=200, out_features=2, bias=True)
  (softmax): Softmax(dim=-1)
)

In [98]:
env = gym.make("Pong-v4")
train(agent,env)

epoch: 9, loss: 100.39936828613281, rewards: -19.0, count: 1851
epoch: 19, loss: 99.27107238769531, rewards: -21.0, count: 2740
epoch: 29, loss: 97.95178985595703, rewards: -19.0, count: 2083
epoch: 39, loss: 81.42912292480469, rewards: -20.0, count: 1716
epoch: 49, loss: 80.0657958984375, rewards: -19.0, count: 1842
epoch: 59, loss: 69.82518005371094, rewards: -20.0, count: 2199
epoch: 69, loss: 80.46188354492188, rewards: -21.0, count: 1502
epoch: 79, loss: 81.82455444335938, rewards: -17.0, count: 2011
epoch: 89, loss: 46.32903289794922, rewards: -20.0, count: 2446
epoch: 99, loss: 112.66027069091797, rewards: -19.0, count: 2011
epoch: 109, loss: 110.7575912475586, rewards: -17.0, count: 2094
epoch: 119, loss: 63.838226318359375, rewards: -20.0, count: 2045
epoch: 129, loss: 43.82166290283203, rewards: -20.0, count: 1410
epoch: 139, loss: 73.86117553710938, rewards: -18.0, count: 2363
epoch: 149, loss: 42.780494689941406, rewards: -20.0, count: 1806
epoch: 159, loss: 78.479263305664

epoch: 1269, loss: 61.92820739746094, rewards: -19.0, count: 3370
epoch: 1279, loss: 63.04595947265625, rewards: -21.0, count: 1747
epoch: 1289, loss: 20.298131942749023, rewards: -20.0, count: 2264
epoch: 1299, loss: 62.66165542602539, rewards: -20.0, count: 1881
epoch: 1309, loss: 49.06715393066406, rewards: -18.0, count: 2439
epoch: 1319, loss: 27.15872573852539, rewards: -18.0, count: 2433
epoch: 1329, loss: 71.82215118408203, rewards: -20.0, count: 2925
epoch: 1339, loss: 26.968229293823242, rewards: -18.0, count: 2211
epoch: 1349, loss: 22.21343231201172, rewards: -20.0, count: 2035
epoch: 1359, loss: 18.843164443969727, rewards: -20.0, count: 2026
epoch: 1369, loss: 47.34242248535156, rewards: -18.0, count: 2355
epoch: 1379, loss: 23.19776725769043, rewards: -19.0, count: 2086
epoch: 1389, loss: 61.0006103515625, rewards: -18.0, count: 2916
epoch: 1399, loss: 28.574010848999023, rewards: -18.0, count: 2059
epoch: 1409, loss: 14.454166412353516, rewards: -18.0, count: 2778
epoch:

epoch: 2509, loss: 36.82316589355469, rewards: -19.0, count: 2644
epoch: 2519, loss: 18.072111129760742, rewards: -18.0, count: 2311
epoch: 2529, loss: 9.523689270019531, rewards: -18.0, count: 1738
epoch: 2539, loss: 33.95619201660156, rewards: -18.0, count: 2465
epoch: 2549, loss: 15.05644416809082, rewards: -19.0, count: 2020
epoch: 2559, loss: 27.503055572509766, rewards: -21.0, count: 1504
epoch: 2569, loss: 23.564620971679688, rewards: -18.0, count: 1980
epoch: 2579, loss: 2.8135719299316406, rewards: -19.0, count: 2089
epoch: 2589, loss: 20.001726150512695, rewards: -20.0, count: 2196
epoch: 2599, loss: 28.566858291625977, rewards: -19.0, count: 2481
epoch: 2609, loss: 6.392823219299316, rewards: -19.0, count: 1913
epoch: 2619, loss: 12.390830993652344, rewards: -19.0, count: 2017
epoch: 2629, loss: 26.430553436279297, rewards: -16.0, count: 2548
epoch: 2639, loss: 35.101131439208984, rewards: -14.0, count: 2569
epoch: 2649, loss: -2.57442045211792, rewards: -19.0, count: 2165
e

epoch: 3739, loss: 34.86646270751953, rewards: -20.0, count: 2910
epoch: 3749, loss: 40.12380599975586, rewards: -18.0, count: 2600
epoch: 3759, loss: 1.619724154472351, rewards: -19.0, count: 3876
epoch: 3769, loss: 30.28267478942871, rewards: -19.0, count: 2964
epoch: 3779, loss: -36.03532791137695, rewards: -17.0, count: 4134
epoch: 3789, loss: 2.9764137268066406, rewards: -18.0, count: 3689
epoch: 3799, loss: 23.281028747558594, rewards: -19.0, count: 2324
epoch: 3809, loss: 50.944889068603516, rewards: -17.0, count: 4035
epoch: 3819, loss: 5.531700134277344, rewards: -14.0, count: 3927
epoch: 3829, loss: 23.434860229492188, rewards: -20.0, count: 2863
epoch: 3839, loss: 5.27295446395874, rewards: -18.0, count: 2842
epoch: 3849, loss: -0.329531192779541, rewards: -19.0, count: 3053
epoch: 3859, loss: 18.55010223388672, rewards: -20.0, count: 1798
epoch: 3869, loss: 17.076791763305664, rewards: -20.0, count: 2867
epoch: 3879, loss: 14.21294116973877, rewards: -17.0, count: 2798
epoc

epoch: 4969, loss: -46.909061431884766, rewards: -17.0, count: 4125
epoch: 4979, loss: 1.9770617485046387, rewards: -18.0, count: 2208
epoch: 4989, loss: -19.459033966064453, rewards: -16.0, count: 3443
epoch: 4999, loss: -24.501022338867188, rewards: -15.0, count: 3138
epoch: 5009, loss: -2.1939568519592285, rewards: -14.0, count: 3923
epoch: 5019, loss: 0.3761183023452759, rewards: -20.0, count: 2500
epoch: 5029, loss: 3.325768232345581, rewards: -7.0, count: 5052
epoch: 5039, loss: -27.791160583496094, rewards: -13.0, count: 5128
epoch: 5049, loss: 35.28559112548828, rewards: -16.0, count: 3453
epoch: 5059, loss: -12.46896743774414, rewards: -9.0, count: 4402
epoch: 5069, loss: -0.4447190761566162, rewards: -16.0, count: 3099
epoch: 5079, loss: -10.230735778808594, rewards: -17.0, count: 3703
epoch: 5089, loss: 6.249671936035156, rewards: -16.0, count: 3670
epoch: 5099, loss: 4.142050266265869, rewards: -15.0, count: 3612
epoch: 5109, loss: -0.892456591129303, rewards: -18.0, count:

KeyboardInterrupt: 

In [90]:
def sample_action(self,state):
    probs = self.policy_net(state) # 4
    if np.random.uniform() < 0.0:
        action = np.random.randint(0,2)
        return action + 2, torch.log(probs[action]+1e-8)
    dist = Categorical(probs)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    return action.item()+2,log_prob

# 替换方法
import types
agent.sample_action = types.MethodType(sample_action, agent)

In [None]:
import time
def visualize_agent(env, agent, num_episodes=5):
    """
    渲染显示智能体的行动
    """
    env = gym.make('CliffWalking-v0', render_mode='human')  # 创建可视化环境
    
    for episode in range(num_episodes):
        state_tuple = env.reset()
        state = state_tuple[0] if isinstance(state_tuple, tuple) else state_tuple
        total_reward = 0
        steps = 0
        done = False
        
        print(f"\nEpisode {episode + 1}")
        
        while not done:
            env.render()  # 渲染当前状态
            
            # 将状态转换为one-hot编码
            state_onehot = np.zeros(48)
            state_onehot[state] = 1
            
            # 使用训练好的策略选择动作
            with torch.no_grad():
                if np.random.random() < 0.0:
                    action = np.random.randint(0, 4)
                else:
                    state_tensor = torch.FloatTensor(state_onehot)
                    probs = agent.policy_net(state_tensor)
                    action = probs.argmax().item()  # 使用最可能的动作
            
            # 执行动作
            step_result = env.step(action)
            if len(step_result) == 4:
                next_state, reward, done, _ = step_result
            else:
                next_state, reward, terminated, truncated, _ = step_result
                done = terminated or truncated
            
            total_reward += reward
            steps += 1
            state = next_state
            
            # 添加小延迟使动作更容易观察
            time.sleep(0.5)
        
        print(f"Episode finished after {steps} steps. Total reward: {total_reward}")
    
    env.close()

# 在主程序最后添加：
if __name__ == "__main__":    
    # 训练完成后显示智能体行动
    print("\nVisualizing trained agent behavior...")
    env = gym.make('CliffWalking-v0',render_mode='human')
    visualize_agent(env, agent)

In [None]:
env.close()