In [1]:
import gym
import numpy as np

In [2]:
# 创建环境
env = gym.make('CliffWalking-v0',render_mode='human')

In [3]:
from torch import nn
import torch

In [4]:
class PolicyNet(nn.Module):
    
    def __init__(self,input_dim,output_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim,200)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(200,output_dim)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self,state):
        ### n
        x = self.linear1(state)
        x = self.relu(x)
        x = self.linear2(x) # n
        x = self.softmax(x) # n
        return x

In [5]:
class ValueNet(nn.Module):
    ### 用来学习值
    
    def __init__(self,input_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim,200)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(200,1)
    
    def forward(self,state):
        ### n
        x = self.linear1(state)
        x = self.relu(x)
        x = self.linear2(x) # n
        return x

In [6]:
from torch.distributions import Categorical
import numpy as np
np.bool8 = np.bool_

from torch.optim import AdamW

In [7]:
class Agent:
    
    def __init__(self):
        self.policy_net = PolicyNet(48,4)
        self.value_net = ValueNet(48)
        self.optimizer = AdamW(self.policy_net.parameters(),lr=1e-3)
        self.value_optimizer = AdamW(self.value_net.parameters(),lr=1e-3)
    
    def sample_action(self,state):
        probs = self.policy_net(state) # 4
        if np.random.uniform() < 0.5:
            action = np.random.randint(0,4)
            return action, torch.log(probs[action]+1e-8).detach()
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action).detach()
        return action.item(),log_prob
    
    def update(self,rewards,log_probs,xs,old_actions):
        ### 一次游戏时间
        ret = []
        adding = 0
        for r in rewards[::-1]:
            adding = adding * 0.99 + r
            ret.insert(0,adding)
        ret = torch.FloatTensor(ret)
        ret = ret - ret.mean()
        ret = ret / (ret.std()+1e-8)
        
        
        
        for _ in range(4):
            values = self.value_net(xs) ## B,48 ==> B,1
#             print(values.shape)

            new_probs = self.policy_net(xs) ## B,4
            dist = Categorical(new_probs) # B,4
#             actions = dist.sample() # B,1
            new_logprobs = dist.log_prob(old_actions.squeeze(-1)) # B,1
#             print('new log probs',new_logprobs.shape)
            advantages = ret - values.squeeze(-1).detach() # B
#             print('adv shape',advantages.shape,'ret shape',ret.shape,'value shape',values.shape,'log',log_probs.shape)
#             print('new_probs',new_probs.shape,'new_logprobs shape',new_logprobs.shape,'old_actions',old_actions.shape)
            ratio = torch.exp(new_logprobs - log_probs.squeeze(-1)) ## B
#             print('shape',ratio.shape)
            surr1 = ratio * advantages.squeeze(-1)
            surr2 = torch.clamp(ratio,0.8,1.2) * advantages
            loss = -torch.min(surr1,surr2).mean()


            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()


            value_loss = (values.squeeze(0) - ret).pow(2).mean()
            self.value_optimizer.zero_grad()
            value_loss.backward()
            self.value_optimizer.step()


    #         r_log_probs = []
    #         for r,log_prob in zip(ret,log_probs):
    #             r_log_probs.append(-r*log_prob)
    #         r_log_probs = torch.vstack(r_log_probs)

    #         loss = r_log_probs.sum()



            return loss

In [8]:
def convert2tensor(state):
    state_arr = np.zeros(48)
    state_arr[state] = 1
    state_arr = torch.FloatTensor(state_arr)
    return state_arr

In [9]:
def train(agent,env):
    success_count = []
    max_size = 1000
    for epoch in range(200000):
        rewards = []
        xs = []
        log_probs = []
        old_actions = []
        terminated = False
        success = True
        state,_ = env.reset()
        while not terminated and len(log_probs) < max_size:
            state_arr = convert2tensor(state)
            xs.append(state_arr)
            action, log_prob = agent.sample_action(state_arr)
            next_state, reward, terminated, truncated, _ = env.step(action)
            if reward == -100:
                reward = -10
                success = False
            if reward == -1 and state == 47:
                reward = 10
            state = next_state
            rewards.append(reward)
            log_probs.append(log_prob)
            old_actions.append(action)
        xs = torch.vstack(xs)
        log_probs = torch.vstack(log_probs)
        old_actions = torch.LongTensor(old_actions)
        loss = agent.update(rewards,log_probs,xs,old_actions) 
        success_count.append(success)
        
        
        if (epoch+1) % 10 == 0:
            print(f'success rate:  {len([s for s in success_count[-100:] if s]) / 100}')
            print(f'epoch: {epoch}, loss: {loss}, rewards: {sum(rewards)}, count: {len(rewards)}')

In [10]:
agent = Agent()



In [15]:
env = gym.make('CliffWalking-v0')
train(agent,env)

success rate:  0.1
epoch: 9, loss: -0.0021670260466635227, rewards: -232, count: 232
success rate:  0.2
epoch: 19, loss: -4.142966281506233e-05, rewards: -121, count: 121
success rate:  0.3
epoch: 29, loss: -3.549348912201822e-05, rewards: -332, count: 332
success rate:  0.4
epoch: 39, loss: 0.043539464473724365, rewards: -221, count: 221
success rate:  0.5
epoch: 49, loss: -3.709710290422663e-05, rewards: -202, count: 202
success rate:  0.6
epoch: 59, loss: -0.004479533061385155, rewards: -22, count: 22
success rate:  0.7
epoch: 69, loss: -0.0008025957504287362, rewards: -413, count: 413
success rate:  0.8
epoch: 79, loss: -7.551795988547383e-07, rewards: -212, count: 212
success rate:  0.9
epoch: 89, loss: -3.5992036373500014e-07, rewards: -26, count: 26
success rate:  0.99
epoch: 99, loss: 6.308158049250778e-07, rewards: -48, count: 48
success rate:  0.99
epoch: 109, loss: -2.554484801464696e-08, rewards: -35, count: 35
success rate:  0.99
epoch: 119, loss: -8.195011469069868e-05, r

success rate:  0.97
epoch: 989, loss: 1.4131409443507437e-05, rewards: -44, count: 35
success rate:  0.97
epoch: 999, loss: -1.6732796211726964e-05, rewards: -37, count: 37
success rate:  0.96
epoch: 1009, loss: -1.5368064850918017e-05, rewards: -21, count: 21
success rate:  0.96
epoch: 1019, loss: 1.2699295439233538e-05, rewards: -17, count: 17
success rate:  0.95
epoch: 1029, loss: 1.3030387890466955e-05, rewards: -22, count: 22
success rate:  0.95
epoch: 1039, loss: 3.680762165458873e-05, rewards: -17, count: 17
success rate:  0.95
epoch: 1049, loss: 1.0227615348412655e-05, rewards: -22, count: 22
success rate:  0.94
epoch: 1059, loss: -3.333389759063721e-05, rewards: -20, count: 20
success rate:  0.94
epoch: 1069, loss: 0.7994296550750732, rewards: -19, count: 19
success rate:  0.94
epoch: 1079, loss: -1.8080075960824615e-06, rewards: -18, count: 18
success rate:  0.96
epoch: 1089, loss: 1.961986163223628e-05, rewards: -18, count: 18
success rate:  0.96
epoch: 1099, loss: 4.5790393

success rate:  0.99
epoch: 1969, loss: 1.4129806913842913e-05, rewards: -17, count: 17
success rate:  0.99
epoch: 1979, loss: -1.2068187061231583e-05, rewards: -17, count: 17
success rate:  0.98
epoch: 1989, loss: -5.1372189773246646e-05, rewards: -17, count: 17
success rate:  0.96
epoch: 1999, loss: 0.03667224943637848, rewards: -38, count: 29
success rate:  0.95
epoch: 2009, loss: 3.4844175388570875e-05, rewards: -17, count: 17
success rate:  0.95
epoch: 2019, loss: 9.038869393407367e-06, rewards: -17, count: 17
success rate:  0.95
epoch: 2029, loss: -2.2600679585593753e-05, rewards: -17, count: 17
success rate:  0.96
epoch: 2039, loss: -6.472363111242885e-06, rewards: -17, count: 17
success rate:  0.96
epoch: 2049, loss: 5.5775923101464286e-05, rewards: -17, count: 17
success rate:  0.96
epoch: 2059, loss: 9.186127499560826e-06, rewards: -17, count: 17
success rate:  0.96
epoch: 2069, loss: -8.57605664350558e-06, rewards: -17, count: 17
success rate:  0.96
epoch: 2079, loss: -5.6308

success rate:  0.97
epoch: 2969, loss: 3.195510362274945e-05, rewards: -17, count: 17
success rate:  0.97
epoch: 2979, loss: -3.5545406717574224e-05, rewards: -17, count: 17
success rate:  0.97
epoch: 2989, loss: -0.00023607646289747208, rewards: -17, count: 17
success rate:  0.97
epoch: 2999, loss: -0.00024008049513213336, rewards: -17, count: 17
success rate:  0.97
epoch: 3009, loss: -9.400003182236105e-05, rewards: -17, count: 17
success rate:  0.96
epoch: 3019, loss: 4.100799560546875e-05, rewards: -17, count: 17
success rate:  0.97
epoch: 3029, loss: -0.00010611029574647546, rewards: -17, count: 17
success rate:  0.97
epoch: 3039, loss: 0.00015987534425221384, rewards: -19, count: 19
success rate:  0.98
epoch: 3049, loss: -0.00016066606622189283, rewards: -17, count: 17
success rate:  0.98
epoch: 3059, loss: -3.637987174442969e-05, rewards: -17, count: 17
success rate:  0.98
epoch: 3069, loss: 0.00016046271775849164, rewards: -17, count: 17
success rate:  0.97
epoch: 3079, loss: 6

success rate:  0.99
epoch: 3949, loss: 2.3057064026943408e-05, rewards: -18, count: 18
success rate:  0.99
epoch: 3959, loss: -2.7060508728027344e-05, rewards: -17, count: 17
success rate:  0.99
epoch: 3969, loss: 4.377084769657813e-05, rewards: -17, count: 17
success rate:  0.99
epoch: 3979, loss: -0.00014030933380126953, rewards: -17, count: 17
success rate:  0.98
epoch: 3989, loss: -0.0001513677416369319, rewards: -17, count: 17
success rate:  0.98
epoch: 3999, loss: -0.0002603601024020463, rewards: -17, count: 17
success rate:  0.98
epoch: 4009, loss: -0.00024473335361108184, rewards: -18, count: 18
success rate:  0.98
epoch: 4019, loss: 0.00011493878992041573, rewards: -17, count: 17
success rate:  0.98
epoch: 4029, loss: 9.954677079804242e-05, rewards: -17, count: 17
success rate:  0.99
epoch: 4039, loss: 2.4907729311962612e-05, rewards: -17, count: 17
success rate:  0.99
epoch: 4049, loss: 3.0293183954199776e-05, rewards: -17, count: 17
success rate:  0.99
epoch: 4059, loss: -5.

success rate:  1.0
epoch: 4939, loss: -4.492787775234319e-05, rewards: -17, count: 17
success rate:  1.0
epoch: 4949, loss: 8.15812309156172e-05, rewards: -17, count: 17
success rate:  1.0
epoch: 4959, loss: 2.6737941880128346e-05, rewards: -17, count: 17
success rate:  1.0
epoch: 4969, loss: -1.7530388504383154e-05, rewards: -18, count: 18
success rate:  1.0
epoch: 4979, loss: 9.817235735454233e-08, rewards: -17, count: 17
success rate:  1.0
epoch: 4989, loss: -5.329356440597621e-07, rewards: -17, count: 17
success rate:  1.0
epoch: 4999, loss: 3.814697265625e-06, rewards: -17, count: 17
success rate:  1.0
epoch: 5009, loss: -1.2649429663724732e-06, rewards: -18, count: 18
success rate:  0.99
epoch: 5019, loss: -1.430511474609375e-06, rewards: -17, count: 17
success rate:  0.99
epoch: 5029, loss: -2.0298693925724365e-06, rewards: -18, count: 18
success rate:  0.99
epoch: 5039, loss: -3.085417006332136e-07, rewards: -17, count: 17
success rate:  0.99
epoch: 5049, loss: -0.0, rewards: -

success rate:  0.99
epoch: 5939, loss: 4.3476327959979244e-07, rewards: -17, count: 17
success rate:  0.99
epoch: 5949, loss: 2.0077354179193208e-07, rewards: -19, count: 19
success rate:  0.99
epoch: 5959, loss: 5.6098489409350805e-08, rewards: -17, count: 17
success rate:  0.99
epoch: 5969, loss: 9.817235735454233e-08, rewards: -17, count: 17
success rate:  1.0
epoch: 5979, loss: 1.4024622352337701e-08, rewards: -17, count: 17
success rate:  1.0
epoch: 5989, loss: 1.472585324790998e-07, rewards: -17, count: 17
success rate:  1.0
epoch: 5999, loss: -4.207386794519152e-08, rewards: -17, count: 17
success rate:  1.0
epoch: 6009, loss: -1.4024622352337701e-08, rewards: -17, count: 17
success rate:  1.0
epoch: 6019, loss: 1.3323391101494053e-07, rewards: -17, count: 17
success rate:  1.0
epoch: 6029, loss: 7.020102543719986e-07, rewards: -18, count: 18
success rate:  1.0
epoch: 6039, loss: 4.9086178677271164e-08, rewards: -17, count: 17
success rate:  1.0
epoch: 6049, loss: 2.945170649581

success rate:  0.98
epoch: 6899, loss: -5.743082965636859e-06, rewards: -17, count: 17
success rate:  0.98
epoch: 6909, loss: -1.116359908337472e-05, rewards: -17, count: 17
success rate:  0.98
epoch: 6919, loss: 5.220112143433653e-06, rewards: -19, count: 19
success rate:  0.98
epoch: 6929, loss: 0.8668215274810791, rewards: -27, count: 18
success rate:  0.98
epoch: 6939, loss: -4.184246063232422e-05, rewards: -17, count: 17
success rate:  0.98
epoch: 6949, loss: -3.834331801044755e-05, rewards: -17, count: 17
success rate:  0.98
epoch: 6959, loss: 2.2579642973141745e-06, rewards: -17, count: 17
success rate:  0.98
epoch: 6969, loss: 2.154182038793806e-05, rewards: -17, count: 17
success rate:  0.99
epoch: 6979, loss: 2.086162567138672e-05, rewards: -17, count: 17
success rate:  0.99
epoch: 6989, loss: 0.9196115136146545, rewards: -21, count: 21
success rate:  0.99
epoch: 6999, loss: 2.6015675302915042e-06, rewards: -17, count: 17
success rate:  0.99
epoch: 7009, loss: -1.269228278033

KeyboardInterrupt: 

In [14]:
def sample_action(self,state):
    probs = self.policy_net(state) # 4
    if np.random.uniform() < 0.01:
        action = np.random.randint(0,4)
        return action, torch.log(probs[action]+1e-8).detach()
    dist = Categorical(probs)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    return action.item(),log_prob.detach()

# 替换方法
import types
agent.sample_action = types.MethodType(sample_action, agent)

In [16]:
import time
def visualize_agent(env, agent, num_episodes=5):
    """
    渲染显示智能体的行动
    """
    env = gym.make('CliffWalking-v0', render_mode='human')  # 创建可视化环境
    
    for episode in range(num_episodes):
        state_tuple = env.reset()
        state = state_tuple[0] if isinstance(state_tuple, tuple) else state_tuple
        total_reward = 0
        steps = 0
        done = False
        
        print(f"\nEpisode {episode + 1}")
        
        while not done:
            env.render()  # 渲染当前状态
            
            # 将状态转换为one-hot编码
            state_onehot = np.zeros(48)
            state_onehot[state] = 1
            
            # 使用训练好的策略选择动作
            with torch.no_grad():
                if np.random.random() < 0.0:
                    action = np.random.randint(0, 4)
                else:
                    state_tensor = torch.FloatTensor(state_onehot)
                    probs = agent.policy_net(state_tensor)
                    action = probs.argmax().item()  # 使用最可能的动作
            
            # 执行动作
            step_result = env.step(action)
            if len(step_result) == 4:
                next_state, reward, done, _ = step_result
            else:
                next_state, reward, terminated, truncated, _ = step_result
                done = terminated or truncated
            
            total_reward += reward
            steps += 1
            state = next_state
            
            # 添加小延迟使动作更容易观察
            time.sleep(0.5)
        
        print(f"Episode finished after {steps} steps. Total reward: {total_reward}")
    
    env.close()

# 在主程序最后添加：
if __name__ == "__main__":    
    # 训练完成后显示智能体行动
    print("\nVisualizing trained agent behavior...")
    env = gym.make('CliffWalking-v0',render_mode='human')
    visualize_agent(env, agent)


Visualizing trained agent behavior...


2025-04-27 13:52:09.865 python[67964:170339742] +[IMKClient subclass]: chose IMKClient_Modern
2025-04-27 13:52:09.865 python[67964:170339742] +[IMKInputSession subclass]: chose IMKInputSession_Modern



Episode 1
Episode finished after 17 steps. Total reward: -17

Episode 2


KeyboardInterrupt: 

In [None]:
env.close()