Skip to content

Commit

Permalink
Temp add log in collector
Browse files Browse the repository at this point in the history
  • Loading branch information
AltmanD committed Sep 18, 2023
1 parent 4eebb53 commit 0624030
Showing 1 changed file with 93 additions and 8 deletions.
101 changes: 93 additions & 8 deletions ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Callable, List, Tuple, Any
from typing import TYPE_CHECKING, Callable, List, Tuple, Any, Optional
from functools import reduce
import treetensor.torch as ttorch
import numpy as np
from ding.utils import EasyTimer, allreduce_data, build_logger
from ding.envs import BaseEnvManager
from ding.policy import Policy
from ding.torch_utils import to_ndarray, get_shape0
Expand Down Expand Up @@ -83,7 +85,7 @@ def _inference(ctx: "OnlineRLContext"):
return _inference


def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList) -> Callable:
def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList, collect_print_freq=100, tb_logger: 'SummaryWriter' = None, exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector') -> Callable:
"""
Overview:
The middleware that executes the transition process in the env.
Expand All @@ -98,6 +100,63 @@ def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList)

env_episode_id = [_ for _ in range(env.env_num)]
current_id = env.env_num
timer = EasyTimer()
last_train_iter = 0
total_envstep_count = 0
total_episode_count = 0
total_duration = 0
total_train_sample_count = 0
env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)}
episode_info = []

if tb_logger is not None:
logger, _ = build_logger(
path='./{}/log/{}'.format(exp_name, instance_name),
name=instance_name,
need_tb=False
)
tb_logger = tb_logger
else:
logger, tb_logger = build_logger(
path='./{}/log/{}'.format(exp_name, instance_name), name=instance_name
)

def output_log(train_iter: int) -> None:
"""
Overview:
Print the output log information. You can refer to the docs of `Best Practice` to understand \
the training generated logs and tensorboards.
Arguments:
- train_iter (:obj:`int`): the number of training iteration.
"""
nonlocal episode_info, timer, total_episode_count, total_duration, total_envstep_count, total_train_sample_count, last_train_iter
if (train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0:
last_train_iter = train_iter
episode_count = len(episode_info)
envstep_count = sum([d['step'] for d in episode_info])
train_sample_count = sum([d['train_sample'] for d in episode_info])
duration = sum([d['time'] for d in episode_info])
episode_return = [d['reward'] for d in episode_info]
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
'train_sample_count': train_sample_count,
'avg_envstep_per_episode': envstep_count / episode_count,
'avg_sample_per_episode': train_sample_count / episode_count,
'avg_envstep_per_sec': envstep_count / duration,
'avg_train_sample_per_sec': train_sample_count / duration,
'avg_episode_per_sec': episode_count / duration,
'reward_mean': np.mean(episode_return),
'reward_std': np.std(episode_return),
'reward_max': np.max(episode_return),
'reward_min': np.min(episode_return),
'total_envstep_count': total_envstep_count,
'total_train_sample_count': total_train_sample_count,
'total_episode_count': total_episode_count,
# 'each_reward': episode_return,
}
episode_info.clear()
logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))

def _rollout(ctx: "OnlineRLContext"):
"""
Expand All @@ -113,22 +172,48 @@ def _rollout(ctx: "OnlineRLContext"):
trajectory stops.
"""

nonlocal current_id
nonlocal current_id, env_info, episode_info, timer, total_episode_count, total_duration, total_envstep_count, total_train_sample_count, last_train_iter
timesteps = env.step(ctx.action)
ctx.env_step += len(timesteps)
timesteps = [t.tensor() for t in timesteps]
# TODO abnormal env step

interaction_duration = timer.value / len(timesteps)
for i, timestep in enumerate(timesteps):
transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep)
transition = ttorch.as_tensor(transition) # TBD
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]])
transitions.append(timestep.env_id, transition)
with timer:
transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep)
transition = ttorch.as_tensor(transition) # TBD
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]])
transitions.append(timestep.env_id, transition)
env_info[timestep.env_id]['step'] += 1

env_info[timestep.env_id]['time'] += timer.value + interaction_duration
if timestep.done:
info = {
'reward': timestep.info['eval_episode_return'],
'time': env_info[timestep.env_id]['time'],
'step': env_info[timestep.env_id]['step'],
'train_sample': env_info[timestep.env_id]['train_sample'],
}

episode_info.append(info)
policy.reset([timestep.env_id])
env_episode_id[timestep.env_id] = current_id
current_id += 1
ctx.env_episode += 1

collected_duration = sum([d['time'] for d in episode_info])
collected_sample = allreduce_data(collected_sample, 'sum')
collected_step = allreduce_data(collected_step, 'sum')
collected_episode = allreduce_data(collected_episode, 'sum')
collected_duration = allreduce_data(collected_duration, 'sum')
total_envstep_count += collected_step
total_episode_count += collected_episode
total_duration += collected_duration
total_train_sample_count += collected_sample

output_log(ctx.train_iter)
# TODO log

return _rollout

0 comments on commit 0624030

Please sign in to comment.