# Advantage Actor Critic
---
* detach 는 상수취급 해야할 때 적용

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import gym
import random

In [99]:
env = gym.make('CartPole-v1')

In [118]:
# hyper parameters
EPSILON = 1
ALPHA = .001
GAMMA = .95

In [119]:
class A2C(nn.Module):
    def __init__(self):
        super(A2C, self).__init__()
        self.fc_1 = nn.Linear(4, 128)
        self.fc_pi = nn.Linear(128, 2)
        self.fc_v = nn.Linear(128, 1)
        
    def pi(self, x):
        x = torch.relu(self.fc_1(x))
        x = torch.softmax(self.fc_pi(x), dim=0)
        return x
    
    def v(self, x):
        x = torch.relu(self.fc_1(x))
        x = self.fc_v(x)
        return x

In [120]:
def train(net, data, optimizer):
    s, a, r, s2, l_p = data
    v_target = r + GAMMA * net.v(torch.from_numpy(s2).float())
    advantage = v_target - net.v(torch.from_numpy(s).float())
    
    log_p = l_p
    loss = F.mse_loss(v_target.detach(), net.v(torch.from_numpy(s2).float())) -log_p * advantage.detach()
    optimizer.zero_grad()
    loss.mean().backward()
    optimizer.step()

In [121]:
ep = 1
total_ep = 1000
agent = A2C()
optimizer = optim.Adam(agent.parameters(), ALPHA)

while(ep < total_ep):
    done = False
    state = env.reset()
    total_reward = 0
    
    while(not done):
        
        prob = agent.pi(torch.from_numpy(state).float())
        m = Categorical(prob)
        
        if(random.random() < EPSILON):
            action = env.action_space.sample()
        else:
            action = m.sample().item()
        
        state_next, reward, done, _ = env.step(action)
        total_reward += reward

        if(done):
            reward = -100
        log_prob = torch.log(prob[action])
        transition = (state, action, reward, state_next, log_prob)
        train(agent, transition, optimizer)

        state = state_next
        
        if(done):
            ep += 1
            EPSILON = 1 / ((ep / 100) + 5)
            print(total_reward)

36.0
19.0
42.0
9.0
21.0
15.0
12.0
17.0
13.0
16.0
15.0
29.0
31.0
38.0
35.0
58.0
32.0
16.0
22.0
35.0
12.0
10.0
20.0
26.0
15.0
36.0
23.0
22.0
60.0
30.0
65.0
13.0
40.0
13.0
65.0
44.0
13.0
15.0
12.0
14.0
53.0
18.0
60.0
39.0
23.0
24.0
19.0
55.0
25.0
28.0
22.0
47.0
32.0
40.0
18.0
41.0
40.0
39.0
30.0
21.0
27.0
42.0
17.0
111.0
29.0
65.0
31.0
13.0
25.0
77.0
29.0
20.0
30.0
101.0
77.0
49.0
36.0
44.0
42.0
89.0
48.0
41.0
54.0
57.0
82.0
74.0
39.0
126.0
85.0
56.0
68.0
31.0
80.0
72.0
94.0
87.0
55.0
93.0
32.0
52.0
32.0
26.0
28.0
68.0
62.0
102.0
43.0
90.0
127.0
129.0
70.0
27.0
98.0
81.0
63.0
84.0
74.0
66.0
51.0
182.0
189.0
107.0
73.0
179.0
230.0
409.0
122.0
244.0
69.0
128.0
343.0
85.0
147.0
97.0
42.0
216.0
116.0
124.0
227.0
122.0
143.0
53.0
62.0
104.0
122.0
222.0
608.0
130.0
149.0
230.0
109.0
50.0
308.0
244.0
85.0
110.0
240.0
76.0
115.0
157.0
118.0
233.0
24.0
173.0
27.0
125.0
59.0
100.0
63.0
226.0
231.0
36.0
20.0
34.0
25.0
21.0
27.0
26.0
50.0
26.0
120.0
107.0
120.0
140.0
107.0
117.0
108.0
110.0
30.0
116.

KeyboardInterrupt: 