In [32]:
from torch import nn

In [33]:
import torch.nn.functional as F

In [34]:
from collections import deque

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from collections import deque
from torch.distributions import Categorical
import numpy as np
import gym
import os
import copy
import matplotlib.pyplot as plt
import seaborn as sns
np.bool8 = np.bool_
from torch.optim import Adam

In [36]:
env = gym.make("CartPole-v1")  # 或 "rgb_array"
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
print(f"状态空间维度：{n_states}，动作空间维度：{n_actions}")
state,_ = env.reset()
next_state, reward, terminated, truncated, _ = env.step(0)  # 注意这里的返回值

状态空间维度：4，动作空间维度：2


In [37]:
class Actor(nn.Module):
    
    def __init__(self,input_dim,ouput_dim,hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,ouput_dim)
    
    
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        probs = F.softmax(x,dim=-1)
        return probs

In [38]:
class Critic(nn.Module):
    
    def __init__(self,input_dim,ouput_dim,hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,ouput_dim)
    
    
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

In [53]:
 class ReplayQue:
        
        def __init__(self):
            self.buffer = deque()
        
        
        def push(self,transitions):
            self.buffer.append(transitions)
        
        def clear(self):
            self.buffer.clear()
        
        def __len__(self):
            return len(self.buffer)
        
        def sample(self):
            return zip(*list(self.buffer))

In [58]:
class Agent:
    
    
    def __init__(self):
        
        self.actor = Actor(n_states,n_actions,256)
        self.critic = Critic(n_states,1,256)
        self.actor_optimizer = Adam(self.actor.parameters())
        self.critic_optimizer = Adam(self.critic.parameters())
        self.memory = ReplayQue()
    
    
    def sample_action(self,state):
        state = torch.tensor(state)
        probs = self.actor(state)
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.detach().cpu().numpy().item(), log_prob.detach()
    
    
    @torch.no_grad()
    def predict(self,state):
        state = torch.tensor(state)
        probs = self.actor(state)
        dist = Categorical(probs)
        action = dist.sample()
        return action.detach().cpu().numpy().item()
        
    
    def update(self):
        if len(self.memory) % 100 != 0:
            return
    
        
        old_actions,old_states,old_rewards,old_dones,old_logprobs = self.memory.sample()

        old_actions = torch.tensor(old_actions)
        old_states = torch.tensor(old_states)
        old_logprobs = torch.tensor(old_logprobs)
        
        returns = []
        discount_sum = 0
        for reward,done in zip(reversed(old_rewards),reversed(old_dones)):
            if done:
                discount_sum = 0
            discount_sum = discount_sum * 0.99 + reward
            returns.insert(0,discount_sum)
        
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        
        for _ in range(4):
            values = self.critic(old_states)
            probs = self.actor(old_states)
            dist = Categorical(probs)
            logprobs = dist.log_prob(old_actions)
            ratio = torch.exp(logprobs - old_logprobs)
            advantage = returns - values.detach()
            surr1 = advantage * ratio
            surr2 = torch.clamp(ratio,0.8,1.2) * advantage
            actor_loss = -torch.min(surr1,surr2).mean() + 0.01 * dist.entropy().mean()
            critic_loss = (returns - values).pow(2).mean()
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            
            actor_loss.backward()
            critic_loss.backward()
            
            self.actor_optimizer.step()
            self.critic_optimizer.step()
        self.memory.clear()

In [59]:
def train(env,agent):
    
    for eph in range(10000):
        
        ### train
        state,_ = env.reset()
        for _ in range(1000):
            
            action,logprob = agent.sample_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)  # 注意这里的返回值
            agent.memory.push((action,state,reward,terminated,logprob))
            state = next_state
            agent.update()
            if terminated:
                break
        
        ### eval
        if (eph+1) % 100 == 0:
            reward_sum = 0
            state,_ = env.reset()
            for _ in range(1000):
                action = agent.predict(state)
                next_state, reward, terminated, truncated, _ = env.step(action)  # 注意这里的返回值
                reward_sum += reward
                state = next_state
                if terminated:
                    break
            print('reward....',reward_sum)

In [60]:
agent = Agent()

In [61]:
train(env,agent)

reward.... 44.0
reward.... 247.0
reward.... 258.0
reward.... 174.0
reward.... 95.0
reward.... 287.0
reward.... 113.0
reward.... 208.0
reward.... 258.0
reward.... 201.0
reward.... 172.0
reward.... 200.0
reward.... 144.0
reward.... 223.0
reward.... 167.0
reward.... 105.0
reward.... 97.0
reward.... 78.0
reward.... 183.0
reward.... 148.0
reward.... 439.0
reward.... 148.0
reward.... 159.0
reward.... 505.0
reward.... 1000.0
reward.... 171.0
reward.... 189.0
reward.... 413.0
reward.... 151.0
reward.... 287.0
reward.... 1000.0
reward.... 1000.0
reward.... 503.0
reward.... 460.0
reward.... 1000.0
reward.... 884.0
reward.... 1000.0
reward.... 1000.0
reward.... 395.0
reward.... 575.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 502.0
reward.... 443.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 1000.0
reward.... 100