-
Notifications
You must be signed in to change notification settings - Fork 92
/
agent_config.py
219 lines (181 loc) · 7.5 KB
/
agent_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""util function to create a tf_agent."""
from typing import Any, Callable, Dict
import abc
import gin
import tensorflow as tf
from tf_agents.agents import tf_agent
from tf_agents.agents.behavioral_cloning import behavioral_cloning_agent
from tf_agents.agents.dqn import dqn_agent
from tf_agents.agents.ppo import ppo_agent
from tf_agents.specs import tensor_spec
from tf_agents.typing import types
from compiler_opt.rl import constant_value_network
from compiler_opt.rl.distributed import agent as distributed_ppo_agent
class AgentConfig(metaclass=abc.ABCMeta):
"""Agent creation and data processing hook-ups."""
def __init__(self, *, time_step_spec: types.NestedTensorSpec,
action_spec: types.NestedTensorSpec):
self._time_step_spec = time_step_spec
self._action_spec = action_spec
@property
def time_step_spec(self):
return self._time_step_spec
@property
def action_spec(self):
return self._action_spec
@abc.abstractmethod
def create_agent(self, preprocessing_layers: tf.keras.layers.Layer,
policy_network: types.Network) -> tf_agent.TFAgent:
"""Specific agent configs must implement this."""
raise NotImplementedError()
def get_policy_info_parsing_dict(
self) -> Dict[str, tf.io.FixedLenSequenceFeature]:
"""Return the parsing dict for the policy info."""
return {}
# pylint: disable=unused-argument
def process_parsed_sequence_and_get_policy_info(
self, parsed_sequence: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
"""Function to process parsed_sequence and to return policy_info.
Args:
parsed_sequence: A dict from feature_name to feature_value parsed from TF
SequenceExample.
Returns:
A nested policy_info for given agent.
"""
return {}
@gin.configurable
def create_agent(agent_config: AgentConfig,
preprocessing_layer_creator: Callable[[types.TensorSpec],
tf.keras.layers.Layer],
policy_network: types.Network):
"""Gin configurable wrapper of AgentConfig.create_agent.
Works around the fact that class members aren't gin-configurable."""
preprocessing_layers = tf.nest.map_structure(
preprocessing_layer_creator, agent_config.time_step_spec.observation)
return agent_config.create_agent(preprocessing_layers, policy_network)
@gin.configurable(module='agents')
class BCAgentConfig(AgentConfig):
"""Behavioral Cloning agent configuration."""
def create_agent(self, preprocessing_layers: tf.keras.layers.Layer,
policy_network: types.Network) -> tf_agent.TFAgent:
"""Creates a behavioral_cloning_agent."""
network = policy_network(
self.time_step_spec.observation,
self.action_spec,
preprocessing_layers=preprocessing_layers,
name='QNetwork')
return behavioral_cloning_agent.BehavioralCloningAgent(
self.time_step_spec,
self.action_spec,
cloning_network=network,
num_outer_dims=2)
@gin.configurable(module='agents')
class DQNAgentConfig(AgentConfig):
"""DQN agent configuration."""
def create_agent(self, preprocessing_layers: tf.keras.layers.Layer,
policy_network: types.Network) -> tf_agent.TFAgent:
"""Creates a dqn_agent."""
network = policy_network(
self.time_step_spec.observation,
self.action_spec,
preprocessing_layers=preprocessing_layers,
name='QNetwork')
return dqn_agent.DqnAgent(
self.time_step_spec, self.action_spec, q_network=network)
@gin.configurable(module='agents')
class PPOAgentConfig(AgentConfig):
"""PPO/Reinforce agent configuration."""
def create_agent(self, preprocessing_layers: tf.keras.layers.Layer,
policy_network: types.Network) -> tf_agent.TFAgent:
"""Creates a ppo_agent."""
actor_network = policy_network(
self.time_step_spec.observation,
self.action_spec,
preprocessing_layers=preprocessing_layers,
name='ActorDistributionNetwork')
critic_network = constant_value_network.ConstantValueNetwork(
self.time_step_spec.observation, name='ConstantValueNetwork')
return ppo_agent.PPOAgent(
self.time_step_spec,
self.action_spec,
actor_net=actor_network,
value_net=critic_network)
def get_policy_info_parsing_dict(
self) -> Dict[str, tf.io.FixedLenSequenceFeature]:
if tensor_spec.is_discrete(self._action_spec):
return {
'CategoricalProjectionNetwork_logits':
tf.io.FixedLenSequenceFeature(
shape=(self._action_spec.maximum - self._action_spec.minimum +
1),
dtype=tf.float32)
}
else:
return {
'NormalProjectionNetwork_scale':
tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32),
'NormalProjectionNetwork_loc':
tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32)
}
def process_parsed_sequence_and_get_policy_info(
self, parsed_sequence: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
if tensor_spec.is_discrete(self._action_spec):
policy_info = {
'dist_params': {
'logits': parsed_sequence['CategoricalProjectionNetwork_logits']
}
}
del parsed_sequence['CategoricalProjectionNetwork_logits']
else:
policy_info = {
'dist_params': {
'scale': parsed_sequence['NormalProjectionNetwork_scale'],
'loc': parsed_sequence['NormalProjectionNetwork_loc']
}
}
del parsed_sequence['NormalProjectionNetwork_scale']
del parsed_sequence['NormalProjectionNetwork_loc']
return policy_info
@gin.configurable(module='agents')
class DistributedPPOAgentConfig(PPOAgentConfig):
"""Distributed PPO/Reinforce agent configuration."""
def _create_agent_implt(self, preprocessing_layers: tf.keras.layers.Layer,
policy_network: types.Network) -> tf_agent.TFAgent:
"""Creates a ppo_distributed agent."""
actor_network = policy_network(
self.time_step_spec.observation,
self.action_spec,
preprocessing_layers=preprocessing_layers,
preprocessing_combiner=tf.keras.layers.Concatenate(),
name='ActorDistributionNetwork')
critic_network = constant_value_network.ConstantValueNetwork(
self.time_step_spec.observation, name='ConstantValueNetwork')
return distributed_ppo_agent.MLGOPPOAgent(
self.time_step_spec,
self.action_spec,
optimizer=tf.keras.optimizers.Adam(learning_rate=4e-4, epsilon=1e-5),
actor_net=actor_network,
value_net=critic_network,
value_pred_loss_coef=0.0,
entropy_regularization=0.01,
importance_ratio_clipping=0.2,
discount_factor=1.0,
gradient_clipping=1.0,
debug_summaries=False,
value_clipping=None,
aggregate_losses_across_replicas=True,
loss_scaling_factor=1.0)