/
a3c.py
137 lines (123 loc) · 5.55 KB
/
a3c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from .a2c import *
from machin.parallel.server import PushPullGradServer
from torch.optim import Adam
class A3C(A2C):
"""
A3C framework.
"""
def __init__(self,
actor: Union[NeuralNetworkModule, nn.Module],
critic: Union[NeuralNetworkModule, nn.Module],
criterion: Callable,
grad_server: Tuple[PushPullGradServer,
PushPullGradServer],
*_,
entropy_weight: float = None,
value_weight: float = 0.5,
gradient_max: float = np.inf,
gae_lambda: float = 1.0,
discount: float = 0.99,
update_times: int = 50,
replay_size: int = 500000,
replay_device: Union[str, t.device] = "cpu",
replay_buffer: Buffer = None,
visualize: bool = False,
**__):
"""
See Also:
:class:`.A2C`
Note:
A3C algorithm relies on parameter servers to synchronize
parameters of actor and critic models across samplers (
interact with environment) and trainers (using samples
to train.
The parameter server type :class:`.PushPullGradServer`
used here utilizes gradients calculated by trainers:
1. perform a "sum" reduction process on the collected
gradients, then apply this reduced gradient to the model
managed by its primary reducer
2. push the parameters of this updated managed model to
a ordered key-value server so that all processes,
including samplers and trainers, can access the updated
parameters.
The ``grad_servers`` argument is a pair of accessors to
two :class:`.PushPullGradServerImpl` class.
Args:
actor: Actor network module.
critic: Critic network module.
optimizer: Optimizer used to optimize ``actor`` and ``critic``.
criterion: Criterion used to evaluate the value loss.
grad_server: Custom gradient sync server accessors, the first
server accessor is for actor, and the second one is for critic.
entropy_weight: Weight of entropy in your loss function, a positive
entropy weight will minimize entropy, while a negative one will
maximize entropy.
value_weight: Weight of critic value loss.
gradient_max: Maximum gradient.
gae_lambda: :math:`\\lambda` used in generalized advantage
estimation.
discount: :math:`\\gamma` used in the bellman function.
update_times: Number of update iterations per sample period. Buffer
will be cleared after ``update()``
replay_size: Replay buffer size. Not compatible with
``replay_buffer``.
replay_device: Device where the replay buffer locates on, Not
compatible with ``replay_buffer``.
replay_buffer: Custom replay buffer.
visualize: Whether visualize the network flow in the first pass.
"""
# Adam is just a placeholder here, the actual optimizer is
# set in parameter servers
super(A3C, self).__init__(actor, critic, Adam, criterion,
entropy_weight=entropy_weight,
value_weight=value_weight,
gradient_max=gradient_max,
gae_lambda=gae_lambda,
discount=discount,
update_times=update_times,
replay_size=replay_size,
replay_device=replay_device,
replay_buffer=replay_buffer,
visualize=visualize)
# disable local stepping
self.actor_optim.step = lambda: None
self.critic_optim.step = lambda: None
self.actor_grad_server, self.critic_grad_server = \
grad_server[0], grad_server[1]
self.is_syncing = True
def set_sync(self, is_syncing):
self.is_syncing = is_syncing
def manual_sync(self):
self.actor_grad_server.pull(self.actor)
self.critic_grad_server.pull(self.critic)
def act(self, state: Dict[str, Any], **__):
# DOC INHERITED
if self.is_syncing:
self.actor_grad_server.pull(self.actor)
return super(A3C, self).act(state)
def _eval_act(self,
state: Dict[str, Any],
action: Dict[str, Any],
**__):
# DOC INHERITED
if self.is_syncing:
self.actor_grad_server.pull(self.actor)
return super(A3C, self)._eval_act(state, action)
def _criticize(self, state: Dict[str, Any], *_, **__):
# DOC INHERITED
if self.is_syncing:
self.critic_grad_server.pull(self.critic)
return super(A3C, self)._criticize(state)
def update(self,
update_value=True,
update_policy=True,
concatenate_samples=True,
**__):
# DOC INHERITED
org_sync = self.is_syncing
self.is_syncing = False
super(A3C, self).update(update_value, update_policy,
concatenate_samples)
self.is_syncing = org_sync
self.actor_grad_server.push(self.actor)
self.critic_grad_server.push(self.critic)