# PPO
---
![title](https://spinningup.openai.com/en/latest/_images/math/0a399dc49e3b45664a7edaf485ab5c23a7282f43.svg)
---
## torch 신경망 주의할 것
* 업데이트 할 파라미터 정확히 지정하기 - detach 로 학습할 파라미터 확실하게 분리하기

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import gym
import random
import collections

In [12]:
env = gym.make('CartPole-v1')

In [13]:
class PPO(nn.Module):
    def __init__(self):
        super(PPO, self).__init__()
        self.fc1 = nn.Linear(4, 32)
        self.fc_pi = nn.Linear(32, 2)
        self.fc_v = nn.Linear(32,1)
    
    def pi(self, x):
        x = self.fc1(x)
        prob = torch.softmax(self.fc_pi(x), dim = 0)
        return prob
    
    def v(self, x):
        x = self.fc1(x)
        x = self.fc_v(x)
        return x

In [28]:
def batch_factory(memory):
    s_, a_, r_, s2_, d_, prob_ = [], [], [], [], [], []
    for s, a, r, s2, d, p in memory:
        s_.append(s)
        a_.append([a])
        r_.append([r])
        s2_.append(s2)
        d_.append([d])
        prob_.append(p)
        
    s_ = torch.tensor(s_, dtype=torch.float)
    a_ = torch.tensor(a_)
    r_ = torch.tensor(r_)
    s2_ = torch.tensor(s2_, dtype=torch.float)
    d_ = torch.tensor(d_, dtype=torch.float)
    prob_ = torch.tensor(prob_)
    
    return s_, a_, r_, s2_, d_, prob_

In [29]:
def stack(memory, item):
    memory.append(item)

In [34]:
def train(memory, net):
    s, a, r, s2, d, prob = batch_factory(memory)
    
    target = r + gamma * net.v(s2)
    delta = net.v(s) - target
    delta = delta.detach().numpy()
    
    advantage_lst = []
    advantage = 0.0
    for delta_t in delta[::-1]:
        advantage = gamma * lmbda * advantage + delta_t[0]
        advantage_lst.append([advantage])
    advantage_lst.reverse()
    advantage = torch.tensor(advantage_lst, dtype=torch.float)

    pi = net.pi(s)
    pi_a = pi.gather(1, a)
    ratio = pi_a / prob
    
    surr1 = ratio * advantage
    surr2 = torch.clamp(ratio, 1 - e, 1 + e) * advantage
    loss = -torch.min(surr1, surr2) + F.mse_loss(net.v(s), target.detach())

In [36]:
net = PPO()
ep = 1
total_ep = 100
gamma = .95
e = .1

while(ep < total_ep):
    done = False
    state = env.reset()
    buffer = collections.deque()
    while(not done):
        prob = net.pi(torch.from_numpy(state).float())
        action = Categorical(prob).sample().item()
        state_next, reward, done, _ = env.step(action)
        stack(buffer, (state, action, reward, state_next, done, prob[action].item()))
        if(done):
            train(buffer, net)
            ep += 1