In [206]:
import numpy as np
import pickle  # 用 pickle 替代 cPickle
import gym
np.bool8 = np.bool_

In [207]:
import gym

In [208]:
env = gym.make("CliffWalking-v0")
observation, _ = env.reset()

In [209]:
# 0: move up

# 1: move right

# 2: move down

# 3: move left

In [210]:
observation,_

(36, {'prob': 1})

In [211]:
for i in range(4):
    observation, reward, terminated, truncated, info = env.step(0)  # 更新为新的返回值
    print(i,observation,reward, terminated,truncated)

0 24 -1 False False
1 12 -1 False False
2 0 -1 False False
3 0 -1 False False


In [212]:
from torch.distributions import Categorical

In [213]:
from torch import nn
import torch

In [214]:
class MLP(nn.Module):
    
    
    def __init__(self,output_dim):
        super().__init__()
        self.embeddings = nn.Embedding(48,200)
        self.linear1 = nn.Linear(200,200,bias=False)
        self.linear2 = nn.Linear(200,output_dim,bias=False)
        
       # Xavier 初始化 + 偏置归零
        nn.init.xavier_normal_(self.embeddings.weight)
        nn.init.xavier_normal_(self.linear1.weight)
        nn.init.xavier_normal_(self.linear2.weight)

        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self,idxs):
        x = self.embeddings(idxs)
        x = self.relu(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return self.softmax(x)
    
    def sample_action(self,x):
#         with torch.no_grad():
                # 添加探索机制
        if np.random.random() < 0.6:
            action = np.random.randint(0, 4)
            probs = self.forward(x).squeeze()
#             print(probs.shape,action)
            return torch.LongTensor([action]), torch.log(probs[action]+1e-8).unsqueeze(dim=-1)
        
        prob = self.forward(x)
        dist = Categorical(prob)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action,log_prob

In [215]:
mm = MLP(4)

In [216]:
mm.sample_action(torch.LongTensor([1]))

(tensor([1]), tensor([-1.4070], grad_fn=<UnsqueezeBackward0>))

In [217]:
def discount_rewards(rewards):
    ans = np.zeros_like(rewards)
    adding = 0
    for t in reversed(range(len(rewards))):
        if rewards[t] != 0:
            adding = 0
        adding = adding * 0.99 + rewards[t]
        ans[t] = adding
    return ans

In [218]:
model = MLP(4)

In [219]:
from torch.optim import RMSprop
import torch

In [220]:
optimizer = RMSprop(model.parameters(), lr=0.001, alpha=0.99, eps=1e-5)

In [221]:
def get_base_reward(next_state, done):
    # 默认环境奖励
    if done:  # 到达终点
        return 10.0  # 原为0，调整为正向奖励
    elif is_cliff(next_state):  # 掉入悬崖
        return -100.0
    else:  # 普通移动
        return -1.0  # 保持原惩罚
def distance_reward(state):
    # 状态转换为坐标 (行, 列)
    row = state // 12
    col = state % 12
    # 终点坐标 (3, 11)
    target_row, target_col = 3, 11
    # 曼哈顿距离（横向距离优先）
    distance = (target_col - col) + 0.1 * abs(target_row - row)  # 横向距离权重更高
    return 0.5 * distance  # 系数可调
def safe_path_reward(state):
    row = state // 12
    # 若在第一行（索引0），给予额外奖励
    if row == 0:
        return 2.0  # 系数可调
    else:
        return 0.0
def safe_path_reward(state):
    row = state // 12
    # 若在第一行（索引0），给予额外奖励
    if row == 0:
        return 2.0  # 系数可调
    else:
        return 0.0
def progress_reward(state, prev_col):
    current_col = state % 12
    # 若向右移动，给予奖励
    if current_col > prev_col:
        return 1.0  # 系数可调
    else:
        return -0.5  # 向左或不动有轻微惩罚
    
def is_cliff(state: int) -> bool:
    """
    判断状态是否位于悬崖区域
    :param state: 当前状态编号（0到47之间的整数）
    :return: True（是悬崖） / False（非悬崖）
    """
    # 将状态转换为网格坐标（行、列）
    row = state // 12  # 行索引（0到3）
    col = state % 12   # 列索引（0到11）
    
    # 悬崖区域：第3行（索引2）的第1到10列
    return (row == 2) and (1 <= col <= 10)

class CustomCliffWalkingReward:
    def __init__(self):
        self.prev_col = 0  # 记录上一步的列坐标

    def reset(self):
        self.prev_col = 0

    def __call__(self, state, next_state, done):
        # 基础奖励
        base_reward = get_base_reward(next_state, done)
        
        # 距离奖励
        dist_r = distance_reward(next_state)
        
        # 安全路径奖励
        safe_r = safe_path_reward(next_state)
        
        # 进度奖励
        current_col = next_state % 12
        prog_r = progress_reward(next_state, self.prev_col)
        self.prev_col = current_col  # 更新列坐标
        
        # 合并奖励（加权求和）
        total_reward = base_reward + dist_r + safe_r + prog_r
        
        return total_reward

# 初始化奖励函数
reward_shaper = CustomCliffWalkingReward()

In [185]:
def sample_action(self,x):
#         with torch.no_grad():
            # 添加探索机制
    if np.random.random() < 0.1:
        action = np.random.randint(0, 4)
        probs = self.forward(x).squeeze()
#             print(probs.shape,action)
        return torch.LongTensor([action]), torch.log(probs[action]+1e-8).unsqueeze(dim=-1)

    prob = self.forward(x)
    dist = Categorical(prob)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    return action,log_prob

# 替换方法
import types
model.sample_action = types.MethodType(sample_action, model)

In [None]:
# env = CliffWalking2x3Env()
observation, _ = env.reset()
episode_number = 0
prev_x = None
xs = []
logps = []
hs = []
actions = []
ys = []
rewards = []
train_time = 0
total_sum = 0
max_times = 0
while True:
    prev_o = observation
    idxs = torch.LongTensor([observation]) # shape (1,)
    action,log_prob = model.sample_action(idxs) # shape 1,  1,
    
    action = action.item() ## 转成数
    logps.append(log_prob)
#     log_prob = log_prob.item()
    
    observation, reward, terminated, truncated, info = env.step(action)  # 更新为新的返回值
    
        # 修改奖励结构
    if reward == -100:  # 掉入悬崖
        reward = -10  # 减小惩罚力度，避免过度规避风险
    elif terminated and observation == 47:  # 到达目标状态（右下角）
        reward = 10  # 给予明显的正向奖励

#     reward = reward_shaper(prev_o,observation,terminated)
# #     print(action,reward)
#     if terminated:
#         reward = 1000
    
    rewards.append(reward*1.0)
    xs.append(prev_o)
    actions.append(action)
    total_sum += reward
    max_times += 1
    

    if terminated or truncated or max_times > 1000:
#         print(len(actions),actions[:30])
        episode_number += 1
        xs = np.vstack(xs).ravel()
        actions = np.vstack(actions)
        rewards = np.vstack(rewards).ravel()
        rewards = discount_rewards(rewards)
        rewards -= rewards.sum()
        rewards /= (rewards.std()+1e-8)

        rewards = torch.FloatTensor(rewards)
        xs = torch.LongTensor(xs)
        actions = torch.LongTensor(actions) # B,1
        
        ps = model(xs) ### B,4
        # 防止概率为0或1导致log计算爆炸
        logps = torch.cat(logps)
#         print(logps.shape,rewards.shape)
#         logps = torch.log(ps.gather(dim=-1,index=actions)).squeeze() # B,1
#         print('ps',ps.shape)
#         print('action',actions.shape)
#         print('rewards',rewards)
#         print('logps',logps)
        loss = -(logps * rewards).mean()
        #### 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#         print('rewards',rewards)
#         print('logps',ps,logps)
        print('total',loss,len(logps),total_sum)
        
        xs = []
        rewards = []
        logps = []
        hs = []
        actions = []
        prev_x = None
        total_sum = 0
        max_times = 0
        observation, _ = env.reset()
        
        

total tensor(103.3958, grad_fn=<NegBackward0>) 101 -108
total tensor(1605.2445, grad_fn=<NegBackward0>) 1001 -1469
total tensor(1658.8826, grad_fn=<NegBackward0>) 1001 -1523
total tensor(652.4825, grad_fn=<NegBackward0>) 426 -631
total tensor(724.8765, grad_fn=<NegBackward0>) 461 -720
total tensor(458.5815, grad_fn=<NegBackward0>) 353 -459
total tensor(1476.6172, grad_fn=<NegBackward0>) 1001 -1442
total tensor(127.8369, grad_fn=<NegBackward0>) 114 -139
total tensor(1645.8586, grad_fn=<NegBackward0>) 1001 -1586
total tensor(697.9135, grad_fn=<NegBackward0>) 524 -657
total tensor(175.7951, grad_fn=<NegBackward0>) 134 -204
total tensor(1269.9437, grad_fn=<NegBackward0>) 947 -1287
total tensor(134.5563, grad_fn=<NegBackward0>) 134 -177
total tensor(1284.7568, grad_fn=<NegBackward0>) 1001 -1343
total tensor(461.4749, grad_fn=<NegBackward0>) 378 -511
total tensor(1281.5402, grad_fn=<NegBackward0>) 1001 -1451
total tensor(1201.9324, grad_fn=<NegBackward0>) 1001 -1460
total tensor(341.2062, gr

total tensor(977.2396, grad_fn=<NegBackward0>) 1001 -1946
total tensor(973.9130, grad_fn=<NegBackward0>) 1001 -1865
total tensor(973.0195, grad_fn=<NegBackward0>) 1001 -1748
total tensor(987.3819, grad_fn=<NegBackward0>) 1001 -1595
total tensor(976.3616, grad_fn=<NegBackward0>) 1001 -1820
total tensor(984.3308, grad_fn=<NegBackward0>) 1001 -2009
total tensor(985.9248, grad_fn=<NegBackward0>) 1001 -1604
total tensor(132.9812, grad_fn=<NegBackward0>) 151 -284
total tensor(1008.0211, grad_fn=<NegBackward0>) 1001 -1496
total tensor(975.6362, grad_fn=<NegBackward0>) 1001 -1802
total tensor(975.4971, grad_fn=<NegBackward0>) 1001 -1820
total tensor(999.3065, grad_fn=<NegBackward0>) 1001 -1532
total tensor(977.8429, grad_fn=<NegBackward0>) 1001 -1721
total tensor(975.1783, grad_fn=<NegBackward0>) 1001 -1811
total tensor(984.3331, grad_fn=<NegBackward0>) 1001 -1622
total tensor(988.5807, grad_fn=<NegBackward0>) 1001 -1568
total tensor(980.2633, grad_fn=<NegBackward0>) 1001 -1937
total tensor(99

total tensor(981.0837, grad_fn=<NegBackward0>) 1001 -2009
total tensor(976.1320, grad_fn=<NegBackward0>) 1001 -1793
total tensor(693.2255, grad_fn=<NegBackward0>) 729 -1240
total tensor(991.3957, grad_fn=<NegBackward0>) 1001 -1577
total tensor(972.1302, grad_fn=<NegBackward0>) 1001 -1802
total tensor(987.2001, grad_fn=<NegBackward0>) 1001 -1622
total tensor(975.0482, grad_fn=<NegBackward0>) 1001 -1811
total tensor(478.4750, grad_fn=<NegBackward0>) 495 -718
total tensor(972.6896, grad_fn=<NegBackward0>) 1001 -1793
total tensor(974.9578, grad_fn=<NegBackward0>) 1001 -1757
total tensor(974.8975, grad_fn=<NegBackward0>) 1001 -1910
total tensor(975.0233, grad_fn=<NegBackward0>) 1001 -1838
total tensor(974.8759, grad_fn=<NegBackward0>) 1001 -1793
total tensor(975.4305, grad_fn=<NegBackward0>) 1001 -1739
total tensor(974.9070, grad_fn=<NegBackward0>) 1001 -1883
total tensor(974.7273, grad_fn=<NegBackward0>) 1001 -1784
total tensor(60.9108, grad_fn=<NegBackward0>) 75 -73
total tensor(358.5017,

total tensor(976.5488, grad_fn=<NegBackward0>) 1001 -1712
total tensor(975.1663, grad_fn=<NegBackward0>) 1001 -1892
total tensor(975.4703, grad_fn=<NegBackward0>) 1001 -1730
total tensor(974.4194, grad_fn=<NegBackward0>) 1001 -1874
total tensor(980.9238, grad_fn=<NegBackward0>) 1001 -1667
total tensor(977.2264, grad_fn=<NegBackward0>) 1001 -1928
total tensor(977.2518, grad_fn=<NegBackward0>) 1001 -1955
total tensor(883.0905, grad_fn=<NegBackward0>) 924 -1597
total tensor(974.9518, grad_fn=<NegBackward0>) 1001 -1820
total tensor(976.3486, grad_fn=<NegBackward0>) 1001 -1937
total tensor(974.9873, grad_fn=<NegBackward0>) 1001 -1901
total tensor(974.9828, grad_fn=<NegBackward0>) 1001 -1784
total tensor(974.6449, grad_fn=<NegBackward0>) 1001 -1892
total tensor(979.7445, grad_fn=<NegBackward0>) 1001 -1964
total tensor(983.9647, grad_fn=<NegBackward0>) 1001 -1622
total tensor(976.7778, grad_fn=<NegBackward0>) 1001 -1730
total tensor(975.0523, grad_fn=<NegBackward0>) 1001 -1766
total tensor(97

total tensor(975.0208, grad_fn=<NegBackward0>) 1001 -1748
total tensor(837.1588, grad_fn=<NegBackward0>) 865 -1349
total tensor(975.1447, grad_fn=<NegBackward0>) 1001 -1883
total tensor(974.0826, grad_fn=<NegBackward0>) 1001 -1820
total tensor(976.2225, grad_fn=<NegBackward0>) 1001 -1793
total tensor(975.8296, grad_fn=<NegBackward0>) 1001 -1910
total tensor(974.6727, grad_fn=<NegBackward0>) 1001 -1856
total tensor(976.4041, grad_fn=<NegBackward0>) 1001 -1712
total tensor(849.2159, grad_fn=<NegBackward0>) 883 -1430
total tensor(975.1460, grad_fn=<NegBackward0>) 1001 -1892
total tensor(976.4741, grad_fn=<NegBackward0>) 1001 -1892
total tensor(982.2851, grad_fn=<NegBackward0>) 1001 -2027
total tensor(975.7116, grad_fn=<NegBackward0>) 1001 -1901
total tensor(975.3069, grad_fn=<NegBackward0>) 1001 -1748
total tensor(986.2908, grad_fn=<NegBackward0>) 1001 -2063
total tensor(974.7208, grad_fn=<NegBackward0>) 1001 -1883
total tensor(976.2936, grad_fn=<NegBackward0>) 1001 -1892
total tensor(990

total tensor(979.4631, grad_fn=<NegBackward0>) 1001 -1982
total tensor(974.9645, grad_fn=<NegBackward0>) 1001 -1820
total tensor(975.0120, grad_fn=<NegBackward0>) 1001 -1784
total tensor(980.4601, grad_fn=<NegBackward0>) 1001 -1667
total tensor(977.6461, grad_fn=<NegBackward0>) 1001 -1703
total tensor(974.7169, grad_fn=<NegBackward0>) 1001 -1784
total tensor(980.5951, grad_fn=<NegBackward0>) 1001 -1667
total tensor(973.9219, grad_fn=<NegBackward0>) 1001 -1784
total tensor(974.6354, grad_fn=<NegBackward0>) 1001 -1865
total tensor(976.2838, grad_fn=<NegBackward0>) 1001 -1730
total tensor(978.4722, grad_fn=<NegBackward0>) 1001 -1685
total tensor(941.9990, grad_fn=<NegBackward0>) 973 -1538
total tensor(975.8426, grad_fn=<NegBackward0>) 1001 -1730
total tensor(978.5364, grad_fn=<NegBackward0>) 1001 -1685
total tensor(985.2059, grad_fn=<NegBackward0>) 1001 -1613
total tensor(975.6689, grad_fn=<NegBackward0>) 1001 -1910
total tensor(627.5876, grad_fn=<NegBackward0>) 645 -1390
total tensor(83.

total tensor(981.2369, grad_fn=<NegBackward0>) 1001 -2009
total tensor(978.4897, grad_fn=<NegBackward0>) 1001 -1973
total tensor(975.1922, grad_fn=<NegBackward0>) 1001 -1757
total tensor(205.9604, grad_fn=<NegBackward0>) 227 -423
total tensor(974.3973, grad_fn=<NegBackward0>) 1001 -1838
total tensor(209.9116, grad_fn=<NegBackward0>) 232 -356
total tensor(974.6002, grad_fn=<NegBackward0>) 1001 -1829
total tensor(984.5716, grad_fn=<NegBackward0>) 1001 -1631
total tensor(576.6793, grad_fn=<NegBackward0>) 576 -790
total tensor(976.8958, grad_fn=<NegBackward0>) 1001 -1721
total tensor(977.8438, grad_fn=<NegBackward0>) 1001 -1955
total tensor(974.9584, grad_fn=<NegBackward0>) 1001 -1766
total tensor(974.1541, grad_fn=<NegBackward0>) 1001 -1829
total tensor(979.6033, grad_fn=<NegBackward0>) 1001 -1982
total tensor(141.9005, grad_fn=<NegBackward0>) 161 -222
total tensor(133.8169, grad_fn=<NegBackward0>) 112 -101
total tensor(469.2984, grad_fn=<NegBackward0>) 491 -741
total tensor(975.2177, gra

total tensor(977.5846, grad_fn=<NegBackward0>) 1001 -1955
total tensor(976.8849, grad_fn=<NegBackward0>) 1001 -1712
total tensor(983.6606, grad_fn=<NegBackward0>) 1001 -2054
total tensor(993.3972, grad_fn=<NegBackward0>) 1001 -2162
total tensor(977.6192, grad_fn=<NegBackward0>) 1001 -1946
total tensor(973.9790, grad_fn=<NegBackward0>) 1001 -1775
total tensor(976.0363, grad_fn=<NegBackward0>) 1001 -1919
total tensor(975.2503, grad_fn=<NegBackward0>) 1001 -1847
total tensor(977.0776, grad_fn=<NegBackward0>) 1001 -1703
total tensor(107.4451, grad_fn=<NegBackward0>) 127 -179
total tensor(985.2336, grad_fn=<NegBackward0>) 1001 -2063
total tensor(978.3121, grad_fn=<NegBackward0>) 1001 -1973
total tensor(150.7765, grad_fn=<NegBackward0>) 167 -336
total tensor(975.9071, grad_fn=<NegBackward0>) 1001 -1757
total tensor(974.0603, grad_fn=<NegBackward0>) 1001 -1802
total tensor(724.1808, grad_fn=<NegBackward0>) 750 -1540
total tensor(984.7465, grad_fn=<NegBackward0>) 1001 -1631
total tensor(974.05

total tensor(973.8304, grad_fn=<NegBackward0>) 1001 -1829
total tensor(976.4727, grad_fn=<NegBackward0>) 1001 -1937
total tensor(977.2985, grad_fn=<NegBackward0>) 1001 -1712
total tensor(989.4697, grad_fn=<NegBackward0>) 1001 -2126
total tensor(976.4214, grad_fn=<NegBackward0>) 1001 -1928
total tensor(977.8829, grad_fn=<NegBackward0>) 1001 -1703
total tensor(658.4183, grad_fn=<NegBackward0>) 693 -1267
total tensor(984.8436, grad_fn=<NegBackward0>) 1001 -2072
total tensor(979.2573, grad_fn=<NegBackward0>) 1001 -1676
total tensor(979.2869, grad_fn=<NegBackward0>) 1001 -2009
total tensor(77.4877, grad_fn=<NegBackward0>) 92 -99
total tensor(974.9527, grad_fn=<NegBackward0>) 1001 -1856
total tensor(197.6115, grad_fn=<NegBackward0>) 203 -246
total tensor(976.5048, grad_fn=<NegBackward0>) 1001 -1901
total tensor(974.3334, grad_fn=<NegBackward0>) 1001 -1865
total tensor(975.9086, grad_fn=<NegBackward0>) 1001 -1919
total tensor(988.7063, grad_fn=<NegBackward0>) 1001 -2117
total tensor(989.6548,

total tensor(986.2364, grad_fn=<NegBackward0>) 1001 -2081
total tensor(974.8481, grad_fn=<NegBackward0>) 1001 -1829
total tensor(844.2546, grad_fn=<NegBackward0>) 879 -1705
total tensor(216.1822, grad_fn=<NegBackward0>) 231 -490
total tensor(991.6679, grad_fn=<NegBackward0>) 1001 -2153
total tensor(974.1887, grad_fn=<NegBackward0>) 1001 -1793
total tensor(975.0900, grad_fn=<NegBackward0>) 1001 -1865
total tensor(976.8420, grad_fn=<NegBackward0>) 1001 -1946
total tensor(975.2905, grad_fn=<NegBackward0>) 1001 -1838
total tensor(974.9669, grad_fn=<NegBackward0>) 1001 -1766
total tensor(974.4711, grad_fn=<NegBackward0>) 1001 -1838
total tensor(975.0574, grad_fn=<NegBackward0>) 1001 -1883
total tensor(976.3647, grad_fn=<NegBackward0>) 1001 -1928
total tensor(974.4050, grad_fn=<NegBackward0>) 1001 -1865
total tensor(973.9450, grad_fn=<NegBackward0>) 1001 -1820
total tensor(975.6713, grad_fn=<NegBackward0>) 1001 -1730
total tensor(990.5371, grad_fn=<NegBackward0>) 1001 -1586
total tensor(289.

total tensor(976.9145, grad_fn=<NegBackward0>) 1001 -1955
total tensor(983.0966, grad_fn=<NegBackward0>) 1001 -1640
total tensor(149.2548, grad_fn=<NegBackward0>) 170 -285
total tensor(977.2067, grad_fn=<NegBackward0>) 1001 -1946
total tensor(685.7981, grad_fn=<NegBackward0>) 719 -1194
total tensor(975.9849, grad_fn=<NegBackward0>) 1001 -1874
total tensor(209.5515, grad_fn=<NegBackward0>) 230 -336
total tensor(982.2710, grad_fn=<NegBackward0>) 1001 -1649
total tensor(976.5936, grad_fn=<NegBackward0>) 1001 -1730
total tensor(974.4781, grad_fn=<NegBackward0>) 1001 -1865
total tensor(1000.8843, grad_fn=<NegBackward0>) 1001 -1532
total tensor(976.6391, grad_fn=<NegBackward0>) 1001 -1730
total tensor(974.9945, grad_fn=<NegBackward0>) 1001 -1874
total tensor(973.7883, grad_fn=<NegBackward0>) 1001 -1838
total tensor(975.9349, grad_fn=<NegBackward0>) 1001 -1739
total tensor(974.9415, grad_fn=<NegBackward0>) 1001 -1766
total tensor(975.4268, grad_fn=<NegBackward0>) 1001 -1892
total tensor(974.8

total tensor(974.6129, grad_fn=<NegBackward0>) 1001 -1883
total tensor(973.7726, grad_fn=<NegBackward0>) 1001 -1820
total tensor(977.0892, grad_fn=<NegBackward0>) 1001 -1937
total tensor(974.4709, grad_fn=<NegBackward0>) 1001 -1784
total tensor(974.3510, grad_fn=<NegBackward0>) 1001 -1856
total tensor(981.8597, grad_fn=<NegBackward0>) 1001 -1658
total tensor(226.0160, grad_fn=<NegBackward0>) 247 -461
total tensor(977.3580, grad_fn=<NegBackward0>) 1001 -1946
total tensor(978.0164, grad_fn=<NegBackward0>) 1001 -1964
total tensor(974.0057, grad_fn=<NegBackward0>) 1001 -1811
total tensor(977.0681, grad_fn=<NegBackward0>) 1001 -1712
total tensor(974.7874, grad_fn=<NegBackward0>) 1001 -1874
total tensor(981.3076, grad_fn=<NegBackward0>) 1001 -1658
total tensor(333.9966, grad_fn=<NegBackward0>) 359 -573
total tensor(944.0042, grad_fn=<NegBackward0>) 986 -1713
total tensor(974.2016, grad_fn=<NegBackward0>) 1001 -1829
total tensor(976.3815, grad_fn=<NegBackward0>) 1001 -1928
total tensor(319.31

total tensor(978.0169, grad_fn=<NegBackward0>) 1001 -1955
total tensor(980.6692, grad_fn=<NegBackward0>) 1001 -2009
total tensor(975.9548, grad_fn=<NegBackward0>) 1001 -1910
total tensor(975.8639, grad_fn=<NegBackward0>) 1001 -1910
total tensor(978.0705, grad_fn=<NegBackward0>) 1001 -1694
total tensor(974.8865, grad_fn=<NegBackward0>) 1001 -1775
total tensor(1002.7408, grad_fn=<NegBackward0>) 1001 -1523
total tensor(1009.3819, grad_fn=<NegBackward0>) 1001 -2315
total tensor(976.5641, grad_fn=<NegBackward0>) 1001 -1919
total tensor(974.7328, grad_fn=<NegBackward0>) 1001 -1766
total tensor(976.1358, grad_fn=<NegBackward0>) 1001 -1919
total tensor(988.7278, grad_fn=<NegBackward0>) 1001 -1595
total tensor(974.7169, grad_fn=<NegBackward0>) 1001 -1901
total tensor(976.6736, grad_fn=<NegBackward0>) 1001 -1712
total tensor(848.7216, grad_fn=<NegBackward0>) 886 -1460
total tensor(974.4285, grad_fn=<NegBackward0>) 1001 -1847
total tensor(671.2661, grad_fn=<NegBackward0>) 699 -1399
total tensor(9

total tensor(974.2341, grad_fn=<NegBackward0>) 1001 -1820
total tensor(975.3032, grad_fn=<NegBackward0>) 1001 -1892
total tensor(976.5331, grad_fn=<NegBackward0>) 1001 -1919
total tensor(974.8115, grad_fn=<NegBackward0>) 1001 -1784
total tensor(975.3189, grad_fn=<NegBackward0>) 1001 -1775
total tensor(325.0294, grad_fn=<NegBackward0>) 346 -686
total tensor(974.1905, grad_fn=<NegBackward0>) 1001 -1838
total tensor(976.7880, grad_fn=<NegBackward0>) 1001 -1730
total tensor(911.0659, grad_fn=<NegBackward0>) 951 -1759
total tensor(977.6891, grad_fn=<NegBackward0>) 1001 -1955
total tensor(735.5247, grad_fn=<NegBackward0>) 771 -1300
total tensor(932.0596, grad_fn=<NegBackward0>) 974 -1710
total tensor(998.8555, grad_fn=<NegBackward0>) 1001 -1541
total tensor(638.1539, grad_fn=<NegBackward0>) 672 -1210
total tensor(974.1826, grad_fn=<NegBackward0>) 1001 -1820
total tensor(450.7178, grad_fn=<NegBackward0>) 479 -792
total tensor(974.2368, grad_fn=<NegBackward0>) 1001 -1847
total tensor(974.4292,

total tensor(974.2776, grad_fn=<NegBackward0>) 1001 -1865
total tensor(977.6044, grad_fn=<NegBackward0>) 1001 -1955
total tensor(979.2006, grad_fn=<NegBackward0>) 1001 -1982
total tensor(976.7750, grad_fn=<NegBackward0>) 1001 -1928
total tensor(974.3539, grad_fn=<NegBackward0>) 1001 -1802
total tensor(536.1988, grad_fn=<NegBackward0>) 567 -961
total tensor(975.5746, grad_fn=<NegBackward0>) 1001 -1757
total tensor(977.1207, grad_fn=<NegBackward0>) 1001 -1712
total tensor(975.9207, grad_fn=<NegBackward0>) 1001 -1910
total tensor(982.2372, grad_fn=<NegBackward0>) 1001 -2027
total tensor(979.2790, grad_fn=<NegBackward0>) 1001 -1982
total tensor(979.8705, grad_fn=<NegBackward0>) 1001 -1676
total tensor(974.5577, grad_fn=<NegBackward0>) 1001 -1838
total tensor(985.7316, grad_fn=<NegBackward0>) 1001 -2081
total tensor(737.2989, grad_fn=<NegBackward0>) 768 -1225
total tensor(981.7966, grad_fn=<NegBackward0>) 1001 -1658
total tensor(975.3378, grad_fn=<NegBackward0>) 1001 -1892
total tensor(977.

total tensor(994.9597, grad_fn=<NegBackward0>) 1001 -2189
total tensor(983.3659, grad_fn=<NegBackward0>) 1001 -1640
total tensor(979.7520, grad_fn=<NegBackward0>) 1001 -1982
total tensor(977.3315, grad_fn=<NegBackward0>) 1001 -1712
total tensor(977.4808, grad_fn=<NegBackward0>) 1001 -1955
total tensor(978.3359, grad_fn=<NegBackward0>) 1001 -1973
total tensor(974.5192, grad_fn=<NegBackward0>) 1001 -1856
total tensor(990.0177, grad_fn=<NegBackward0>) 1001 -2135
total tensor(974.8780, grad_fn=<NegBackward0>) 1001 -1883
total tensor(210.8587, grad_fn=<NegBackward0>) 224 -294
total tensor(974.9436, grad_fn=<NegBackward0>) 1001 -1766
total tensor(976.8272, grad_fn=<NegBackward0>) 1001 -1937
total tensor(980.3527, grad_fn=<NegBackward0>) 1001 -1667
total tensor(974.6550, grad_fn=<NegBackward0>) 1001 -1784
total tensor(976.7455, grad_fn=<NegBackward0>) 1001 -1937
total tensor(549.9261, grad_fn=<NegBackward0>) 582 -1003
total tensor(983.4794, grad_fn=<NegBackward0>) 1001 -2045
total tensor(974.

total tensor(974.9290, grad_fn=<NegBackward0>) 1001 -1883
total tensor(974.6972, grad_fn=<NegBackward0>) 1001 -1784
total tensor(979.4136, grad_fn=<NegBackward0>) 1001 -1676
total tensor(978.2606, grad_fn=<NegBackward0>) 1001 -1694
total tensor(979.9327, grad_fn=<NegBackward0>) 1001 -1667
total tensor(430.2388, grad_fn=<NegBackward0>) 459 -763
total tensor(974.7838, grad_fn=<NegBackward0>) 1001 -1874
total tensor(654.0964, grad_fn=<NegBackward0>) 687 -1144
total tensor(980.6427, grad_fn=<NegBackward0>) 1001 -1667
total tensor(974.2829, grad_fn=<NegBackward0>) 1001 -1793
total tensor(662.5247, grad_fn=<NegBackward0>) 696 -1270
total tensor(977.7841, grad_fn=<NegBackward0>) 1001 -1955
total tensor(981.4272, grad_fn=<NegBackward0>) 1001 -2018
total tensor(974.3326, grad_fn=<NegBackward0>) 1001 -1829
total tensor(975.4811, grad_fn=<NegBackward0>) 1001 -1748
total tensor(974.4308, grad_fn=<NegBackward0>) 1001 -1856
total tensor(978.7085, grad_fn=<NegBackward0>) 1001 -1973
total tensor(975.0

total tensor(975.6302, grad_fn=<NegBackward0>) 1001 -1739
total tensor(974.4052, grad_fn=<NegBackward0>) 1001 -1802
total tensor(980.2552, grad_fn=<NegBackward0>) 1001 -1667
total tensor(974.4911, grad_fn=<NegBackward0>) 1001 -1829
total tensor(826.1423, grad_fn=<NegBackward0>) 848 -1800
total tensor(974.4615, grad_fn=<NegBackward0>) 1001 -1856
total tensor(978.0492, grad_fn=<NegBackward0>) 1001 -1964
total tensor(793.9210, grad_fn=<NegBackward0>) 832 -1478
total tensor(974.8621, grad_fn=<NegBackward0>) 1001 -1874
total tensor(667.8792, grad_fn=<NegBackward0>) 699 -1354
total tensor(263.0002, grad_fn=<NegBackward0>) 286 -527
total tensor(369.6044, grad_fn=<NegBackward0>) 392 -597
total tensor(978.3268, grad_fn=<NegBackward0>) 1001 -1703
total tensor(713.7887, grad_fn=<NegBackward0>) 748 -1412
total tensor(978.2076, grad_fn=<NegBackward0>) 1001 -1955
total tensor(974.7753, grad_fn=<NegBackward0>) 1001 -1793
total tensor(976.6934, grad_fn=<NegBackward0>) 1001 -1937
total tensor(974.0624,

total tensor(976.7530, grad_fn=<NegBackward0>) 1001 -1937
total tensor(976.8325, grad_fn=<NegBackward0>) 1001 -1928
total tensor(985.3356, grad_fn=<NegBackward0>) 1001 -1622
total tensor(976.9935, grad_fn=<NegBackward0>) 1001 -1946
total tensor(982.4730, grad_fn=<NegBackward0>) 1001 -2036
total tensor(73.1456, grad_fn=<NegBackward0>) 92 -126
total tensor(976.0344, grad_fn=<NegBackward0>) 1001 -1928
total tensor(876.8474, grad_fn=<NegBackward0>) 917 -1671
total tensor(983.0913, grad_fn=<NegBackward0>) 1001 -2045
total tensor(979.9297, grad_fn=<NegBackward0>) 1001 -1676
total tensor(974.0272, grad_fn=<NegBackward0>) 1001 -1811
total tensor(973.9305, grad_fn=<NegBackward0>) 1001 -1838
total tensor(974.1091, grad_fn=<NegBackward0>) 1001 -1865
total tensor(976.9570, grad_fn=<NegBackward0>) 1001 -1721
total tensor(974.3588, grad_fn=<NegBackward0>) 1001 -1874
total tensor(975.7707, grad_fn=<NegBackward0>) 1001 -1910
total tensor(980.5742, grad_fn=<NegBackward0>) 1001 -2009
total tensor(974.32

total tensor(974.6113, grad_fn=<NegBackward0>) 1001 -1874
total tensor(980.2644, grad_fn=<NegBackward0>) 1001 -1667
total tensor(986.0136, grad_fn=<NegBackward0>) 1001 -2081
total tensor(974.6980, grad_fn=<NegBackward0>) 1001 -1865
total tensor(992.6861, grad_fn=<NegBackward0>) 1001 -2162
total tensor(975.0271, grad_fn=<NegBackward0>) 1001 -1874
total tensor(974.3136, grad_fn=<NegBackward0>) 1001 -1811
total tensor(995.0325, grad_fn=<NegBackward0>) 1001 -1559
total tensor(974.2779, grad_fn=<NegBackward0>) 1001 -1847
total tensor(974.4172, grad_fn=<NegBackward0>) 1001 -1793
total tensor(975.0706, grad_fn=<NegBackward0>) 1001 -1766
total tensor(975.2942, grad_fn=<NegBackward0>) 1001 -1901
total tensor(974.2980, grad_fn=<NegBackward0>) 1001 -1847
total tensor(975.1070, grad_fn=<NegBackward0>) 1001 -1910
total tensor(635.3826, grad_fn=<NegBackward0>) 661 -1343
total tensor(980.7341, grad_fn=<NegBackward0>) 1001 -2000
total tensor(581.0482, grad_fn=<NegBackward0>) 614 -1071
total tensor(974

total tensor(991.7339, grad_fn=<NegBackward0>) 1001 -2144
total tensor(342.8744, grad_fn=<NegBackward0>) 364 -722
total tensor(976.1344, grad_fn=<NegBackward0>) 1001 -1730
total tensor(987.9449, grad_fn=<NegBackward0>) 1001 -1604
total tensor(430.8287, grad_fn=<NegBackward0>) 453 -910
total tensor(102.7647, grad_fn=<NegBackward0>) 120 -217
total tensor(979.2141, grad_fn=<NegBackward0>) 1001 -1982
total tensor(991.8651, grad_fn=<NegBackward0>) 1001 -2153
total tensor(104.5792, grad_fn=<NegBackward0>) 121 -227
total tensor(974.4177, grad_fn=<NegBackward0>) 1001 -1838
total tensor(978.7020, grad_fn=<NegBackward0>) 1001 -1964
total tensor(133.2635, grad_fn=<NegBackward0>) 146 -315
total tensor(167.4648, grad_fn=<NegBackward0>) 189 -313
total tensor(986.8841, grad_fn=<NegBackward0>) 1001 -2099
total tensor(976.0906, grad_fn=<NegBackward0>) 1001 -1919
total tensor(980.4308, grad_fn=<NegBackward0>) 1001 -1667
total tensor(611.5718, grad_fn=<NegBackward0>) 642 -1225
total tensor(974.2088, grad

total tensor(981.2666, grad_fn=<NegBackward0>) 1001 -2018
total tensor(982.3528, grad_fn=<NegBackward0>) 1001 -1649
total tensor(982.2272, grad_fn=<NegBackward0>) 1001 -2036
total tensor(974.0010, grad_fn=<NegBackward0>) 1001 -1829
total tensor(978.1907, grad_fn=<NegBackward0>) 1001 -1694
total tensor(982.4268, grad_fn=<NegBackward0>) 1001 -1640
total tensor(974.7354, grad_fn=<NegBackward0>) 1001 -1865
total tensor(983.9814, grad_fn=<NegBackward0>) 1001 -1631
total tensor(989.1295, grad_fn=<NegBackward0>) 1001 -2126
total tensor(999.5116, grad_fn=<NegBackward0>) 1001 -1541
total tensor(976.1058, grad_fn=<NegBackward0>) 1001 -1757
total tensor(974.2590, grad_fn=<NegBackward0>) 1001 -1856
total tensor(978.5023, grad_fn=<NegBackward0>) 1001 -1973
total tensor(976.7609, grad_fn=<NegBackward0>) 1001 -1730
total tensor(977.9810, grad_fn=<NegBackward0>) 1001 -1703
total tensor(976.7625, grad_fn=<NegBackward0>) 1001 -1928
total tensor(981.8130, grad_fn=<NegBackward0>) 1001 -1658
total tensor(9

total tensor(981.4545, grad_fn=<NegBackward0>) 1001 -2027
total tensor(980.9434, grad_fn=<NegBackward0>) 1001 -1658
total tensor(983.7161, grad_fn=<NegBackward0>) 1001 -2054
total tensor(975.2858, grad_fn=<NegBackward0>) 1001 -1748
total tensor(975.3078, grad_fn=<NegBackward0>) 1001 -1901
total tensor(977.9985, grad_fn=<NegBackward0>) 1001 -1964
total tensor(973.6697, grad_fn=<NegBackward0>) 1001 -1820
total tensor(973.9173, grad_fn=<NegBackward0>) 1001 -1811
total tensor(977.6071, grad_fn=<NegBackward0>) 1001 -1712
total tensor(973.5096, grad_fn=<NegBackward0>) 1001 -1829
total tensor(974.4500, grad_fn=<NegBackward0>) 1001 -1802
total tensor(979.8038, grad_fn=<NegBackward0>) 1001 -1982
total tensor(975.7181, grad_fn=<NegBackward0>) 1001 -1910
total tensor(977.1027, grad_fn=<NegBackward0>) 1001 -1946
total tensor(982.4630, grad_fn=<NegBackward0>) 1001 -2036
total tensor(974.0623, grad_fn=<NegBackward0>) 1001 -1838
total tensor(974.7773, grad_fn=<NegBackward0>) 1001 -1775
total tensor(5

total tensor(977.3301, grad_fn=<NegBackward0>) 1001 -1946
total tensor(977.6738, grad_fn=<NegBackward0>) 1001 -1964
total tensor(974.2640, grad_fn=<NegBackward0>) 1001 -1784
total tensor(977.3352, grad_fn=<NegBackward0>) 1001 -1703
total tensor(976.1165, grad_fn=<NegBackward0>) 1001 -1919
total tensor(974.4271, grad_fn=<NegBackward0>) 1001 -1784
total tensor(978.3850, grad_fn=<NegBackward0>) 1001 -1973
total tensor(974.3923, grad_fn=<NegBackward0>) 1001 -1775
total tensor(976.4251, grad_fn=<NegBackward0>) 1001 -1730
total tensor(975.6133, grad_fn=<NegBackward0>) 1001 -1910
total tensor(974.4064, grad_fn=<NegBackward0>) 1001 -1793
total tensor(984.4764, grad_fn=<NegBackward0>) 1001 -2063
total tensor(832.2698, grad_fn=<NegBackward0>) 867 -1414
total tensor(977.6482, grad_fn=<NegBackward0>) 1001 -1964
total tensor(981.9033, grad_fn=<NegBackward0>) 1001 -2036
total tensor(975.2747, grad_fn=<NegBackward0>) 1001 -1892
total tensor(976.5797, grad_fn=<NegBackward0>) 1001 -1919
total tensor(97

total tensor(977.9979, grad_fn=<NegBackward0>) 1001 -1964
total tensor(974.9653, grad_fn=<NegBackward0>) 1001 -1883
total tensor(608.0649, grad_fn=<NegBackward0>) 636 -1255
total tensor(977.0875, grad_fn=<NegBackward0>) 1001 -1712
total tensor(974.1613, grad_fn=<NegBackward0>) 1001 -1838
total tensor(974.2325, grad_fn=<NegBackward0>) 1001 -1829
total tensor(974.6204, grad_fn=<NegBackward0>) 1001 -1775
total tensor(977.8008, grad_fn=<NegBackward0>) 1001 -1955
total tensor(975.1584, grad_fn=<NegBackward0>) 1001 -1766
total tensor(974.2800, grad_fn=<NegBackward0>) 1001 -1847
total tensor(975.9272, grad_fn=<NegBackward0>) 1001 -1730
total tensor(974.3429, grad_fn=<NegBackward0>) 1001 -1811
total tensor(975.2010, grad_fn=<NegBackward0>) 1001 -1766
total tensor(977.2405, grad_fn=<NegBackward0>) 1001 -1946
total tensor(974.1415, grad_fn=<NegBackward0>) 1001 -1829
total tensor(975.9358, grad_fn=<NegBackward0>) 1001 -1910
total tensor(975.8380, grad_fn=<NegBackward0>) 1001 -1739
total tensor(97

total tensor(974.7687, grad_fn=<NegBackward0>) 1001 -1766
total tensor(974.4420, grad_fn=<NegBackward0>) 1001 -1793
total tensor(979.6025, grad_fn=<NegBackward0>) 1001 -1676
total tensor(920.9684, grad_fn=<NegBackward0>) 960 -1822
total tensor(975.0795, grad_fn=<NegBackward0>) 1001 -1892
total tensor(974.9317, grad_fn=<NegBackward0>) 1001 -1766
total tensor(978.3433, grad_fn=<NegBackward0>) 1001 -1694
total tensor(974.9117, grad_fn=<NegBackward0>) 1001 -1775
total tensor(976.7013, grad_fn=<NegBackward0>) 1001 -1937
total tensor(975.1951, grad_fn=<NegBackward0>) 1001 -1766
total tensor(974.1115, grad_fn=<NegBackward0>) 1001 -1838
total tensor(977.5349, grad_fn=<NegBackward0>) 1001 -1703
total tensor(978.2871, grad_fn=<NegBackward0>) 1001 -1964
total tensor(978.4794, grad_fn=<NegBackward0>) 1001 -1973
total tensor(978.8239, grad_fn=<NegBackward0>) 1001 -1973
total tensor(974.3488, grad_fn=<NegBackward0>) 1001 -1838
total tensor(240.7513, grad_fn=<NegBackward0>) 263 -405
total tensor(974.

total tensor(979.1516, grad_fn=<NegBackward0>) 1001 -1685
total tensor(981.1088, grad_fn=<NegBackward0>) 1001 -1658
total tensor(172.8691, grad_fn=<NegBackward0>) 194 -336
total tensor(976.1224, grad_fn=<NegBackward0>) 1001 -1730
total tensor(983.7316, grad_fn=<NegBackward0>) 1001 -2054
total tensor(976.8764, grad_fn=<NegBackward0>) 1001 -1721
total tensor(974.2094, grad_fn=<NegBackward0>) 1001 -1793
total tensor(980.2609, grad_fn=<NegBackward0>) 1001 -1667
total tensor(974.2126, grad_fn=<NegBackward0>) 1001 -1802
total tensor(975.7886, grad_fn=<NegBackward0>) 1001 -1739
total tensor(975.7404, grad_fn=<NegBackward0>) 1001 -1910
total tensor(974.3483, grad_fn=<NegBackward0>) 1001 -1784
total tensor(150.9241, grad_fn=<NegBackward0>) 169 -320
total tensor(985.0491, grad_fn=<NegBackward0>) 1001 -2072
total tensor(769.2098, grad_fn=<NegBackward0>) 806 -1371
total tensor(975.2900, grad_fn=<NegBackward0>) 1001 -1883
total tensor(977.8616, grad_fn=<NegBackward0>) 1001 -1955
total tensor(986.87

total tensor(984.0237, grad_fn=<NegBackward0>) 1001 -1631
total tensor(980.8704, grad_fn=<NegBackward0>) 1001 -1658
total tensor(974.9354, grad_fn=<NegBackward0>) 1001 -1874
total tensor(397.7044, grad_fn=<NegBackward0>) 422 -654
total tensor(992.2008, grad_fn=<NegBackward0>) 1001 -1577
total tensor(974.5824, grad_fn=<NegBackward0>) 1001 -1784
total tensor(974.3149, grad_fn=<NegBackward0>) 1001 -1829
total tensor(975.1594, grad_fn=<NegBackward0>) 1001 -1892
total tensor(989.3514, grad_fn=<NegBackward0>) 1001 -1595
total tensor(991.4651, grad_fn=<NegBackward0>) 1001 -2144
total tensor(974.4311, grad_fn=<NegBackward0>) 1001 -1865
total tensor(974.5623, grad_fn=<NegBackward0>) 1001 -1829
total tensor(975.9131, grad_fn=<NegBackward0>) 1001 -1730
total tensor(979.4014, grad_fn=<NegBackward0>) 1001 -1991
total tensor(974.5112, grad_fn=<NegBackward0>) 1001 -1775
total tensor(978.8544, grad_fn=<NegBackward0>) 1001 -1685
total tensor(983.8455, grad_fn=<NegBackward0>) 1001 -2054
total tensor(974

total tensor(976.7584, grad_fn=<NegBackward0>) 1001 -1937
total tensor(357.0034, grad_fn=<NegBackward0>) 383 -624
total tensor(975.0487, grad_fn=<NegBackward0>) 1001 -1757
total tensor(974.7656, grad_fn=<NegBackward0>) 1001 -1865
total tensor(332.1354, grad_fn=<NegBackward0>) 345 -478
total tensor(974.6782, grad_fn=<NegBackward0>) 1001 -1775
total tensor(975.0034, grad_fn=<NegBackward0>) 1001 -1757
total tensor(986.2988, grad_fn=<NegBackward0>) 1001 -2090
total tensor(150.9333, grad_fn=<NegBackward0>) 172 -278
total tensor(974.4857, grad_fn=<NegBackward0>) 1001 -1847
total tensor(129.4340, grad_fn=<NegBackward0>) 148 -200
total tensor(975.7068, grad_fn=<NegBackward0>) 1001 -1730
total tensor(979.3967, grad_fn=<NegBackward0>) 1001 -1676
total tensor(975.9044, grad_fn=<NegBackward0>) 1001 -1919
total tensor(978.0847, grad_fn=<NegBackward0>) 1001 -1694
total tensor(831.0959, grad_fn=<NegBackward0>) 870 -1579
total tensor(975.2425, grad_fn=<NegBackward0>) 1001 -1748
total tensor(974.8680, 

total tensor(974.1847, grad_fn=<NegBackward0>) 1001 -1793
total tensor(975.8866, grad_fn=<NegBackward0>) 1001 -1910
total tensor(978.1876, grad_fn=<NegBackward0>) 1001 -1964
total tensor(974.2975, grad_fn=<NegBackward0>) 1001 -1784
total tensor(977.8800, grad_fn=<NegBackward0>) 1001 -1955
total tensor(987.5123, grad_fn=<NegBackward0>) 1001 -1604
total tensor(979.0326, grad_fn=<NegBackward0>) 1001 -1982
total tensor(974.6148, grad_fn=<NegBackward0>) 1001 -1856
total tensor(974.7585, grad_fn=<NegBackward0>) 1001 -1766
total tensor(974.1737, grad_fn=<NegBackward0>) 1001 -1829
total tensor(979.5936, grad_fn=<NegBackward0>) 1001 -1991
total tensor(975.1941, grad_fn=<NegBackward0>) 1001 -1883
total tensor(976.1992, grad_fn=<NegBackward0>) 1001 -1919
total tensor(981.9045, grad_fn=<NegBackward0>) 1001 -1649
total tensor(974.3159, grad_fn=<NegBackward0>) 1001 -1811
total tensor(981.9385, grad_fn=<NegBackward0>) 1001 -2027
total tensor(976.9057, grad_fn=<NegBackward0>) 1001 -1937
total tensor(9

total tensor(996.8013, grad_fn=<NegBackward0>) 1001 -1550
total tensor(974.7377, grad_fn=<NegBackward0>) 1001 -1874
total tensor(976.7912, grad_fn=<NegBackward0>) 1001 -1937
total tensor(976.4578, grad_fn=<NegBackward0>) 1001 -1928
total tensor(987.2399, grad_fn=<NegBackward0>) 1001 -2099
total tensor(974.5392, grad_fn=<NegBackward0>) 1001 -1829
total tensor(981.2057, grad_fn=<NegBackward0>) 1001 -2018
total tensor(984.1705, grad_fn=<NegBackward0>) 1001 -1631
total tensor(538.1922, grad_fn=<NegBackward0>) 566 -1095
total tensor(974.4252, grad_fn=<NegBackward0>) 1001 -1811
total tensor(976.8923, grad_fn=<NegBackward0>) 1001 -1937
total tensor(976.6835, grad_fn=<NegBackward0>) 1001 -1928
total tensor(975.9946, grad_fn=<NegBackward0>) 1001 -1739
total tensor(976.1798, grad_fn=<NegBackward0>) 1001 -1928
total tensor(975.4514, grad_fn=<NegBackward0>) 1001 -1748
total tensor(975.4705, grad_fn=<NegBackward0>) 1001 -1892
total tensor(975.5217, grad_fn=<NegBackward0>) 1001 -1901
total tensor(53

total tensor(974.6485, grad_fn=<NegBackward0>) 1001 -1865
total tensor(984.4385, grad_fn=<NegBackward0>) 1001 -2063
total tensor(974.8447, grad_fn=<NegBackward0>) 1001 -1766
total tensor(975.1177, grad_fn=<NegBackward0>) 1001 -1757
total tensor(975.4707, grad_fn=<NegBackward0>) 1001 -1901
total tensor(849.2209, grad_fn=<NegBackward0>) 887 -1488
total tensor(984.1364, grad_fn=<NegBackward0>) 1001 -1631
total tensor(977.3405, grad_fn=<NegBackward0>) 1001 -1946
total tensor(976.5621, grad_fn=<NegBackward0>) 1001 -1721
total tensor(976.3972, grad_fn=<NegBackward0>) 1001 -1730
total tensor(974.2582, grad_fn=<NegBackward0>) 1001 -1838
total tensor(583.2299, grad_fn=<NegBackward0>) 614 -1008
total tensor(974.5815, grad_fn=<NegBackward0>) 1001 -1793
total tensor(974.0949, grad_fn=<NegBackward0>) 1001 -1811
total tensor(978.6652, grad_fn=<NegBackward0>) 1001 -1685
total tensor(210.5912, grad_fn=<NegBackward0>) 229 -452
total tensor(974.5281, grad_fn=<NegBackward0>) 1001 -1856
total tensor(976.2

total tensor(975.3163, grad_fn=<NegBackward0>) 1001 -1892
total tensor(976.6241, grad_fn=<NegBackward0>) 1001 -1928
total tensor(975.3680, grad_fn=<NegBackward0>) 1001 -1748
total tensor(984.6466, grad_fn=<NegBackward0>) 1001 -2063
total tensor(974.6133, grad_fn=<NegBackward0>) 1001 -1775
total tensor(988.9747, grad_fn=<NegBackward0>) 1001 -2117
total tensor(988.0095, grad_fn=<NegBackward0>) 1001 -2108
total tensor(988.8547, grad_fn=<NegBackward0>) 1001 -1595
total tensor(976.6152, grad_fn=<NegBackward0>) 1001 -1721
total tensor(978.5872, grad_fn=<NegBackward0>) 1001 -1973
total tensor(976.1135, grad_fn=<NegBackward0>) 1001 -1919
total tensor(974.5005, grad_fn=<NegBackward0>) 1001 -1784
total tensor(974.2072, grad_fn=<NegBackward0>) 1001 -1829
total tensor(977.6227, grad_fn=<NegBackward0>) 1001 -1955
total tensor(974.3370, grad_fn=<NegBackward0>) 1001 -1856
total tensor(974.6764, grad_fn=<NegBackward0>) 1001 -1874
total tensor(974.3536, grad_fn=<NegBackward0>) 1001 -1793
total tensor(9

total tensor(974.3807, grad_fn=<NegBackward0>) 1001 -1802
total tensor(976.0495, grad_fn=<NegBackward0>) 1001 -1730
total tensor(989.6567, grad_fn=<NegBackward0>) 1001 -2126
total tensor(975.8629, grad_fn=<NegBackward0>) 1001 -1739
total tensor(980.1733, grad_fn=<NegBackward0>) 1001 -2000
total tensor(974.4266, grad_fn=<NegBackward0>) 1001 -1784
total tensor(28.4528, grad_fn=<NegBackward0>) 45 -34
total tensor(977.1087, grad_fn=<NegBackward0>) 1001 -1712
total tensor(979.3389, grad_fn=<NegBackward0>) 1001 -1676
total tensor(1012.8889, grad_fn=<NegBackward0>) 1001 -1487
total tensor(772.4556, grad_fn=<NegBackward0>) 797 -1659
total tensor(977.6423, grad_fn=<NegBackward0>) 1001 -1703
total tensor(978.9312, grad_fn=<NegBackward0>) 1001 -1685
total tensor(981.7633, grad_fn=<NegBackward0>) 1001 -2027
total tensor(980.1024, grad_fn=<NegBackward0>) 1001 -2000
total tensor(978.3177, grad_fn=<NegBackward0>) 1001 -1694
total tensor(974.2953, grad_fn=<NegBackward0>) 1001 -1811
total tensor(975.48

total tensor(974.8237, grad_fn=<NegBackward0>) 1001 -1775
total tensor(984.1759, grad_fn=<NegBackward0>) 1001 -1631
total tensor(978.2286, grad_fn=<NegBackward0>) 1001 -1964
total tensor(979.6595, grad_fn=<NegBackward0>) 1001 -1991
total tensor(974.2502, grad_fn=<NegBackward0>) 1001 -1829
total tensor(974.9014, grad_fn=<NegBackward0>) 1001 -1874
total tensor(974.4240, grad_fn=<NegBackward0>) 1001 -1811
total tensor(976.2050, grad_fn=<NegBackward0>) 1001 -1730
total tensor(977.5698, grad_fn=<NegBackward0>) 1001 -1703
total tensor(975.6707, grad_fn=<NegBackward0>) 1001 -1910
total tensor(975.8563, grad_fn=<NegBackward0>) 1001 -1910
total tensor(974.1208, grad_fn=<NegBackward0>) 1001 -1820
total tensor(977.0580, grad_fn=<NegBackward0>) 1001 -1712
total tensor(975.0790, grad_fn=<NegBackward0>) 1001 -1892
total tensor(985.2689, grad_fn=<NegBackward0>) 1001 -1622
total tensor(974.9067, grad_fn=<NegBackward0>) 1001 -1775
total tensor(974.6870, grad_fn=<NegBackward0>) 1001 -1775
total tensor(8

total tensor(974.4124, grad_fn=<NegBackward0>) 1001 -1856
total tensor(974.5046, grad_fn=<NegBackward0>) 1001 -1865
total tensor(890.8636, grad_fn=<NegBackward0>) 895 -1289
total tensor(977.2139, grad_fn=<NegBackward0>) 1001 -1946
total tensor(974.6915, grad_fn=<NegBackward0>) 1001 -1865
total tensor(974.6194, grad_fn=<NegBackward0>) 1001 -1775
total tensor(396.6337, grad_fn=<NegBackward0>) 424 -755
total tensor(974.3849, grad_fn=<NegBackward0>) 1001 -1811
total tensor(974.7747, grad_fn=<NegBackward0>) 1001 -1883
total tensor(977.0665, grad_fn=<NegBackward0>) 1001 -1712
total tensor(975.0955, grad_fn=<NegBackward0>) 1001 -1757
total tensor(978.7669, grad_fn=<NegBackward0>) 1001 -1973
total tensor(980.8324, grad_fn=<NegBackward0>) 1001 -1658
total tensor(974.2471, grad_fn=<NegBackward0>) 1001 -1847
total tensor(974.4893, grad_fn=<NegBackward0>) 1001 -1829
total tensor(976.1777, grad_fn=<NegBackward0>) 1001 -1919
total tensor(977.5344, grad_fn=<NegBackward0>) 1001 -1955
total tensor(974.

total tensor(974.3712, grad_fn=<NegBackward0>) 1001 -1784
total tensor(976.8745, grad_fn=<NegBackward0>) 1001 -1937
total tensor(974.8672, grad_fn=<NegBackward0>) 1001 -1883
total tensor(975.9958, grad_fn=<NegBackward0>) 1001 -1730
total tensor(974.4922, grad_fn=<NegBackward0>) 1001 -1766
total tensor(975.1481, grad_fn=<NegBackward0>) 1001 -1892
total tensor(975.8018, grad_fn=<NegBackward0>) 1001 -1910
total tensor(981.4572, grad_fn=<NegBackward0>) 1001 -2018
total tensor(974.3585, grad_fn=<NegBackward0>) 1001 -1865
total tensor(708.6122, grad_fn=<NegBackward0>) 745 -1301
total tensor(976.3038, grad_fn=<NegBackward0>) 1001 -1730
total tensor(362.1817, grad_fn=<NegBackward0>) 387 -610
total tensor(980.3076, grad_fn=<NegBackward0>) 1001 -1667
total tensor(973.8666, grad_fn=<NegBackward0>) 1001 -1811
total tensor(977.5308, grad_fn=<NegBackward0>) 1001 -1955
total tensor(984.1437, grad_fn=<NegBackward0>) 1001 -1631
total tensor(974.0456, grad_fn=<NegBackward0>) 1001 -1838
total tensor(982.

total tensor(974.6118, grad_fn=<NegBackward0>) 1001 -1865
total tensor(980.0938, grad_fn=<NegBackward0>) 1001 -2000
total tensor(974.7353, grad_fn=<NegBackward0>) 1001 -1874
total tensor(393.7507, grad_fn=<NegBackward0>) 421 -698
total tensor(976.5399, grad_fn=<NegBackward0>) 1001 -1928
total tensor(974.0960, grad_fn=<NegBackward0>) 1001 -1802
total tensor(982.1811, grad_fn=<NegBackward0>) 1001 -1649
total tensor(974.7161, grad_fn=<NegBackward0>) 1001 -1766
total tensor(984.9173, grad_fn=<NegBackward0>) 1001 -1622
total tensor(976.1417, grad_fn=<NegBackward0>) 1001 -1919
total tensor(974.0907, grad_fn=<NegBackward0>) 1001 -1784
total tensor(980.3431, grad_fn=<NegBackward0>) 1001 -2000
total tensor(974.3177, grad_fn=<NegBackward0>) 1001 -1811
total tensor(228.4847, grad_fn=<NegBackward0>) 249 -472
total tensor(974.5089, grad_fn=<NegBackward0>) 1001 -1847
total tensor(615.5543, grad_fn=<NegBackward0>) 644 -1020
total tensor(974.5792, grad_fn=<NegBackward0>) 1001 -1847
total tensor(976.09

total tensor(976.0003, grad_fn=<NegBackward0>) 1001 -1910
total tensor(975.7135, grad_fn=<NegBackward0>) 1001 -1748
total tensor(686.2181, grad_fn=<NegBackward0>) 720 -1204
total tensor(991.6931, grad_fn=<NegBackward0>) 1001 -1577
total tensor(977.4333, grad_fn=<NegBackward0>) 1001 -1703
total tensor(978.9214, grad_fn=<NegBackward0>) 1001 -1982
total tensor(985.0812, grad_fn=<NegBackward0>) 1001 -2072
total tensor(974.5355, grad_fn=<NegBackward0>) 1001 -1784
total tensor(994.5863, grad_fn=<NegBackward0>) 1001 -2180
total tensor(701.6884, grad_fn=<NegBackward0>) 737 -1347
total tensor(975.4229, grad_fn=<NegBackward0>) 1001 -1748
total tensor(986.2121, grad_fn=<NegBackward0>) 1001 -1613
total tensor(975.3721, grad_fn=<NegBackward0>) 1001 -1757
total tensor(975.5165, grad_fn=<NegBackward0>) 1001 -1748
total tensor(984.9801, grad_fn=<NegBackward0>) 1001 -1622
total tensor(973.9325, grad_fn=<NegBackward0>) 1001 -1802
total tensor(981.3952, grad_fn=<NegBackward0>) 1001 -1658
total tensor(982

total tensor(976.9219, grad_fn=<NegBackward0>) 1001 -1937
total tensor(604.8619, grad_fn=<NegBackward0>) 637 -1067
total tensor(974.1774, grad_fn=<NegBackward0>) 1001 -1847
total tensor(974.1810, grad_fn=<NegBackward0>) 1001 -1838
total tensor(486.6555, grad_fn=<NegBackward0>) 515 -837
total tensor(250.1419, grad_fn=<NegBackward0>) 274 -461
total tensor(976.5734, grad_fn=<NegBackward0>) 1001 -1721
total tensor(975.9899, grad_fn=<NegBackward0>) 1001 -1739
total tensor(982.2000, grad_fn=<NegBackward0>) 1001 -1649
total tensor(981.9836, grad_fn=<NegBackward0>) 1001 -2027
total tensor(725.6540, grad_fn=<NegBackward0>) 761 -1281
total tensor(981.9269, grad_fn=<NegBackward0>) 1001 -1649
total tensor(981.5564, grad_fn=<NegBackward0>) 1001 -2027
total tensor(801.4597, grad_fn=<NegBackward0>) 840 -1513
total tensor(974.7827, grad_fn=<NegBackward0>) 1001 -1865
total tensor(977.7054, grad_fn=<NegBackward0>) 1001 -1955
total tensor(975.5300, grad_fn=<NegBackward0>) 1001 -1739
total tensor(945.7990

total tensor(977.6846, grad_fn=<NegBackward0>) 1001 -1955
total tensor(974.7847, grad_fn=<NegBackward0>) 1001 -1874
total tensor(987.6787, grad_fn=<NegBackward0>) 1001 -1604
total tensor(977.0362, grad_fn=<NegBackward0>) 1001 -1937
total tensor(975.3033, grad_fn=<NegBackward0>) 1001 -1892
total tensor(977.2950, grad_fn=<NegBackward0>) 1001 -1946
total tensor(827.7232, grad_fn=<NegBackward0>) 864 -1645
total tensor(979.4971, grad_fn=<NegBackward0>) 1001 -1991
total tensor(974.6560, grad_fn=<NegBackward0>) 1001 -1874
total tensor(974.2183, grad_fn=<NegBackward0>) 1001 -1829
total tensor(974.9396, grad_fn=<NegBackward0>) 1001 -1883
total tensor(974.8209, grad_fn=<NegBackward0>) 1001 -1775
total tensor(979.6146, grad_fn=<NegBackward0>) 1001 -1991
total tensor(978.2168, grad_fn=<NegBackward0>) 1001 -1964
total tensor(866.2073, grad_fn=<NegBackward0>) 904 -1712
total tensor(974.8567, grad_fn=<NegBackward0>) 1001 -1874
total tensor(975.5123, grad_fn=<NegBackward0>) 1001 -1901
total tensor(977

total tensor(974.3234, grad_fn=<NegBackward0>) 1001 -1820
total tensor(983.1548, grad_fn=<NegBackward0>) 1001 -1640
total tensor(985.8862, grad_fn=<NegBackward0>) 1001 -2081
total tensor(978.5331, grad_fn=<NegBackward0>) 1001 -1973
total tensor(976.0781, grad_fn=<NegBackward0>) 1001 -1730
total tensor(974.1671, grad_fn=<NegBackward0>) 1001 -1838
total tensor(985.1293, grad_fn=<NegBackward0>) 1001 -1622
total tensor(974.6168, grad_fn=<NegBackward0>) 1001 -1775
total tensor(983.7308, grad_fn=<NegBackward0>) 1001 -2054
total tensor(782.7148, grad_fn=<NegBackward0>) 816 -1588
total tensor(976.1901, grad_fn=<NegBackward0>) 1001 -1919
total tensor(974.7098, grad_fn=<NegBackward0>) 1001 -1865
total tensor(974.4980, grad_fn=<NegBackward0>) 1001 -1784
total tensor(975.0681, grad_fn=<NegBackward0>) 1001 -1892
total tensor(974.1357, grad_fn=<NegBackward0>) 1001 -1838
total tensor(820.1125, grad_fn=<NegBackward0>) 844 -1292
total tensor(652.7596, grad_fn=<NegBackward0>) 683 -1329
total tensor(975.

total tensor(988.8380, grad_fn=<NegBackward0>) 1001 -2117
total tensor(410.0955, grad_fn=<NegBackward0>) 431 -879
total tensor(979.5430, grad_fn=<NegBackward0>) 1001 -1991
total tensor(974.3466, grad_fn=<NegBackward0>) 1001 -1793
total tensor(980.4734, grad_fn=<NegBackward0>) 1001 -1667
total tensor(974.2107, grad_fn=<NegBackward0>) 1001 -1802
total tensor(975.0602, grad_fn=<NegBackward0>) 1001 -1883
total tensor(975.3815, grad_fn=<NegBackward0>) 1001 -1892
total tensor(981.2162, grad_fn=<NegBackward0>) 1001 -1658
total tensor(975.5544, grad_fn=<NegBackward0>) 1001 -1748
total tensor(976.0004, grad_fn=<NegBackward0>) 1001 -1919
total tensor(978.5529, grad_fn=<NegBackward0>) 1001 -1973
total tensor(601.2280, grad_fn=<NegBackward0>) 626 -966
total tensor(978.1953, grad_fn=<NegBackward0>) 1001 -1694
total tensor(178.3441, grad_fn=<NegBackward0>) 200 -342
total tensor(975.0678, grad_fn=<NegBackward0>) 1001 -1883
total tensor(974.2437, grad_fn=<NegBackward0>) 1001 -1847
total tensor(976.185

total tensor(986.3732, grad_fn=<NegBackward0>) 1001 -1613
total tensor(974.2902, grad_fn=<NegBackward0>) 1001 -1829
total tensor(992.8145, grad_fn=<NegBackward0>) 1001 -2162
total tensor(979.7141, grad_fn=<NegBackward0>) 1001 -1991
total tensor(980.0886, grad_fn=<NegBackward0>) 1001 -2000
total tensor(978.8165, grad_fn=<NegBackward0>) 1001 -1685
total tensor(986.4881, grad_fn=<NegBackward0>) 1001 -1613
total tensor(433.2430, grad_fn=<NegBackward0>) 462 -784
total tensor(702.5507, grad_fn=<NegBackward0>) 732 -1162
total tensor(976.0132, grad_fn=<NegBackward0>) 1001 -1919
total tensor(978.5974, grad_fn=<NegBackward0>) 1001 -1973
total tensor(974.1021, grad_fn=<NegBackward0>) 1001 -1802
total tensor(974.6339, grad_fn=<NegBackward0>) 1001 -1865
total tensor(980.2742, grad_fn=<NegBackward0>) 1001 -2000
total tensor(980.7074, grad_fn=<NegBackward0>) 1001 -2009
total tensor(974.6647, grad_fn=<NegBackward0>) 1001 -1865
total tensor(980.1186, grad_fn=<NegBackward0>) 1001 -2000
total tensor(976.

total tensor(974.2487, grad_fn=<NegBackward0>) 1001 -1829
total tensor(433.7807, grad_fn=<NegBackward0>) 462 -838
total tensor(976.0580, grad_fn=<NegBackward0>) 1001 -1730
total tensor(983.2385, grad_fn=<NegBackward0>) 1001 -1640
total tensor(974.5844, grad_fn=<NegBackward0>) 1001 -1865
total tensor(989.5468, grad_fn=<NegBackward0>) 1001 -2126
total tensor(976.5922, grad_fn=<NegBackward0>) 1001 -1721
total tensor(976.9930, grad_fn=<NegBackward0>) 1001 -1712
total tensor(974.3488, grad_fn=<NegBackward0>) 1001 -1847
total tensor(492.6765, grad_fn=<NegBackward0>) 522 -862
total tensor(983.8633, grad_fn=<NegBackward0>) 1001 -2054
total tensor(983.1164, grad_fn=<NegBackward0>) 1001 -2045
total tensor(975.9874, grad_fn=<NegBackward0>) 1001 -1730
total tensor(975.5364, grad_fn=<NegBackward0>) 1001 -1748
total tensor(990.5518, grad_fn=<NegBackward0>) 1001 -1586
total tensor(974.6064, grad_fn=<NegBackward0>) 1001 -1784
total tensor(974.6444, grad_fn=<NegBackward0>) 1001 -1784
total tensor(974.1

total tensor(977.6623, grad_fn=<NegBackward0>) 1001 -1703
total tensor(976.5566, grad_fn=<NegBackward0>) 1001 -1739
total tensor(975.8550, grad_fn=<NegBackward0>) 1001 -1919
total tensor(975.5115, grad_fn=<NegBackward0>) 1001 -1748
total tensor(974.2460, grad_fn=<NegBackward0>) 1001 -1847
total tensor(973.9894, grad_fn=<NegBackward0>) 1001 -1847
total tensor(974.1763, grad_fn=<NegBackward0>) 1001 -1856
total tensor(985.1525, grad_fn=<NegBackward0>) 1001 -1622
total tensor(980.0492, grad_fn=<NegBackward0>) 1001 -1991
total tensor(998.1853, grad_fn=<NegBackward0>) 1001 -2216
total tensor(975.6647, grad_fn=<NegBackward0>) 1001 -1910
total tensor(997.3271, grad_fn=<NegBackward0>) 1001 -1550
total tensor(974.3301, grad_fn=<NegBackward0>) 1001 -1874
total tensor(978.4960, grad_fn=<NegBackward0>) 1001 -1694
total tensor(341.8052, grad_fn=<NegBackward0>) 368 -609
total tensor(974.5515, grad_fn=<NegBackward0>) 1001 -1775
total tensor(974.8498, grad_fn=<NegBackward0>) 1001 -1892
total tensor(92.

total tensor(974.7244, grad_fn=<NegBackward0>) 1001 -1865
total tensor(975.4954, grad_fn=<NegBackward0>) 1001 -1901
total tensor(404.5090, grad_fn=<NegBackward0>) 432 -727
total tensor(978.2018, grad_fn=<NegBackward0>) 1001 -1964
total tensor(977.2603, grad_fn=<NegBackward0>) 1001 -1946
total tensor(32.3621, grad_fn=<NegBackward0>) 50 -48
total tensor(993.5795, grad_fn=<NegBackward0>) 1001 -1568
total tensor(319.7371, grad_fn=<NegBackward0>) 335 -720
total tensor(984.4822, grad_fn=<NegBackward0>) 1001 -2063
total tensor(976.0131, grad_fn=<NegBackward0>) 1001 -1919
total tensor(974.9508, grad_fn=<NegBackward0>) 1001 -1766
total tensor(974.1773, grad_fn=<NegBackward0>) 1001 -1820
total tensor(974.3829, grad_fn=<NegBackward0>) 1001 -1811
total tensor(974.2792, grad_fn=<NegBackward0>) 1001 -1802
total tensor(974.6423, grad_fn=<NegBackward0>) 1001 -1775
total tensor(993.5497, grad_fn=<NegBackward0>) 1001 -1568
total tensor(974.6021, grad_fn=<NegBackward0>) 1001 -1865
total tensor(975.0590, 

In [None]:
np.log(0.6695)

In [None]:
for k in model.parameters():
    print(k)

In [None]:
torch.save(model.state_dict(),'model.pt')

In [None]:
!ls model.pt

In [None]:
env = gym.make("Pong-v4",render_mode="human")
observation, _ = env.reset()

In [31]:
model = MLP(D,1)

In [32]:
model.load_state_dict(torch.load('model.pt'))

In [33]:
prev_x = None
while True:
    current_x = prepro(observation)
    gap = np.zeros_like(current_x) if prev_x is None else current_x - prev_x
    prob = model(torch.from_numpy(gap))
    action = 2 if prob < 0.5 else 3
    observation, reward, terminated, truncated, info = env.step(action)  # 更新为新的返回值
    prev_x = current_x
    if terminated:
        observation, _ = env.reset()
        prev_x = None

In [34]:
env.close()

In [26]:
import numpy as np
print(11)

class CliffWalking2x3Env:
    def __init__(self):
        self.shape = (2, 3)  # 2行3列
        self.start_state = 3  # 起点在第二行第一列（状态3）
        self.cliff_state = 1  # 悬崖在第一行第二列（状态1）
        self.goal_state = 5  # 终点在第二行第三列（状态5）
        self.current_state = self.start_state
        
        # 动作空间：0=上，1=右，2=下，3=左
        self.action_space = 4
        self.observation_space = 6  # 状态空间大小
        
    def reset(self):
        """重置环境到初始状态"""
        self.current_state = self.start_state
        return self.current_state,{}
    
    def state_to_pos(self, state):
        """将状态编号转换为坐标"""
        row = state // 3
        col = state % 3
        return (row, col)
    
    def pos_to_state(self, row, col):
        """将坐标转换为状态编号"""
        return row * 3 + col
    
    def step(self, action):
        """执行动作并返回新状态、奖励、是否终止"""
        row, col = self.state_to_pos(self.current_state)
        new_row, new_col = row, col
        
        # 计算新坐标
        if action == 0:    # 上
            new_row = max(row - 1, 0)
        elif action == 1:  # 右
            new_col = min(col + 1, 2)
        elif action == 2:  # 下
            new_row = min(row + 1, 1)
        elif action == 3:  # 左
            new_col = max(col - 1, 0)
        
        # 检查是否越界（此逻辑已包含在max/min中，可省略）
        new_state = self.pos_to_state(new_row, new_col)
        reward = -1  # 默认每步-1奖励
        done = False
        
        # 判断是否掉下悬崖或到达终点
        if new_state == self.cliff_state:
            reward = -100
            self.current_state = self.start_state  # 回到起点
        elif new_state == self.goal_state:
            reward = 10
            done = True
            self.current_state = new_state
        else:
            self.current_state = new_state
        
        return self.current_state, reward, done,done, {}
    
    def render(self):
        """可视化当前状态"""
        grid = [
            ['S', 'C', 'S'],
            ['S', 'S', 'G']
        ]
        row, col = self.state_to_pos(self.current_state)
        grid[row][col] = 'A'  # 用A表示智能体位置
        print(f"+---+---+---+")
        print(f"|{grid[0][0]}|{grid[0][1]}|{grid[0][2]}|")
        print(f"+---+---+---+")
        print(f"|{grid[1][0]}|{grid[1][1]}|{grid[1][2]}|")
        print(f"+---+---+---+")
        print()


# env = CliffWalking2x3Env()
# state = env.reset()
# done = False

# print("初始状态：")
# env.render()

# # 示例动作序列：右→右→上→右→下→右
# actions = [1, 1, 0, 1, 2, 1]

# for action in actions:
#     next_state, reward, done, _,_ = env.step(action)
#     print(f"执行动作 {action} 后:")
#     env.render()
#     print(f"奖励: {reward}, 终止: {done}\n")
#     if done:
#         break

11
