# PPO - continuous action space
> Continuous action spaces use normal/gaussian distribution. In this case the model output is the mean+std which define the normal distribution to use for the action selection. The action is sampled from this distribution. The probability is in this case the probabilty of the actions value under the given normal distribution (I don't know the math for that but you can look it up).
---
* 연속액션 환경에서는 정규분포로 action 을 선택

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import gym
import random
import collections

In [2]:
env = gym.make('Pendulum-v0')

In [35]:
# hyper parameters
ALPHA = .0005
EPSILON = 1
T = 10 # T step 만큼 데이터 쌓고 학습할 것
LAMBDA = .95
K = 3
GAMMA = .99
e = .05

In [36]:
class PPO(nn.Module):
    def __init__(self):
        super(PPO, self).__init__()
        self.fc1 = nn.Linear(3, 128)
        self.fc_pi = nn.Linear(128, 32)
        self.fc_pi2 = nn.Linear(32, 1)
        self.fc_v = nn.Linear(128, 32)
        self.fc_v_2 = nn.Linear(32, 1)
        self.optimizer = optim.Adam(self.parameters(), ALPHA)
    
    # 출력이 Normal 분포 -> sampling 을 해서 실수값 뽑아서 사용해야함!
    def pi(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc_pi(x))
        prob = torch.tanh(self.fc_pi2(x))
        prob = torch.distributions.normal.Normal(prob, .1) # 평균 : prob, 분산 : .01
        return prob
    
    def v(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc_v(x))
        x = self.fc_v_2(x)
        return x

In [37]:
def train_net(net, data, optimizer):
    s, a, r, s2, d, prob = batch_factory(data)

    # epoch K 만큼
    for i in range(K):
        td_target = r + GAMMA * net.v(s2)
        delta = td_target - net.v(s)
        delta = delta.detach().numpy() # 1 step advantage
        advantage_lst = []
        advantage = 0.0

        # GAE 계산
        for delta_t in delta[::-1]:
            advantage = GAMMA * LAMBDA * advantage + delta_t[0]
            advantage_lst.append([advantage])
        advantage_lst.reverse()
        advantage = torch.tensor(advantage_lst, dtype=torch.float)

        pi_a = net.pi(s).sample()
        ratio = torch.exp(torch.log(pi_a) - torch.log(prob))

        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1 - e, 1 + e) * advantage
        loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(net.v(s) , td_target.detach())

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()

In [38]:
def batch_factory(memory):
    s_, a_, r_, s2_, d_, prob_ = [], [], [], [], [], []
    for s, a, r, s2, d, p in memory:
        s_.append(s)
        a_.append([a])
        r = -100 if d else r
        r_.append([r])
        s2_.append(s2)
        d = 0 if d else 1
        d_.append([d])
        prob_.append([p])
        
    s_ = torch.tensor(s_, dtype=torch.float)
    a_ = torch.tensor(a_)
    r_ = torch.tensor(r_, dtype=torch.float)
    s2_ = torch.tensor(s2_, dtype=torch.float)
    d_ = torch.tensor(d_, dtype=torch.float)
    prob_ = torch.tensor(prob_)
    
    return s_, a_, r_, s2_, d_, prob_


In [40]:
net = PPO()
ep = 1
total_ep = 10000
gamma = .95
total_reward = 0
data = []
epsilon = .1
optimizer = optim.Adam(net.parameters(), ALPHA)

while(ep < total_ep):
    done = False
    state = env.reset()
    while(not done):
        # T step 움직인 후 clipping - T 가 너무 크면 불안정??
        for t in range(T):
            prob = net.pi(torch.from_numpy(state).float())
            action = env.action_space.sample()
            #print(action)
            state_next, reward, done, _ = env.step(action)
            total_reward += reward
            data.append((state, action, reward/100.0, state_next, done, action))
            state = state_next
            if(done):
                break

        train_net(net, data, optimizer)
        data = []
        
    ep += 1
    if(ep%10 == 0):
        print(ep, total_reward/10.0)
        total_reward = 0

10 -1237.6364386075616
20 -1043.8922137266493
30 -1150.8968600248077
40 -1185.5823157431362
50 -1138.4329671282271
60 -1295.2223683316965
70 -1162.2388699547898
80 -1104.5940491548445
90 -1165.7624006159463
100 -1168.5612641646585
110 -1286.1790166919016
120 -1277.7257610524025
130 -1159.8409188495734
140 -1222.7197668823057
150 -1238.8814048214435
160 -1240.2987936256645
170 -1159.9631975307364
180 -1189.75953261942
190 -1126.2565989860407
200 -1150.4051968558783
210 -1126.682283970653
220 -1141.9709983210064
230 -1306.4914647852163
240 -1173.3628933216
250 -1178.7797471446115
260 -1199.2069013004823
270 -1365.470908529773
280 -1097.4698497700888
290 -1254.2666520994892
300 -1323.9119439986305
310 -1258.2200364034734
320 -1192.3119760919049
330 -1296.796959760763
340 -1253.675136864936
350 -1189.8647279849552
360 -1235.6765817651453
370 -1311.5256359973712
380 -1334.6938412084808
390 -1272.7938762211584
400 -1159.373705272553
410 -1140.7989696577329
420 -1211.9791548854141
430 -1306.0

KeyboardInterrupt: 