In [1]:
import gymnasium as gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils
import torch.nn as nn
import mars_explorer

class PolicyNet(nn.Module):
    def __init__(self):
        super(PolicyNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=8, stride=2), nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1), nn.ReLU())
        self.flat = nn.Flatten()
        self.fc1 = nn.Sequential(nn.Linear(288, 512), nn.Tanh())
        self.drop = nn.Dropout(0.5)
        self.fc3 = nn.Sequential(nn.Linear(512, 4)) 
        
    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.flat(output)
        output = self.drop(output)
        output = self.fc1(output)
        return nn.functional.softmax(self.fc3(output), dim=1)
class ValueNet(nn.Module):
    def __init__(self):
        super(ValueNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=8, stride=2), nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1), nn.ReLU())
        self.flat = nn.Flatten()
        self.fc1 = nn.Sequential(nn.Linear(288, 512), nn.Tanh())
        self.drop = nn.Dropout(0.5)
        self.fc3 = nn.Sequential(nn.Linear(512, 1)) 
        
    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.flat(output)
        output = self.drop(output)
        output = self.fc1(output)
        return self.fc3(output)
    
    
class PPO:
    ''' PPO算法,采用截断方式 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 lmbda, epochs, eps, gamma, device):
        self.actor = PolicyNet().to(device)
        self.critic = ValueNet().to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs  # 一条序列的数据用来训练轮数
        self.eps = eps  # PPO中截断范围的参数
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -
                                                                       dones)
        td_delta = td_target - self.critic(states)
        advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,
                                               td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1,
                                                            actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - self.eps,
                                1 + self.eps) * advantage  # 截断
            actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
            critic_loss = torch.mean(
                F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()

In [None]:
actor_lr = 1e-4
critic_lr = 1e-4
num_episodes = 5e5
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'mars-explorer-v1'
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,
            epochs, eps, gamma, device)

return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)

In [None]:
with open('return_list.txt', 'w', encoding='utf-8') as f:
    for i in range(0, len(return_list), 10000):
        avg = np.average(return_list[i:i + 10000])
        f.write(str(avg) + '\n')

In [None]:
with open('return_list.txt', 'r', encoding='utf-8') as f:
    return_list = f.readlines()


return_list = [i[:-2] for i in return_list] # 把末尾的换行删了
return_list = [float(i) for i in return_list] # 把字符串变成数字


import csv
### 记得改路径
with open("D:/group/reinforcement/data/local/20240313/GridworldExperiment_equivariant/run_0/progress.csv",
          'r') as f:
    reader = csv.DictReader(f)
    column_equivariant = [row['ReturnMax'] for row in reader]

# 橙色是普通的，蓝色是对称的

In [None]:
plt.plot(return_list, color="orange")
plt.plot(column_equivariant, color="blue")

plt.show()