Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(lisong): fix icm/rnd+onppo config bugs and app_key env bugs #564

Merged
merged 25 commits into from
Mar 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0936556
polish(pu): polish icm_onppo_config
puyuan1996 Nov 21, 2022
c03fadd
polish(pu): polish icm rnd intrinsic_reward_weight and config
puyuan1996 Nov 29, 2022
c66a6df
Merge branch 'main' of https://github.com/opendilab/DI-engine into de…
puyuan1996 Nov 29, 2022
39365e1
style(pu): yapf format
puyuan1996 Nov 29, 2022
6a92682
Merge branch 'main' into dev-icm-onppo
puyuan1996 Dec 4, 2022
6b11aff
fix(lisong): fix config bugs and app_key env bugs
song2181 Dec 28, 2022
3907dd9
Merge https://github.com/opendilab/DI-engine into dev-icm-onppo
song2181 Dec 28, 2022
e9db602
Merge branch 'main' into dev-icm-onppo
PaParaZz1 Jan 2, 2023
0bf6f14
Merge branch 'main' into dev-icm-onppo
PaParaZz1 Jan 2, 2023
8d58e4c
Merge branch 'main' into dev-icm-onppo
puyuan1996 Jan 9, 2023
142fc44
polish(lisong): polish icm/rnd config and reward model
song2181 Jan 10, 2023
e2ae39c
Merge branch 'dev-icm-onppo' of github.com:song2181/DI-engine into de…
song2181 Jan 10, 2023
3731e02
fix(lisong): add viewsizerapper in minigrid_wrapper
song2181 Jan 11, 2023
c3a4710
Merge branch 'main' into dev-icm-onppo
puyuan1996 Jan 30, 2023
ab6a2a2
Merge branch 'main' into dev-icm-onppo
puyuan1996 Feb 7, 2023
d17f6b3
fix(lisong): add doorkey8x8 rnd+onppo config,save reward model, fix r…
song2181 Feb 9, 2023
84a19c4
Merge branch 'main' of https://github.com/opendilab/DI-engine into de…
song2181 Feb 9, 2023
9906d1b
Merge branch 'dev-icm-onppo' of github.com:song2181/DI-engine into de…
song2181 Feb 9, 2023
fb37768
Merge branch 'main' into dev-icm-onppo
puyuan1996 Feb 13, 2023
d8191c2
fix(pu): fix augmented_reward tb_logging
puyuan1996 Feb 13, 2023
be4fdc5
Merge branch 'dev-icm-onppo' of https://github.com/song2181/DI-engine…
puyuan1996 Feb 13, 2023
153b8db
feat(lisong): add noisy-tv env in minigrid
song2181 Feb 20, 2023
0fc5f93
Merge branch 'dev-icm-onppo' of https://github.com/puyuan1996/DI-engi…
song2181 Feb 20, 2023
1957316
Merge branch 'dev-icm-onppo' of github.com:song2181/DI-engine into de…
song2181 Feb 20, 2023
f6985c5
fix(lisong): modify noisy_tv env
song2181 Feb 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion ding/entry/serial_entry_reward_model_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from .utils import random_collect
import numpy as np
from ding.utils import save_file


def save_reward_model(path, reward_model, weights_name='best'):
path = os.path.join(path, 'reward_model', 'ckpt')
if not os.path.exists(path):
try:
os.makedirs(path)
except FileExistsError:
pass
path = os.path.join(path, 'ckpt_{}.pth.tar'.format(weights_name))
state_dict = reward_model.reward_model.state_dict()
save_file(path, state_dict)
print('Saved reward model ckpt in {}'.format(path))


def serial_pipeline_reward_model_offpolicy(
Expand Down Expand Up @@ -87,11 +102,17 @@ def serial_pipeline_reward_model_offpolicy(
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
count = 0
best_reward = -np.inf
while True:
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
reward_mean = np.array([r['eval_episode_return'] for r in reward]).mean()
if reward_mean >= best_reward:
save_reward_model(cfg.exp_name, reward_model, 'best')
best_reward = reward_mean
if stop:
break
new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1)
Expand All @@ -103,7 +124,9 @@ def serial_pipeline_reward_model_offpolicy(
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# update reward_model
reward_model.train()
reward_model.clear_data()
# clear buffer per fix iters to make sure replay buffer's data count isn't too few.
if count % cfg.reward_model.clear_buffer_per_iters == 0:
reward_model.clear_data()
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
Expand All @@ -122,7 +145,9 @@ def serial_pipeline_reward_model_offpolicy(
replay_buffer.update(learner.priority_info)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
count += 1

# Learner's after_run hook.
learner.call_hook('after_run')
save_reward_model(cfg.exp_name, reward_model, 'last')
return policy
21 changes: 21 additions & 0 deletions ding/entry/serial_entry_reward_model_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from .utils import random_collect
from ding.utils import save_file
import numpy as np


def save_reward_model(path, reward_model, weights_name='best'):
path = os.path.join(path, 'reward_model', 'ckpt')
if not os.path.exists(path):
try:
os.makedirs(path)
except FileExistsError:
pass
path = os.path.join(path, 'ckpt_{}.pth.tar'.format(weights_name))
state_dict = reward_model.reward_model.state_dict()
save_file(path, state_dict)
print('Saved reward model ckpt in {}'.format(path))


def serial_pipeline_reward_model_onpolicy(
Expand Down Expand Up @@ -88,11 +103,16 @@ def serial_pipeline_reward_model_onpolicy(
if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
count = 0
best_reward = -np.inf
while True:
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
reward_mean = np.array([r['eval_episode_return'] for r in reward]).mean()
if reward_mean >= best_reward:
save_reward_model(cfg.exp_name, reward_model, 'best')
best_reward = reward_mean
if stop:
break
new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1)
Expand Down Expand Up @@ -127,4 +147,5 @@ def serial_pipeline_reward_model_onpolicy(

# Learner's after_run hook.
learner.call_hook('after_run')
save_reward_model(cfg.exp_name, reward_model, 'last')
return policy
46 changes: 37 additions & 9 deletions ding/reward_model/icm_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ class ICMRewardModel(BaseRewardModel):
update_per_collect=100,
# (float) the importance weight of the forward and reverse loss
reverse_scale=1,
intrinsic_reward_weight=0.003, # 1/300
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments for each fields in default config

extrinsic_reward_norm=True,
extrinsic_reward_norm_max=1,
)

def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
Expand All @@ -171,8 +174,12 @@ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') ->
self.ce = nn.CrossEntropyLoss(reduction="mean")
self.forward_mse = nn.MSELoss(reduction='none')
self.reverse_scale = config.reverse_scale
self.res = nn.Softmax(dim=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use softmax here if we only need to sample action by argmax operation

self.estimate_cnt_icm = 0
self.train_cnt_icm = 0

def _train(self) -> None:
self.train_cnt_icm += 1
train_data_list = [i for i in range(0, len(self.train_states))]
train_data_index = random.sample(train_data_list, self.cfg.batch_size)
data_states: list = [self.train_states[i] for i in train_data_index]
Expand All @@ -187,6 +194,13 @@ def _train(self) -> None:
)
inverse_loss = self.ce(pred_action_logit, data_actions.long())
forward_loss = self.forward_mse(pred_next_state_feature, real_next_state_feature.detach()).mean()
self.tb_logger.add_scalar('icm_reward/forward_loss', forward_loss, self.train_cnt_icm)
self.tb_logger.add_scalar('icm_reward/inverse_loss', inverse_loss, self.train_cnt_icm)
action = torch.argmax(self.res(pred_action_logit), -1)
accuracy = torch.sum(action == data_actions.squeeze(-1)).item() / data_actions.shape[0]
self.tb_logger.add_scalar('icm_reward/action_accuracy', accuracy, self.train_cnt_icm)
loss = self.reverse_scale * inverse_loss + forward_loss
self.tb_logger.add_scalar('icm_reward/total_loss', loss, self.train_cnt_icm)
loss = self.reverse_scale * inverse_loss + forward_loss
self.opt.zero_grad()
loss.backward()
Expand All @@ -195,7 +209,6 @@ def _train(self) -> None:
def train(self) -> None:
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(self.cfg.update_per_collect):
self._train()
self.clear_data()

def estimate(self, data: list) -> List[Dict]:
# NOTE: deepcopy reward part of data is very important,
Expand All @@ -207,17 +220,32 @@ def estimate(self, data: list) -> List[Dict]:
actions = torch.cat(actions).to(self.device)
with torch.no_grad():
real_next_state_feature, pred_next_state_feature, _ = self.reward_model(states, next_states, actions)
reward = self.forward_mse(real_next_state_feature, pred_next_state_feature).mean(dim=1)
reward = (reward - reward.min()) / (reward.max() - reward.min() + 1e-8)
reward = reward.to(train_data_augmented[0]['reward'].device)
reward = torch.chunk(reward, reward.shape[0], dim=0)
for item, rew in zip(train_data_augmented, reward):
raw_icm_reward = self.forward_mse(real_next_state_feature, pred_next_state_feature).mean(dim=1)
self.estimate_cnt_icm += 1
self.tb_logger.add_scalar('icm_reward/raw_icm_reward_max', raw_icm_reward.max(), self.estimate_cnt_icm)
self.tb_logger.add_scalar('icm_reward/raw_icm_reward_mean', raw_icm_reward.mean(), self.estimate_cnt_icm)
self.tb_logger.add_scalar('icm_reward/raw_icm_reward_min', raw_icm_reward.min(), self.estimate_cnt_icm)
self.tb_logger.add_scalar('icm_reward/raw_icm_reward_std', raw_icm_reward.std(), self.estimate_cnt_icm)
icm_reward = (raw_icm_reward - raw_icm_reward.min()) / (raw_icm_reward.max() - raw_icm_reward.min() + 1e-8)
self.tb_logger.add_scalar('icm_reward/icm_reward_max', icm_reward.max(), self.estimate_cnt_icm)
self.tb_logger.add_scalar('icm_reward/icm_reward_mean', icm_reward.mean(), self.estimate_cnt_icm)
self.tb_logger.add_scalar('icm_reward/icm_reward_min', icm_reward.min(), self.estimate_cnt_icm)
self.tb_logger.add_scalar('icm_reward/icm_reward_std', icm_reward.std(), self.estimate_cnt_icm)
icm_reward = (raw_icm_reward - raw_icm_reward.min()) / (raw_icm_reward.max() - raw_icm_reward.min() + 1e-8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why norm twice here

icm_reward = icm_reward.to(self.device)
for item, icm_rew in zip(train_data_augmented, icm_reward):
if self.intrinsic_reward_type == 'add':
item['reward'] += rew
if self.cfg.extrinsic_reward_norm:
item['reward'] = item[
'reward'] / self.cfg.extrinsic_reward_norm_max + icm_rew * self.cfg.intrinsic_reward_weight
else:
item['reward'] = item['reward'] + icm_rew * self.cfg.intrinsic_reward_weight
elif self.intrinsic_reward_type == 'new':
item['intrinsic_reward'] = rew
item['intrinsic_reward'] = icm_rew
if self.cfg.extrinsic_reward_norm:
item['reward'] = item['reward'] / self.cfg.extrinsic_reward_norm_max
elif self.intrinsic_reward_type == 'assign':
item['reward'] = rew
item['reward'] = icm_rew

return train_data_augmented

Expand Down
51 changes: 24 additions & 27 deletions ding/reward_model/rnd_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .base_reward_model import BaseRewardModel
from ding.utils import RunningMeanStd
from ding.torch_utils.data_helper import to_tensor
import copy
import numpy as np


def collect_states(iterator):
Expand Down Expand Up @@ -60,19 +60,15 @@ class RndRewardModel(BaseRewardModel):
obs_norm=True,
obs_norm_clamp_min=-1,
obs_norm_clamp_max=1,
intrinsic_reward_weight=None,
# means the relative weight of RND intrinsic_reward.
# If intrinsic_reward_weight=None, we will automatically set it based on
# the absolute value of the difference between max and min extrinsic reward in the sampled mini-batch
# please refer to estimate() method for details.
intrinsic_reward_rescale=0.01,
# means the rescale value of RND intrinsic_reward only used when intrinsic_reward_weight is None
intrinsic_reward_weight=0.01,
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
extrinsic_reward_norm=True,
extrinsic_reward_norm_max=1,
)

def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWriter' = None) -> None: # noqa
super(RndRewardModel, self).__init__()
self.cfg = config
self.intrinsic_reward_rescale = self.cfg.intrinsic_reward_rescale
assert device == "cpu" or device.startswith("cuda")
self.device = device
if tb_logger is None: # TODO
Expand All @@ -87,6 +83,7 @@ def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWri
self.opt = optim.Adam(self.reward_model.predictor.parameters(), config.learning_rate)
self._running_mean_std_rnd_reward = RunningMeanStd(epsilon=1e-4)
self.estimate_cnt_rnd = 0
self.train_cnt_icm = 0
self._running_mean_std_rnd_obs = RunningMeanStd(epsilon=1e-4)

def _train(self) -> None:
Expand All @@ -102,13 +99,15 @@ def _train(self) -> None:

predict_feature, target_feature = self.reward_model(train_data)
loss = F.mse_loss(predict_feature, target_feature.detach())
self.tb_logger.add_scalar('rnd_reward/loss', loss, self.train_cnt_icm)
self.opt.zero_grad()
loss.backward()
self.opt.step()

def train(self) -> None:
for _ in range(self.cfg.update_per_collect):
self._train()
self.train_cnt_icm += 1

def estimate(self, data: list) -> List[Dict]:
"""
Expand All @@ -132,14 +131,16 @@ def estimate(self, data: list) -> List[Dict]:
self._running_mean_std_rnd_reward.update(mse.cpu().numpy())

# Note: according to the min-max normalization, transform rnd reward to [0,1]
rnd_reward = (mse - mse.min()) / (mse.max() - mse.min() + 1e-11)
rnd_reward = (mse - mse.min()) / (mse.max() - mse.min() + 1e-8)

# save the rnd_reward statistics into tb_logger
self.estimate_cnt_rnd += 1
self.tb_logger.add_scalar('rnd_reward/rnd_reward_max', rnd_reward.max(), self.estimate_cnt_rnd)
self.tb_logger.add_scalar('rnd_reward/rnd_reward_mean', rnd_reward.mean(), self.estimate_cnt_rnd)
self.tb_logger.add_scalar('rnd_reward/rnd_reward_min', rnd_reward.min(), self.estimate_cnt_rnd)
self.tb_logger.add_scalar('rnd_reward/rnd_reward_std', rnd_reward.std(), self.estimate_cnt_rnd)

rnd_reward = rnd_reward.to(train_data_augmented[0]['reward'].device)
rnd_reward = rnd_reward.to(self.device)
rnd_reward = torch.chunk(rnd_reward, rnd_reward.shape[0], dim=0)
"""
NOTE: Following normalization approach to extrinsic reward seems be not reasonable,
Expand All @@ -148,30 +149,26 @@ def estimate(self, data: list) -> List[Dict]:
# rewards = torch.stack([data[i]['reward'] for i in range(len(data))])
# rewards = (rewards - torch.min(rewards)) / (torch.max(rewards) - torch.min(rewards))

# TODO(pu): how to set intrinsic_reward_rescale automatically?
if self.cfg.intrinsic_reward_weight is None:
"""
NOTE: the following way of setting self.cfg.intrinsic_reward_weight is only suitable for the dense
reward env like lunarlander, not suitable for the dense reward env.
In sparse reward env, e.g. minigrid, if the agent reaches the goal, it obtain reward ~1, otherwise 0.
Thus, in sparse reward env, it's reasonable to set the intrinsic_reward_weight approximately equal to
the inverse of max_episode_steps.
"""
self.cfg.intrinsic_reward_weight = self.intrinsic_reward_rescale * max(
1,
abs(
max([train_data_augmented[i]['reward'] for i in range(len(train_data_augmented))]) -
min([train_data_augmented[i]['reward'] for i in range(len(train_data_augmented))])
)
)
for item, rnd_rew in zip(train_data_augmented, rnd_reward):
if self.intrinsic_reward_type == 'add':
item['reward'] = item['reward'] + rnd_rew * self.cfg.intrinsic_reward_weight
if self.cfg.extrinsic_reward_norm:
item['reward'] = item[
'reward'] / self.cfg.extrinsic_reward_norm_max + rnd_rew * self.cfg.intrinsic_reward_weight
else:
item['reward'] = item['reward'] + rnd_rew * self.cfg.intrinsic_reward_weight
elif self.intrinsic_reward_type == 'new':
item['intrinsic_reward'] = rnd_rew
if self.cfg.extrinsic_reward_norm:
item['reward'] = item['reward'] / self.cfg.extrinsic_reward_norm_max
elif self.intrinsic_reward_type == 'assign':
item['reward'] = rnd_rew

# save the augmented_reward statistics into tb_logger
rew = [item['reward'].cpu().numpy() for item in train_data_augmented]
self.tb_logger.add_scalar('augmented_reward/reward_max', np.max(rew), self.estimate_cnt_rnd)
self.tb_logger.add_scalar('augmented_reward/reward_mean', np.mean(rew), self.estimate_cnt_rnd)
self.tb_logger.add_scalar('augmented_reward/reward_min', np.min(rew), self.estimate_cnt_rnd)
self.tb_logger.add_scalar('augmented_reward/reward_std', np.std(rew), self.estimate_cnt_rnd)
return train_data_augmented

def collect_data(self, data: list) -> None:
Expand Down
6 changes: 4 additions & 2 deletions dizoo/minigrid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gym.envs.registration import register
from gymnasium.envs.registration import register
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

register(id='MiniGrid-AKTDT-7x7-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_7x7_1')

Expand All @@ -10,4 +10,6 @@

register(id='MiniGrid-AKTDT-19x19-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19')

register(id='MiniGrid-AKTDT-19x19-3-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19_3')
register(id='MiniGrid-AKTDT-19x19-3-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19_3')

register(id='MiniGrid-NoisyTV-v0', entry_point='dizoo.minigrid.envs:NoisyTVEnv')
Loading