-
Notifications
You must be signed in to change notification settings - Fork 63
/
dqn_agent.py
205 lines (172 loc) · 7.2 KB
/
dqn_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# -*- coding: utf-8 -*-
"""DQfD agent using demo agent for episodic tasks in OpenAI Gym.
- Author: Kh Kim, Curt Park
- Contact: kh.kim@medipixel.io, curt.park@medipixel.io
- Paper: https://arxiv.org/pdf/1704.03732.pdf (DQfD)
"""
import pickle
import time
from typing import Tuple
import numpy as np
import torch
from torch.nn.utils import clip_grad_norm_
import wandb
from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBuffer
from algorithms.common.buffer.replay_buffer import ReplayBuffer
import algorithms.common.helper_functions as common_utils
from algorithms.dqn.agent import DQNAgent
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class DQfDAgent(DQNAgent):
"""DQN interacting with environment.
Attribute:
memory (PrioritizedReplayBuffer): replay memory
"""
# pylint: disable=attribute-defined-outside-init
def _initialize(self):
"""Initialize non-common things."""
if not self.args.test:
# load demo replay memory
demos = self._load_demos()
if self.use_n_step:
demos, demos_n_step = common_utils.get_n_step_info_from_demo(
demos, self.hyper_params["N_STEP"], self.hyper_params["GAMMA"]
)
self.memory_n = ReplayBuffer(
buffer_size=self.hyper_params["BUFFER_SIZE"],
n_step=self.hyper_params["N_STEP"],
gamma=self.hyper_params["GAMMA"],
demo=demos_n_step,
)
# replay memory
self.beta = self.hyper_params["PER_BETA"]
self.memory = PrioritizedReplayBuffer(
self.hyper_params["BUFFER_SIZE"],
self.hyper_params["BATCH_SIZE"],
demo=demos,
alpha=self.hyper_params["PER_ALPHA"],
epsilon_d=self.hyper_params["PER_EPS_DEMO"],
)
def _load_demos(self) -> list:
"""Load expert's demonstrations."""
# load demo replay memory
with open(self.args.demo_path, "rb") as f:
demos = pickle.load(f)
return demos
def update_model(self) -> Tuple[torch.Tensor, ...]:
"""Train the model after each episode."""
experiences_1 = self.memory.sample()
weights, indices, eps_d = experiences_1[-3:]
actions = experiences_1[1]
# 1 step loss
gamma = self.hyper_params["GAMMA"]
dq_loss_element_wise, q_values = self._get_dqn_loss(experiences_1, gamma)
dq_loss = torch.mean(dq_loss_element_wise * weights)
# n step loss
if self.use_n_step:
experiences_n = self.memory_n.sample(indices)
gamma = self.hyper_params["GAMMA"] ** self.hyper_params["N_STEP"]
dq_loss_n_element_wise, q_values_n = self._get_dqn_loss(
experiences_n, gamma
)
# to update loss and priorities
q_values = 0.5 * (q_values + q_values_n)
dq_loss_element_wise += (
dq_loss_n_element_wise * self.hyper_params["LAMBDA1"]
)
dq_loss = torch.mean(dq_loss_element_wise * weights)
# supervised loss using demo for only demo transitions
demo_idxs = np.where(eps_d != 0.0)
n_demo = demo_idxs[0].size
if n_demo != 0: # if 1 or more demos are sampled
# get margin for each demo transition
action_idxs = actions[demo_idxs].long()
margin = torch.ones(q_values.size()) * self.hyper_params["MARGIN"]
margin[demo_idxs, action_idxs] = 0.0 # demo actions have 0 margins
margin = margin.to(device)
# calculate supervised loss
demo_q_values = q_values[demo_idxs, action_idxs].squeeze()
supervised_loss = torch.max(q_values + margin, dim=-1)[0]
supervised_loss = supervised_loss[demo_idxs] - demo_q_values
supervised_loss = torch.mean(supervised_loss) * self.hyper_params["LAMBDA2"]
else: # no demo sampled
supervised_loss = torch.zeros(1, device=device)
# q_value regularization
q_regular = torch.norm(q_values, 2).mean() * self.hyper_params["W_Q_REG"]
# total loss
loss = dq_loss + supervised_loss + q_regular
# train dqn
self.dqn_optimizer.zero_grad()
loss.backward()
clip_grad_norm_(self.dqn.parameters(), self.hyper_params["GRADIENT_CLIP"])
self.dqn_optimizer.step()
# update target networks
tau = self.hyper_params["TAU"]
common_utils.soft_update(self.dqn, self.dqn_target, tau)
# update priorities in PER
loss_for_prior = dq_loss_element_wise.detach().cpu().numpy().squeeze()
new_priorities = loss_for_prior + self.hyper_params["PER_EPS"]
new_priorities += eps_d
self.memory.update_priorities(indices, new_priorities)
# increase beta
fraction = min(float(self.i_episode) / self.args.episode_num, 1.0)
self.beta = self.beta + fraction * (1.0 - self.beta)
if self.hyper_params["USE_NOISY_NET"]:
self.dqn.reset_noise()
self.dqn_target.reset_noise()
return (
loss.item(),
dq_loss.item(),
supervised_loss.item(),
q_values.mean().item(),
n_demo,
)
def write_log(
self, i: int, avg_loss: np.ndarray, score: float, avg_time_cost: float
):
"""Write log about loss and score"""
print(
"[INFO] episode %d, episode step: %d, total step: %d, total score: %f\n"
"epsilon: %f, total loss: %f, dq loss: %f, supervised loss: %f\n"
"avg q values: %f, demo num in minibatch: %d (spent %.6f sec/step)\n"
% (
i,
self.episode_step,
self.total_step,
score,
self.epsilon,
avg_loss[0],
avg_loss[1],
avg_loss[2],
avg_loss[3],
avg_loss[4],
avg_time_cost,
)
)
if self.args.log:
wandb.log(
{
"score": score,
"epsilon": self.epsilon,
"total loss": avg_loss[0],
"dq loss": avg_loss[1],
"supervised loss": avg_loss[2],
"avg q values": avg_loss[3],
"demo num in minibatch": avg_loss[4],
"time per each step": avg_time_cost,
}
)
def pretrain(self):
"""Pretraining steps."""
pretrain_loss = list()
print("[INFO] Pre-Train %d step." % self.hyper_params["PRETRAIN_STEP"])
for i_step in range(1, self.hyper_params["PRETRAIN_STEP"] + 1):
t_begin = time.time()
loss = self.update_model()
t_end = time.time()
pretrain_loss.append(loss) # for logging
# logging
if i_step == 1 or i_step % 100 == 0:
avg_loss = np.vstack(pretrain_loss).mean(axis=0)
pretrain_loss.clear()
self.write_log(0, avg_loss, 0.0, t_end - t_begin)
print("[INFO] Pre-Train Complete!\n")