-
Notifications
You must be signed in to change notification settings - Fork 515
/
post_step.py
164 lines (138 loc) · 5.57 KB
/
post_step.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import logging
from typing import Optional, Union
import gym
import numpy as np
import reagent.types as rlt
import torch
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
from reagent.gym.types import PostStep
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
from reagent.training.rl_dataset import RLDataset
from reagent.training.trainer import Trainer
logger = logging.getLogger(__name__)
def add_replay_buffer_post_step(replay_buffer: ReplayBuffer):
"""
Simply add transitions to replay_buffer.
"""
def post_step(
obs: np.ndarray,
actor_output: rlt.ActorOutput,
reward: float,
terminal: bool,
possible_actions_mask: Optional[torch.Tensor],
) -> None:
action = actor_output.action.numpy()
log_prob = actor_output.log_prob.numpy()
if possible_actions_mask is None:
possible_actions_mask = torch.ones_like(actor_output.action).to(torch.bool)
possible_actions_mask = possible_actions_mask.numpy()
replay_buffer.add(
obs, action, reward, terminal, possible_actions_mask, log_prob.item()
)
return post_step
def train_with_replay_buffer_post_step(
replay_buffer: ReplayBuffer,
trainer: Trainer,
training_freq: int,
batch_size: int,
replay_burnin: Optional[int] = None,
trainer_preprocessor=None,
device: Union[str, torch.device] = "cpu",
) -> PostStep:
""" Called in post_step of agent to train based on replay buffer (RB).
Args:
trainer: responsible for having a .train method to train the model
trainer_preprocessor: format RB output for trainer.train
training_freq: how many steps in between trains
batch_size: how big of a batch to sample
replay_burnin: optional requirement for minimum size of RB before
training begins. (i.e. burn in this many frames)
"""
if isinstance(device, str):
device = torch.device(device)
_num_steps = 0
size_req = batch_size
if replay_burnin is not None:
size_req = max(size_req, replay_burnin)
if trainer_preprocessor is None:
trainer_preprocessor = make_replay_buffer_trainer_preprocessor(trainer, device)
def post_step(
obs: np.ndarray,
actor_output: rlt.ActorOutput,
reward: float,
terminal: bool,
possible_actions_mask: Optional[torch.Tensor],
) -> None:
nonlocal _num_steps
action = actor_output.action.numpy()
log_prob = actor_output.log_prob.numpy()
if possible_actions_mask is None:
possible_actions_mask = torch.ones_like(actor_output.action).to(torch.bool)
possible_actions_mask = possible_actions_mask.numpy()
replay_buffer.add(
obs, action, reward, terminal, possible_actions_mask, log_prob.item()
)
if replay_buffer.size >= size_req and _num_steps % training_freq == 0:
train_batch = replay_buffer.sample_transition_batch_tensor(
batch_size=batch_size
)
preprocessed_batch = trainer_preprocessor(train_batch)
trainer.train(preprocessed_batch)
_num_steps += 1
return
return post_step
def log_data_post_step(dataset: RLDataset, mdp_id: str, env: gym.Env) -> PostStep:
sequence_number = 0
def post_step(
obs: np.ndarray,
actor_output: rlt.ActorOutput,
reward: float,
terminal: bool,
possible_actions_mask: Optional[torch.Tensor],
) -> None:
""" log data into dataset """
nonlocal sequence_number
# actor_output = actor_output.squeeze(0)
if isinstance(env.action_space, gym.spaces.Discrete):
# TimelineOperator expects str for discrete action
action = str(actor_output.action.argmax().item())
if possible_actions_mask is None:
possible_actions_mask = torch.ones_like(actor_output.action).to(
torch.bool
)
if terminal:
possible_actions_mask = torch.zeros_like(actor_output.action).to(
torch.bool
)
elif isinstance(env.action_space, gym.spaces.Box):
# TimelineOperator expects map<long, double> for discrete action
assert actor_output.action.dim() == 1, f"action dim > 1 in {actor_output}"
action = {
i: actor_output.action[i].item()
for i in range(actor_output.action.size(0))
}
else:
raise NotImplementedError(f"{env.action_space} not supported!")
# TODO: make output of policy the desired type already (which means
# altering RB logic to store scalar types) What to do about continuous?
action_prob = actor_output.log_prob.exp().item()
possible_actions = None # TODO: this shouldn't be none if env passes it
time_diff = 1 # TODO: should this be hardcoded?
# Some environments return numpy instead
reward = float(reward)
dataset.insert_pre_timeline_format(
mdp_id=mdp_id,
sequence_number=sequence_number,
state=obs,
action=action,
reward=reward,
possible_actions=possible_actions,
time_diff=time_diff,
action_probability=action_prob,
possible_actions_mask=possible_actions_mask,
)
sequence_number += 1
return
return post_step