## 这是所有算法中第一个例子
一般的训练程序分为两部分
1.主程序(一个函数),创建env,创建agent,然后agent与env循环交互并进行学习
2.Agent(一个类),超参数、定义网络、控制器（选择动作）,以及将它们综合起来的learn方法.
此外还有一些其他方法比如DQN的ReplayBuffer,如果算法比较复杂,可以将一些层次抽象出来,
比如网络结构和ReplayBuffer和一些特有算法,在PRAL框架中就是固定把Algorithm抽象出来作为
一个基础层次

另外有一个难点的就是数据结构的变化,需要掌握List,Dic,Tensor,Numpy的常见用法


In [1]:
#导入gym和torch相关包
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
#Hyperparameters
learning_rate = 0.0005 #学习率
gamma         = 0.98   #
lmbda         = 0.95
eps_clip      = 0.1
K_epoch       = 3
T_horizon     = 20


## 定义PPO架构,继承nn.Module的主要作用是可以用self.parameters()拿到参数放入优化器
class PPO(nn.Module):
    def __init__(self):
        super().__init__()
        self.data = []      #定义了一个简单的Buffer

        self.fc1 = nn.Linear(4,256)  #pi网络与v网络共同的输入全连接层
        self.fc_pi = nn.Linear(256,2) #随机策略梯度输出层
        self.fc_v = nn.Linear(256,1)  #v网络输出层
        self.optimizer = optim.Adam(self.parameters(),lr=learning_rate)  #优化器

    #policy函数
    #输入观测值x
    #输出动作空间概率，从而选择最优action
    def pi(self, x, softmax_dim = 0): 
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    #value函数
    #输入观测值x
    #输出x状态下value的预测值（reward）,提供给policy函数作为参考值
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v

    #把交互数据存入buffer
    def put_data(self, transition):
        self.data.append(transition)

    #把数据形成batch，训练模型时需要一个一个batch输入模型
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, prob_a, done = transition
            
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            prob_a_lst.append([prob_a])
            done_mask = 0 if done else 1
            done_lst.append([done_mask])
            
        s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                          torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
                                          torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
        self.data = []
        return s, a, r, s_prime, done_mask, prob_a

     #训练模型
    
    def train_net(self):
        #make batch 数据，喂给模型
        s, a, r, s_prime, done_mask, prob_a = self.make_batch()

        for i in range(K_epoch): #K_epoch：训练多少个epoch
            #计算td_error 误差，value模型的优化目标就是尽量减少td_error
            td_target = r + gamma * self.v(s_prime) * done_mask
            delta = td_target - self.v(s)
            delta = delta.detach().numpy()

            #计算advantage:
            #即当前策略比一般策略（baseline）要好多少
            #policy的优化目标就是让当前策略比baseline尽量好，但是每次更新时又不能偏离太多，所以后面会有个clip
            advantage_lst = []
            advantage = 0.0
            for delta_t in delta[::-1]:
                advantage = gamma * lmbda * advantage + delta_t[0]
                advantage_lst.append([advantage])
            advantage_lst.reverse()
            advantage = torch.tensor(advantage_lst, dtype=torch.float)

            #计算ratio 防止单词更新偏离太多
            pi = self.pi(s, softmax_dim=1)
            pi_a = pi.gather(1,a)
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b))

            #通过clip 保证ratio在（1-eps_clip, 1+eps_clip）范围内
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage
            #这里简化ppo，把policy loss和value loss放在一起计算
            loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach())

            #梯度优化
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

In [3]:
#主函数：简化ppo 这里先交互T_horizon个回合然后停下来学习训练，再交互，这样循环10000次
def main():
    #创建倒立摆环境
    env = gym.make('CartPole-v1')
    model = PPO()
    score = 0.0
    print_interval = 20

    #主循环
    for n_epi in range(10000):
        s = env.reset()
        done = False
        while not done:
            for t in range(T_horizon):
                #由当前policy模型输出最优action
                prob = model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()
                #用最优action进行交互
                s_prime, r, done, info = env.step(a)

                #存储交互数据，等待训练
                model.put_data((s, a, r/100.0, s_prime, prob[a].item(), done))
                s = s_prime

                score += r
                if done:
                    break

            #模型训练
            model.train_net()

        #打印每轮的学习成绩
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
            score = 0.0

    env.close()

if __name__ == '__main__':
    main()

# of episode :20, avg score : 29.5
# of episode :40, avg score : 33.0
# of episode :60, avg score : 54.2
# of episode :80, avg score : 84.4
# of episode :100, avg score : 72.8
# of episode :120, avg score : 42.0
# of episode :140, avg score : 80.0
# of episode :160, avg score : 136.6
# of episode :180, avg score : 155.8
# of episode :200, avg score : 246.1
# of episode :220, avg score : 293.2
# of episode :240, avg score : 323.4
# of episode :260, avg score : 249.2
# of episode :280, avg score : 272.1
# of episode :300, avg score : 236.9
# of episode :320, avg score : 236.1
# of episode :340, avg score : 182.1
# of episode :360, avg score : 271.9
# of episode :380, avg score : 325.8
# of episode :400, avg score : 278.9
# of episode :420, avg score : 219.3
# of episode :440, avg score : 371.1
# of episode :460, avg score : 186.8
# of episode :480, avg score : 168.7
# of episode :500, avg score : 393.6
# of episode :520, avg score : 277.9


KeyboardInterrupt: 