-
Notifications
You must be signed in to change notification settings - Fork 455
/
ppo_trainer.py
1172 lines (985 loc) · 40.7 KB
/
ppo_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import os
import random
import time
from collections import defaultdict, deque
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import tqdm
from gym import spaces
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from habitat import Config, VectorEnv, logger
from habitat.tasks.rearrange.rearrange_sensors import GfxReplayMeasure
from habitat.tasks.rearrange.utils import write_gfx_replay
from habitat.utils import profiling_wrapper
from habitat.utils.render_wrapper import overlay_frame
from habitat.utils.visualizations.utils import observations_to_image
from habitat_baselines.common.base_trainer import BaseRLTrainer
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.construct_vector_env import construct_envs
from habitat_baselines.common.obs_transformers import (
apply_obs_transforms_batch,
apply_obs_transforms_obs_space,
get_active_obs_transforms,
)
from habitat_baselines.common.rollout_storage import RolloutStorage
from habitat_baselines.common.tensorboard_utils import (
TensorboardWriter,
get_writer,
)
from habitat_baselines.rl.ddppo.algo import DDPPO
from habitat_baselines.rl.ddppo.ddp_utils import (
EXIT,
get_distrib_size,
init_distrib_slurm,
is_slurm_batch_job,
load_resume_state,
rank0_only,
requeue_job,
save_resume_state,
)
from habitat_baselines.rl.ddppo.policy import ( # noqa: F401.
PointNavResNetPolicy,
)
from habitat_baselines.rl.hrl.hierarchical_policy import ( # noqa: F401.
HierarchicalPolicy,
)
from habitat_baselines.rl.ppo import PPO
from habitat_baselines.rl.ppo.policy import NetPolicy
from habitat_baselines.utils.common import (
batch_obs,
generate_video,
get_num_actions,
inference_mode,
is_continuous_action_space,
)
@baseline_registry.register_trainer(name="ddppo")
@baseline_registry.register_trainer(name="ppo")
class PPOTrainer(BaseRLTrainer):
r"""Trainer class for PPO algorithm
Paper: https://arxiv.org/abs/1707.06347.
"""
supported_tasks = ["Nav-v0"]
SHORT_ROLLOUT_THRESHOLD: float = 0.25
_is_distributed: bool
envs: VectorEnv
agent: PPO
actor_critic: NetPolicy
def __init__(self, config=None):
super().__init__(config)
self.actor_critic = None
self.agent = None
self.envs = None
self.obs_transforms = []
self._static_encoder = False
self._encoder = None
self._obs_space = None
# Distributed if the world size would be
# greater than 1
self._is_distributed = get_distrib_size()[2] > 1
@property
def obs_space(self):
if self._obs_space is None and self.envs is not None:
self._obs_space = self.envs.observation_spaces[0]
return self._obs_space
@obs_space.setter
def obs_space(self, new_obs_space):
self._obs_space = new_obs_space
def _all_reduce(self, t: torch.Tensor) -> torch.Tensor:
r"""All reduce helper method that moves things to the correct
device and only runs if distributed
"""
if not self._is_distributed:
return t
orig_device = t.device
t = t.to(device=self.device)
torch.distributed.all_reduce(t)
return t.to(device=orig_device)
def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None:
r"""Sets up actor critic and agent for PPO.
Args:
ppo_cfg: config node with relevant params
Returns:
None
"""
logger.add_filehandler(self.config.LOG_FILE)
policy = baseline_registry.get_policy(self.config.RL.POLICY.name)
observation_space = self.obs_space
self.obs_transforms = get_active_obs_transforms(self.config)
observation_space = apply_obs_transforms_obs_space(
observation_space, self.obs_transforms
)
self.actor_critic = policy.from_config(
self.config,
observation_space,
self.policy_action_space,
orig_action_space=self.orig_policy_action_space,
)
self.obs_space = observation_space
self.actor_critic.to(self.device)
if (
self.config.RL.DDPPO.pretrained_encoder
or self.config.RL.DDPPO.pretrained
):
pretrained_state = torch.load(
self.config.RL.DDPPO.pretrained_weights, map_location="cpu"
)
if self.config.RL.DDPPO.pretrained:
self.actor_critic.load_state_dict(
{ # type: ignore
k[len("actor_critic.") :]: v
for k, v in pretrained_state["state_dict"].items()
}
)
elif self.config.RL.DDPPO.pretrained_encoder:
prefix = "actor_critic.net.visual_encoder."
self.actor_critic.net.visual_encoder.load_state_dict(
{
k[len(prefix) :]: v
for k, v in pretrained_state["state_dict"].items()
if k.startswith(prefix)
}
)
if not self.config.RL.DDPPO.train_encoder:
self._static_encoder = True
for param in self.actor_critic.net.visual_encoder.parameters():
param.requires_grad_(False)
if self.config.RL.DDPPO.reset_critic:
nn.init.orthogonal_(self.actor_critic.critic.fc.weight)
nn.init.constant_(self.actor_critic.critic.fc.bias, 0)
self.agent = (DDPPO if self._is_distributed else PPO).from_config(
self.actor_critic, ppo_cfg
)
def _init_envs(self, config=None, is_eval: bool = False):
if config is None:
config = self.config
self.envs = construct_envs(
config,
workers_ignore_signals=is_slurm_batch_job(),
enforce_scenes_greater_eq_environments=is_eval,
)
def _init_train(self, resume_state=None):
if resume_state is None:
resume_state = load_resume_state(self.config)
if resume_state is not None:
self.config: Config = resume_state["config"]
if self.config.RL.DDPPO.force_distributed:
self._is_distributed = True
self._add_preemption_signal_handlers()
if self._is_distributed:
local_rank, tcp_store = init_distrib_slurm(
self.config.RL.DDPPO.distrib_backend
)
if rank0_only():
logger.info(
"Initialized DD-PPO with {} workers".format(
torch.distributed.get_world_size()
)
)
self.config.defrost()
self.config.TORCH_GPU_ID = local_rank
self.config.SIMULATOR_GPU_ID = local_rank
# Multiply by the number of simulators to make sure they also get unique seeds
self.config.TASK_CONFIG.SEED += (
torch.distributed.get_rank() * self.config.NUM_ENVIRONMENTS
)
self.config.freeze()
random.seed(self.config.TASK_CONFIG.SEED)
np.random.seed(self.config.TASK_CONFIG.SEED)
torch.manual_seed(self.config.TASK_CONFIG.SEED)
self.num_rollouts_done_store = torch.distributed.PrefixStore(
"rollout_tracker", tcp_store
)
self.num_rollouts_done_store.set("num_done", "0")
if rank0_only() and self.config.VERBOSE:
logger.info(f"config: {self.config}")
profiling_wrapper.configure(
capture_start_step=self.config.PROFILING.CAPTURE_START_STEP,
num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE,
)
self._init_envs()
action_space = self.envs.action_spaces[0]
self.policy_action_space = action_space
self.orig_policy_action_space = self.envs.orig_action_spaces[0]
if is_continuous_action_space(action_space):
# Assume ALL actions are NOT discrete
action_shape = (get_num_actions(action_space),)
discrete_actions = False
else:
# For discrete pointnav
action_shape = (1,)
discrete_actions = True
ppo_cfg = self.config.RL.PPO
if torch.cuda.is_available():
self.device = torch.device("cuda", self.config.TORCH_GPU_ID)
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
if rank0_only() and not os.path.isdir(self.config.CHECKPOINT_FOLDER):
os.makedirs(self.config.CHECKPOINT_FOLDER)
self._setup_actor_critic_agent(ppo_cfg)
if resume_state is not None:
self.agent.load_state_dict(resume_state["state_dict"])
self.agent.optimizer.load_state_dict(resume_state["optim_state"])
if self._is_distributed:
self.agent.init_distributed(find_unused_params=False) # type: ignore
logger.info(
"agent number of parameters: {}".format(
sum(param.numel() for param in self.agent.parameters())
)
)
obs_space = self.obs_space
if self._static_encoder:
self._encoder = self.actor_critic.net.visual_encoder
obs_space = spaces.Dict(
{
"visual_features": spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=self._encoder.output_shape,
dtype=np.float32,
),
**obs_space.spaces,
}
)
self._nbuffers = 2 if ppo_cfg.use_double_buffered_sampler else 1
self.rollouts = RolloutStorage(
ppo_cfg.num_steps,
self.envs.num_envs,
obs_space,
self.policy_action_space,
ppo_cfg.hidden_size,
num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
is_double_buffered=ppo_cfg.use_double_buffered_sampler,
action_shape=action_shape,
discrete_actions=discrete_actions,
)
self.rollouts.to(self.device)
observations = self.envs.reset()
batch = batch_obs(observations, device=self.device)
batch = apply_obs_transforms_batch(batch, self.obs_transforms) # type: ignore
if self._static_encoder:
with inference_mode():
batch["visual_features"] = self._encoder(batch)
self.rollouts.buffers["observations"][0] = batch # type: ignore
self.current_episode_reward = torch.zeros(self.envs.num_envs, 1)
self.running_episode_stats = dict(
count=torch.zeros(self.envs.num_envs, 1),
reward=torch.zeros(self.envs.num_envs, 1),
)
self.window_episode_stats = defaultdict(
lambda: deque(maxlen=ppo_cfg.reward_window_size)
)
self.env_time = 0.0
self.pth_time = 0.0
self.t_start = time.time()
@rank0_only
@profiling_wrapper.RangeContext("save_checkpoint")
def save_checkpoint(
self, file_name: str, extra_state: Optional[Dict] = None
) -> None:
r"""Save checkpoint with specified name.
Args:
file_name: file name for checkpoint
Returns:
None
"""
checkpoint = {
"state_dict": self.agent.state_dict(),
"config": self.config,
}
if extra_state is not None:
checkpoint["extra_state"] = extra_state
torch.save(
checkpoint, os.path.join(self.config.CHECKPOINT_FOLDER, file_name)
)
torch.save(
checkpoint,
os.path.join(self.config.CHECKPOINT_FOLDER, "latest.pth"),
)
def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict:
r"""Load checkpoint of specified path as a dict.
Args:
checkpoint_path: path of target checkpoint
*args: additional positional args
**kwargs: additional keyword args
Returns:
dict containing checkpoint info
"""
return torch.load(checkpoint_path, *args, **kwargs)
METRICS_BLACKLIST = {"top_down_map", "collisions.is_collision"}
@classmethod
def _extract_scalars_from_info(
cls, info: Dict[str, Any]
) -> Dict[str, float]:
result = {}
for k, v in info.items():
if not isinstance(k, str) or k in cls.METRICS_BLACKLIST:
continue
if isinstance(v, dict):
result.update(
{
k + "." + subk: subv
for subk, subv in cls._extract_scalars_from_info(
v
).items()
if isinstance(subk, str)
and k + "." + subk not in cls.METRICS_BLACKLIST
}
)
# Things that are scalar-like will have an np.size of 1.
# Strings also have an np.size of 1, so explicitly ban those
elif np.size(v) == 1 and not isinstance(v, str):
result[k] = float(v)
return result
@classmethod
def _extract_scalars_from_infos(
cls, infos: List[Dict[str, Any]]
) -> Dict[str, List[float]]:
results = defaultdict(list)
for i in range(len(infos)):
for k, v in cls._extract_scalars_from_info(infos[i]).items():
results[k].append(v)
return results
def _compute_actions_and_step_envs(self, buffer_index: int = 0):
num_envs = self.envs.num_envs
env_slice = slice(
int(buffer_index * num_envs / self._nbuffers),
int((buffer_index + 1) * num_envs / self._nbuffers),
)
t_sample_action = time.time()
# sample actions
with inference_mode():
step_batch = self.rollouts.buffers[
self.rollouts.current_rollout_step_idxs[buffer_index],
env_slice,
]
profiling_wrapper.range_push("compute actions")
(
values,
actions,
actions_log_probs,
recurrent_hidden_states,
) = self.actor_critic.act(
step_batch["observations"],
step_batch["recurrent_hidden_states"],
step_batch["prev_actions"],
step_batch["masks"],
)
self.pth_time += time.time() - t_sample_action
profiling_wrapper.range_pop() # compute actions
t_step_env = time.time()
for index_env, act in zip(
range(env_slice.start, env_slice.stop), actions.cpu().unbind(0)
):
if is_continuous_action_space(self.policy_action_space):
# Clipping actions to the specified limits
act = np.clip(
act.numpy(),
self.policy_action_space.low,
self.policy_action_space.high,
)
else:
act = act.item()
self.envs.async_step_at(index_env, act)
self.env_time += time.time() - t_step_env
self.rollouts.insert(
next_recurrent_hidden_states=recurrent_hidden_states,
actions=actions,
action_log_probs=actions_log_probs,
value_preds=values,
buffer_index=buffer_index,
)
def _collect_environment_result(self, buffer_index: int = 0):
num_envs = self.envs.num_envs
env_slice = slice(
int(buffer_index * num_envs / self._nbuffers),
int((buffer_index + 1) * num_envs / self._nbuffers),
)
t_step_env = time.time()
outputs = [
self.envs.wait_step_at(index_env)
for index_env in range(env_slice.start, env_slice.stop)
]
observations, rewards_l, dones, infos = [
list(x) for x in zip(*outputs)
]
self.env_time += time.time() - t_step_env
t_update_stats = time.time()
batch = batch_obs(observations, device=self.device)
batch = apply_obs_transforms_batch(batch, self.obs_transforms) # type: ignore
rewards = torch.tensor(
rewards_l,
dtype=torch.float,
device=self.current_episode_reward.device,
)
rewards = rewards.unsqueeze(1)
not_done_masks = torch.tensor(
[[not done] for done in dones],
dtype=torch.bool,
device=self.current_episode_reward.device,
)
done_masks = torch.logical_not(not_done_masks)
self.current_episode_reward[env_slice] += rewards
current_ep_reward = self.current_episode_reward[env_slice]
self.running_episode_stats["reward"][env_slice] += current_ep_reward.where(done_masks, current_ep_reward.new_zeros(())) # type: ignore
self.running_episode_stats["count"][env_slice] += done_masks.float() # type: ignore
for k, v_k in self._extract_scalars_from_infos(infos).items():
v = torch.tensor(
v_k,
dtype=torch.float,
device=self.current_episode_reward.device,
).unsqueeze(1)
if k not in self.running_episode_stats:
self.running_episode_stats[k] = torch.zeros_like(
self.running_episode_stats["count"]
)
self.running_episode_stats[k][env_slice] += v.where(done_masks, v.new_zeros(())) # type: ignore
self.current_episode_reward[env_slice].masked_fill_(done_masks, 0.0)
if self._static_encoder:
with inference_mode():
batch["visual_features"] = self._encoder(batch)
self.rollouts.insert(
next_observations=batch,
rewards=rewards,
next_masks=not_done_masks,
buffer_index=buffer_index,
)
self.rollouts.advance_rollout(buffer_index)
self.pth_time += time.time() - t_update_stats
return env_slice.stop - env_slice.start
@profiling_wrapper.RangeContext("_collect_rollout_step")
def _collect_rollout_step(self):
self._compute_actions_and_step_envs()
return self._collect_environment_result()
@profiling_wrapper.RangeContext("_update_agent")
def _update_agent(self):
ppo_cfg = self.config.RL.PPO
t_update_model = time.time()
with inference_mode():
step_batch = self.rollouts.buffers[
self.rollouts.current_rollout_step_idx
]
next_value = self.actor_critic.get_value(
step_batch["observations"],
step_batch["recurrent_hidden_states"],
step_batch["prev_actions"],
step_batch["masks"],
)
self.rollouts.compute_returns(
next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau
)
self.agent.train()
losses = self.agent.update(self.rollouts)
self.rollouts.after_update()
self.pth_time += time.time() - t_update_model
return losses
def _coalesce_post_step(
self, losses: Dict[str, float], count_steps_delta: int
) -> Dict[str, float]:
stats_ordering = sorted(self.running_episode_stats.keys())
stats = torch.stack(
[self.running_episode_stats[k] for k in stats_ordering], 0
)
stats = self._all_reduce(stats)
for i, k in enumerate(stats_ordering):
self.window_episode_stats[k].append(stats[i])
if self._is_distributed:
loss_name_ordering = sorted(losses.keys())
stats = torch.tensor(
[losses[k] for k in loss_name_ordering] + [count_steps_delta],
device="cpu",
dtype=torch.float32,
)
stats = self._all_reduce(stats)
count_steps_delta = int(stats[-1].item())
stats /= torch.distributed.get_world_size()
losses = {
k: stats[i].item() for i, k in enumerate(loss_name_ordering)
}
if self._is_distributed and rank0_only():
self.num_rollouts_done_store.set("num_done", "0")
self.num_steps_done += count_steps_delta
return losses
@rank0_only
def _training_log(
self, writer, losses: Dict[str, float], prev_time: int = 0
):
deltas = {
k: (
(v[-1] - v[0]).sum().item()
if len(v) > 1
else v[0].sum().item()
)
for k, v in self.window_episode_stats.items()
}
deltas["count"] = max(deltas["count"], 1.0)
writer.add_scalar(
"reward",
deltas["reward"] / deltas["count"],
self.num_steps_done,
)
# Check to see if there are any metrics
# that haven't been logged yet
metrics = {
k: v / deltas["count"]
for k, v in deltas.items()
if k not in {"reward", "count"}
}
for k, v in metrics.items():
writer.add_scalar(f"metrics/{k}", v, self.num_steps_done)
for k, v in losses.items():
writer.add_scalar(f"learner/{k}", v, self.num_steps_done)
fps = self.num_steps_done / ((time.time() - self.t_start) + prev_time)
writer.add_scalar("perf/fps", fps, self.num_steps_done)
# log stats
if self.num_updates_done % self.config.LOG_INTERVAL == 0:
logger.info(
"update: {}\tfps: {:.3f}\t".format(
self.num_updates_done,
fps,
)
)
logger.info(
"update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
"frames: {}".format(
self.num_updates_done,
self.env_time,
self.pth_time,
self.num_steps_done,
)
)
logger.info(
"Average window size: {} {}".format(
len(self.window_episode_stats["count"]),
" ".join(
"{}: {:.3f}".format(k, v / deltas["count"])
for k, v in deltas.items()
if k != "count"
),
)
)
def should_end_early(self, rollout_step) -> bool:
if not self._is_distributed:
return False
# This is where the preemption of workers happens. If a
# worker detects it will be a straggler, it preempts itself!
return (
rollout_step
>= self.config.RL.PPO.num_steps * self.SHORT_ROLLOUT_THRESHOLD
) and int(self.num_rollouts_done_store.get("num_done")) >= (
self.config.RL.DDPPO.sync_frac * torch.distributed.get_world_size()
)
@profiling_wrapper.RangeContext("train")
def train(self) -> None:
r"""Main method for training DD/PPO.
Returns:
None
"""
resume_state = load_resume_state(self.config)
self._init_train(resume_state)
count_checkpoints = 0
prev_time = 0
lr_scheduler = LambdaLR(
optimizer=self.agent.optimizer,
lr_lambda=lambda x: 1 - self.percent_done(),
)
if self._is_distributed:
torch.distributed.barrier()
if resume_state is not None:
self.agent.load_state_dict(resume_state["state_dict"])
self.agent.optimizer.load_state_dict(resume_state["optim_state"])
lr_scheduler.load_state_dict(resume_state["lr_sched_state"])
requeue_stats = resume_state["requeue_stats"]
self.env_time = requeue_stats["env_time"]
self.pth_time = requeue_stats["pth_time"]
self.num_steps_done = requeue_stats["num_steps_done"]
self.num_updates_done = requeue_stats["num_updates_done"]
self._last_checkpoint_percent = requeue_stats[
"_last_checkpoint_percent"
]
count_checkpoints = requeue_stats["count_checkpoints"]
prev_time = requeue_stats["prev_time"]
self.running_episode_stats = requeue_stats["running_episode_stats"]
self.window_episode_stats.update(
requeue_stats["window_episode_stats"]
)
ppo_cfg = self.config.RL.PPO
with (
get_writer(
self.config,
flush_secs=self.flush_secs,
purge_step=int(self.num_steps_done),
)
if rank0_only()
else contextlib.suppress()
) as writer:
while not self.is_done():
profiling_wrapper.on_start_step()
profiling_wrapper.range_push("train update")
if ppo_cfg.use_linear_clip_decay:
self.agent.clip_param = ppo_cfg.clip_param * (
1 - self.percent_done()
)
if rank0_only() and self._should_save_resume_state():
requeue_stats = dict(
env_time=self.env_time,
pth_time=self.pth_time,
count_checkpoints=count_checkpoints,
num_steps_done=self.num_steps_done,
num_updates_done=self.num_updates_done,
_last_checkpoint_percent=self._last_checkpoint_percent,
prev_time=(time.time() - self.t_start) + prev_time,
running_episode_stats=self.running_episode_stats,
window_episode_stats=dict(self.window_episode_stats),
)
save_resume_state(
dict(
state_dict=self.agent.state_dict(),
optim_state=self.agent.optimizer.state_dict(),
lr_sched_state=lr_scheduler.state_dict(),
config=self.config,
requeue_stats=requeue_stats,
),
self.config,
)
if EXIT.is_set():
profiling_wrapper.range_pop() # train update
self.envs.close()
requeue_job()
return
self.agent.eval()
count_steps_delta = 0
profiling_wrapper.range_push("rollouts loop")
profiling_wrapper.range_push("_collect_rollout_step")
for buffer_index in range(self._nbuffers):
self._compute_actions_and_step_envs(buffer_index)
for step in range(ppo_cfg.num_steps):
is_last_step = (
self.should_end_early(step + 1)
or (step + 1) == ppo_cfg.num_steps
)
for buffer_index in range(self._nbuffers):
count_steps_delta += self._collect_environment_result(
buffer_index
)
if (buffer_index + 1) == self._nbuffers:
profiling_wrapper.range_pop() # _collect_rollout_step
if not is_last_step:
if (buffer_index + 1) == self._nbuffers:
profiling_wrapper.range_push(
"_collect_rollout_step"
)
self._compute_actions_and_step_envs(buffer_index)
if is_last_step:
break
profiling_wrapper.range_pop() # rollouts loop
if self._is_distributed:
self.num_rollouts_done_store.add("num_done", 1)
losses = self._update_agent()
if ppo_cfg.use_linear_lr_decay:
lr_scheduler.step() # type: ignore
self.num_updates_done += 1
losses = self._coalesce_post_step(
losses,
count_steps_delta,
)
self._training_log(writer, losses, prev_time)
# checkpoint model
if rank0_only() and self.should_checkpoint():
self.save_checkpoint(
f"ckpt.{count_checkpoints}.pth",
dict(
step=self.num_steps_done,
wall_time=(time.time() - self.t_start) + prev_time,
),
)
count_checkpoints += 1
profiling_wrapper.range_pop() # train update
self.envs.close()
def _eval_checkpoint(
self,
checkpoint_path: str,
writer: TensorboardWriter,
checkpoint_index: int = 0,
) -> None:
r"""Evaluates a single checkpoint.
Args:
checkpoint_path: path of checkpoint
writer: tensorboard writer object for logging to tensorboard
checkpoint_index: index of cur checkpoint for logging
Returns:
None
"""
if self._is_distributed:
raise RuntimeError("Evaluation does not support distributed mode")
# Map location CPU is almost always better than mapping to a CUDA device.
if self.config.EVAL.SHOULD_LOAD_CKPT:
ckpt_dict = self.load_checkpoint(
checkpoint_path, map_location="cpu"
)
step_id = ckpt_dict["extra_state"]["step"]
print(step_id)
else:
ckpt_dict = {}
if self.config.EVAL.USE_CKPT_CONFIG:
config = self._setup_eval_config(ckpt_dict["config"])
else:
config = self.config.clone()
ppo_cfg = config.RL.PPO
config.defrost()
config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
config.freeze()
if (
len(self.config.VIDEO_OPTION) > 0
and self.config.VIDEO_RENDER_TOP_DOWN
):
config.defrost()
config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
config.freeze()
if (
len(config.VIDEO_RENDER_VIEWS) > 0
and len(self.config.VIDEO_OPTION) > 0
):
config.defrost()
for render_view in config.VIDEO_RENDER_VIEWS:
uuid = config.TASK_CONFIG.SIMULATOR[render_view].UUID
config.TASK_CONFIG.GYM.OBS_KEYS.append(uuid)
config.SENSORS.append(render_view)
config.freeze()
if config.VERBOSE:
logger.info(f"env config: {config}")
self._init_envs(config, is_eval=True)
action_space = self.envs.action_spaces[0]
self.policy_action_space = action_space
self.orig_policy_action_space = self.envs.orig_action_spaces[0]
if is_continuous_action_space(action_space):
# Assume NONE of the actions are discrete
action_shape = (get_num_actions(action_space),)
discrete_actions = False
else:
# For discrete pointnav
action_shape = (1,)
discrete_actions = True
self._setup_actor_critic_agent(ppo_cfg)
if self.agent.actor_critic.should_load_agent_state:
self.agent.load_state_dict(ckpt_dict["state_dict"])
self.actor_critic = self.agent.actor_critic
observations = self.envs.reset()
batch = batch_obs(observations, device=self.device)
batch = apply_obs_transforms_batch(batch, self.obs_transforms) # type: ignore
current_episode_reward = torch.zeros(
self.envs.num_envs, 1, device="cpu"
)
test_recurrent_hidden_states = torch.zeros(
self.config.NUM_ENVIRONMENTS,
self.actor_critic.num_recurrent_layers,
ppo_cfg.hidden_size,
device=self.device,
)
prev_actions = torch.zeros(
self.config.NUM_ENVIRONMENTS,
*action_shape,
device=self.device,
dtype=torch.long if discrete_actions else torch.float,
)
not_done_masks = torch.zeros(
self.config.NUM_ENVIRONMENTS,
1,
device=self.device,
dtype=torch.bool,
)
stats_episodes: Dict[
Any, Any
] = {} # dict of dicts that stores stats per episode
ep_eval_count: Dict[Any, int] = defaultdict(lambda: 0)
rgb_frames = [
[] for _ in range(self.config.NUM_ENVIRONMENTS)
] # type: List[List[np.ndarray]]
if len(self.config.VIDEO_OPTION) > 0:
os.makedirs(self.config.VIDEO_DIR, exist_ok=True)
number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
evals_per_ep = self.config.EVAL.EVALS_PER_EP
if number_of_eval_episodes == -1:
number_of_eval_episodes = sum(self.envs.number_of_episodes)
else:
total_num_eps = sum(self.envs.number_of_episodes)
# if total_num_eps is negative, it means the number of evaluation episodes is unknown
if total_num_eps < number_of_eval_episodes and total_num_eps > 1:
logger.warn(
f"Config specified {number_of_eval_episodes} eval episodes"
", dataset only has {total_num_eps}."
)
logger.warn(f"Evaluating with {total_num_eps} instead.")
number_of_eval_episodes = total_num_eps
else:
assert evals_per_ep == 1
assert (
number_of_eval_episodes > 0
), "You must specify a number of evaluation episodes with TEST_EPISODE_COUNT"
pbar = tqdm.tqdm(total=number_of_eval_episodes * evals_per_ep)
self.actor_critic.eval()
while (
len(stats_episodes) < (number_of_eval_episodes * evals_per_ep)
and self.envs.num_envs > 0
):