Skip to content

Commit

Permalink
feature(nyz): add ppof cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jan 6, 2023
1 parent 3a9f213 commit dfae2cc
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 1 deletion.
3 changes: 3 additions & 0 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
Input of ctx:
- env_step (:obj:`int`): The env steps which will increase during collection.
"""
device = self.policy._device
old = ctx.env_step
target_size = self.n_sample * self.unroll_len

Expand All @@ -113,7 +114,9 @@ def __call__(self, ctx: "OnlineRLContext") -> None:

while True:
obs = ttorch.as_tensor(self.env.ready_obs).to(dtype=ttorch.float32)
obs = obs.to(device)
inference_output = self.policy.collect(obs, **ctx.collect_kwargs)
inference_output = inference_output.cpu()
action = inference_output.action.numpy()
timesteps = self.env.step(action)
ctx.env_step += len(timesteps)
Expand Down
3 changes: 3 additions & 0 deletions ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,14 @@ def _evaluate(ctx: "OnlineRLContext"):
else:
env.reset()
policy.reset()
device = policy._device
eval_monitor = VectorEvalMonitor(env.env_num, n_evaluator_episode)

while not eval_monitor.is_finished():
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = obs.to(device)
inference_output = policy.eval(obs)
inference_output = inference_output.cpu()
if render:
eval_monitor.update_video(env.ready_imgs)
eval_monitor.update_output(inference_output)
Expand Down
3 changes: 2 additions & 1 deletion ding/framework/middleware/functional/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):

if ctx.train_data is None: # no enough data from data fetcher
return
train_output = policy.forward(ctx.train_data)
data = ctx.train_data.to(policy._device)
train_output = policy.forward(data)
nonlocal last_log_iter
if ctx.train_iter - last_log_iter >= log_freq:
loss = np.mean([o['total_loss'] for o in train_output])
Expand Down
1 change: 1 addition & 0 deletions ding/policy/ppof.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[st
self._model = model
if self._cfg.cuda and torch.cuda.is_available():
self._device = 'cuda'
self._model.cuda()
else:
self._device = 'cpu'
assert self._cfg.action_space in ["continuous", "discrete", "hybrid", 'multi_discrete']
Expand Down

0 comments on commit dfae2cc

Please sign in to comment.