-
Notifications
You must be signed in to change notification settings - Fork 0
/
PPO.py
159 lines (114 loc) · 4.8 KB
/
PPO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from collections import namedtuple
from typing import List, Tuple, Dict
import gym
import numpy as np
import torch
import torch.nn as nn
from torch import optim, distributions
from torch.distributions import Categorical
from utils import RLAgent, train, torch_device, estimate_advantages, normalize, flatten
from utils.agents import MemoryAgent
# Hyper parameters
epochs = 10000
num_rollouts = 10
actor_hidden = 32
critic_hidden = 32
gamma = 0.99
lam = 0.95
train_pi_iters = 50
train_v_iters = 50
epsilon = 0.2 # clip epsilon
actor_lr = 3e-4
critic_lr = 1e-3
target_kl = 0.01
def discount_cumsum(arr, discount, last=0):
discounted = [0.] * len(arr)
for i in reversed(range(len(arr))):
last = discounted[i] = arr[i] + discount * last
return discounted
class PPOAgent(MemoryAgent):
def __init__(self, env):
super().__init__(env)
obs_size = env.observation_space.shape[0]
num_actions = env.action_space.n
self.actor = nn.Sequential(nn.Linear(obs_size, actor_hidden),
nn.ReLU(),
nn.Linear(actor_hidden, num_actions))
self.critic = nn.Sequential(nn.Linear(obs_size, actor_hidden),
nn.ReLU(),
nn.Linear(actor_hidden, 1))
self.opt_actor = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.opt_critic = optim.Adam(self.critic.parameters(), lr=critic_lr)
self.values = []
self.advantages = []
self.discounted = []
def get_action(self, state: np.ndarray) -> int:
return Categorical(logits=self.actor(torch.from_numpy(state).float().unsqueeze(0))).sample().item()
def save_step(self, action: int, reward: float, next_state: np.ndarray) -> None:
super().save_step(action, reward, next_state)
def on_trajectory_finished(self) -> None:
# estimate advantages and value targets
rewards = self.current_rewards
states = torch.as_tensor(self.current_states).float()
values = self.critic(states).flatten().tolist()
# GAE
deltas = [rewards[i] + gamma * values[i] - values[i] for i in range(len(rewards))]
advantages = discount_cumsum(deltas, gamma * lam)
# Rewards-to-go
discounted = discount_cumsum(self.current_rewards, gamma)
self.values += values
self.advantages += advantages
self.discounted += discounted
super().on_trajectory_finished()
def reset_memory(self):
super().reset_memory()
self.values = []
self.advantages = []
self.discounted = []
@property
def tensored_data(self) -> Dict[str, torch.Tensor]:
values = torch.as_tensor(self.values, dtype=torch.float)
advantages = normalize(torch.as_tensor(self.advantages, dtype=torch.float))
discounted = torch.as_tensor(self.discounted, dtype=torch.float)
return {**super().tensored_data,
'values': values, 'advantages': advantages, 'discounted': discounted}
def update_critic(self, advantages):
loss = .5 * (advantages ** 2).mean() # MSE
self.opt_critic.zero_grad()
loss.backward()
self.opt_critic.step()
def model(self, state):
dist = distributions.Categorical(logits=self.actor(state))
value = self.critic(state).squeeze(1)
return dist, value
def update(self) -> None:
self.actor.to(torch_device)
self.critic.to(torch_device)
data = self.tensored_data
for t in data.values():
assert not t.requires_grad # all of these tensors are treated as constants!
data = {k: v.to(torch_device) for k, v in data.items()}
actions = data['actions']
states = data['states']
advantages = data['advantages']
logits_old = Categorical(logits=self.actor(states)).log_prob(actions).detach() # constant!
for i in range(train_pi_iters):
logits = Categorical(logits=self.actor(states)).log_prob(actions)
ratio = (logits - logits_old).exp()
clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
loss = -(torch.min(ratio, clipped) * advantages).mean()
kl = (logits_old - logits).mean().item()
if kl > target_kl: # Early stopping
break
self.opt_actor.zero_grad()
loss.backward()
self.opt_actor.step()
for _ in range(train_v_iters):
loss = ((self.critic(data['states']) - data['discounted']) ** 2).mean()
self.opt_critic.zero_grad()
loss.backward()
self.opt_critic.step()
self.reset_memory()
self.actor.to('cpu')
self.critic.to('cpu')
train(gym.make('CartPole-v0'), PPOAgent, epochs=epochs, num_rollouts=num_rollouts)