Skip to content

Commit

Permalink
sync code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayilee65 committed May 9, 2024
1 parent 959959e commit 8f1e079
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 48 deletions.
1 change: 1 addition & 0 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def __init__(

self.sim_norm = SimNorm(simnorm_dim=group_size)
# self.sim_norm = nn.Sigmoid() # only for ablation
# self.sim_norm = nn.Softmax() # only for ablation

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
# "gru_gating": False,


"device": 'cuda:0',
"device": 'cuda:6',
'analysis_sim_norm': False,
'analysis_dormant_ratio': False,

Expand Down
5 changes: 3 additions & 2 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
nn.Linear(config.embed_dim, self.obs_per_embdding_dim),
self.sim_norm, # TODO
# nn.Sigmoid(), # only for ablation
# nn.Softmax(), # only for ablation
)
)
self.head_policy = Head(
Expand Down Expand Up @@ -472,8 +473,8 @@ def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_st
self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True)

buffer_action = buffer_action[:ready_env_num]
if ready_env_num<self.env_num:
print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}')
# if ready_env_num<self.env_num:
# print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}')
buffer_action = torch.from_numpy(np.array(buffer_action)).to(latent_state.device)
act_tokens = buffer_action.unsqueeze(-1)
outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, is_init_infer=True)
Expand Down
8 changes: 4 additions & 4 deletions lzero/policy/muzero_rnn_full_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ def _forward_collect(
self._collect_mcts_temperature = temperature
self.collect_epsilon = epsilon
active_collect_env_num = data.shape[0]
if active_collect_env_num != len(ready_env_id):
print('active_collect_env_num != len(ready_env_id)')
# if active_collect_env_num != len(ready_env_id):
# print('active_collect_env_num != len(ready_env_id)')

with torch.no_grad():
# data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)}
Expand Down Expand Up @@ -773,8 +773,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read
self._eval_model.env_num = self._cfg.model.evaluator_env_num
self._eval_model.eval()
active_eval_env_num = data.shape[0]
if active_eval_env_num != len(ready_env_id):
print('active_collect_env_num != len(ready_env_id)')
# if active_eval_env_num != len(ready_env_id):
# print('active_collect_env_num != len(ready_env_id)')
with torch.no_grad():
# data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)}
# network_output = self._eval_model.initial_inference(data)
Expand Down
6 changes: 3 additions & 3 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _init_learn(self) -> None:
update_type='momentum',
# update_kwargs={'theta': 0.01} # MOCO:0.001, DDPG:0.005, TD-MPC:0.01
update_kwargs={'theta': 0.05} # MOCO:0.001, DDPG:0.005, TD-MPC:0.01
# update_kwargs={'theta': 1} # MOCO:0.001, DDPG:0.005, TD-MPC:0.01
# update_kwargs={'theta': 0.999} # MOCO:0.001, DDPG:0.005, TD-MPC:0.01
)
self._learn_model = self._model

Expand Down Expand Up @@ -601,7 +601,7 @@ def _forward_collect(
temperature: float = 1,
to_play: List = [-1],
epsilon: float = 0.25,
ready_env_id: np.array = None,
ready_env_id: np.array = None
) -> Dict:
"""
Overview:
Expand Down Expand Up @@ -791,7 +791,7 @@ def _get_target_obs_index_in_step_k(self, step):
end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num)
return beg_index, end_index

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict:
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None) -> Dict:
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down
5 changes: 3 additions & 2 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,10 @@ def collect(self,
# ==============================================================
# policy forward
# ==============================================================
# policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon)
# print(f'ready_env_id:{ready_env_id}')
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id)
# policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) # for unizero
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id) # for muzero_rnn_allobs


actions_no_env_id = {k: v['action'] for k, v in policy_output.items()}
distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()}
Expand Down
4 changes: 2 additions & 2 deletions lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ def eval(
# ==============================================================
# policy forward
# ==============================================================
# policy_output = self._policy.forward(stack_obs, action_mask, to_play)
policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id)
# policy_output = self._policy.forward(stack_obs, action_mask, to_play) # for unizero
policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) # for muzero_rnn_allobs

actions_no_env_id = {k: v['action'] for k, v in policy_output.items()}
distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
# ==============================================================

atari_muzero_config = dict(
exp_name=f'data_paper_muzero_variants_0429/{env_id[:-14]}_muzero_stack1_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_sslw{ssl_loss_weight}_seed0',
exp_name=f'data_paper_muzero_variants_0510/{env_id[:-14]}_muzero_stack1_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_sslw{ssl_loss_weight}_seed0',
env=dict(
stop_value=int(1e6),
env_id=env_id,
Expand Down
262 changes: 262 additions & 0 deletions zoo/atari/config/atari_muzero_config_46464_stack1_context_20games.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(5)

# env_id = 'AlienNoFrameskip-v4'
# env_id = 'AmidarNoFrameskip-v4'
# env_id = 'AssaultNoFrameskip-v4'
# env_id = 'AsterixNoFrameskip-v4'

# env_id = 'BankHeistNoFrameskip-v4'
# env_id = 'BattleZoneNoFrameskip-v4'
# env_id = 'ChopperCommandNoFrameskip-v4'
# env_id = 'CrazyClimberNoFrameskip-v4'

# env_id = 'DemonAttackNoFrameskip-v4'
# env_id = 'FreewayNoFrameskip-v4'
env_id = 'FrostbiteNoFrameskip-v4'
# env_id = 'GopherNoFrameskip-v4'

# env_id = 'HeroNoFrameskip-v4'
# env_id = 'JamesbondNoFrameskip-v4'
# env_id = 'KangarooNoFrameskip-v4'
# env_id = 'KrullNoFrameskip-v4'

# env_id = 'KungFuMasterNoFrameskip-v4'
# env_id = 'PrivateEyeNoFrameskip-v4'
# env_id = 'RoadRunnerNoFrameskip-v4'
# env_id = 'UpNDownNoFrameskip-v4'

# env_id = 'PongNoFrameskip-v4' # 6
# env_id = 'MsPacmanNoFrameskip-v4' # 9
# env_id = 'QbertNoFrameskip-v4' # 6
# env_id = 'SeaquestNoFrameskip-v4' # 18
# env_id = 'BoxingNoFrameskip-v4' # 18
# env_id = 'BreakoutNoFrameskip-v4' # TODO: eval_sample, episode_steps

update_per_collect = None # for others
# model_update_ratio = 1.
model_update_ratio = 0.5


if env_id == 'AlienNoFrameskip-v4':
action_space_size = 18
elif env_id == 'AmidarNoFrameskip-v4':
action_space_size = 10
elif env_id == 'AssaultNoFrameskip-v4':
action_space_size = 7
elif env_id == 'AsterixNoFrameskip-v4':
action_space_size = 9
elif env_id == 'BankHeistNoFrameskip-v4':
action_space_size = 18
elif env_id == 'BattleZoneNoFrameskip-v4':
action_space_size = 18
elif env_id == 'ChopperCommandNoFrameskip-v4':
action_space_size = 18
elif env_id == 'CrazyClimberNoFrameskip-v4':
action_space_size = 9
elif env_id == 'DemonAttackNoFrameskip-v4':
action_space_size = 6
model_update_ratio = 0.25
elif env_id == 'FreewayNoFrameskip-v4':
action_space_size = 3
model_update_ratio = 0.25
elif env_id == 'FrostbiteNoFrameskip-v4':
action_space_size = 18
elif env_id == 'GopherNoFrameskip-v4':
action_space_size = 8
elif env_id == 'HeroNoFrameskip-v4':
action_space_size = 18
model_update_ratio = 0.25
elif env_id == 'JamesbondNoFrameskip-v4':
action_space_size = 18
elif env_id == 'KangarooNoFrameskip-v4':
action_space_size = 18
elif env_id == 'KrullNoFrameskip-v4':
action_space_size = 18
elif env_id == 'KungFuMasterNoFrameskip-v4':
action_space_size = 14
elif env_id == 'PrivateEyeNoFrameskip-v4':
action_space_size = 18
model_update_ratio = 0.25
elif env_id == 'RoadRunnerNoFrameskip-v4':
action_space_size = 18
elif env_id == 'UpNDownNoFrameskip-v4':
action_space_size = 6
elif env_id == 'PongNoFrameskip-v4':
action_space_size = 6
model_update_ratio = 0.25
elif env_id == 'MsPacmanNoFrameskip-v4':
action_space_size = 9
elif env_id == 'QbertNoFrameskip-v4':
action_space_size = 6
elif env_id == 'SeaquestNoFrameskip-v4':
action_space_size = 18
elif env_id == 'BoxingNoFrameskip-v4':
action_space_size = 18
model_update_ratio = 0.25
elif env_id == 'BreakoutNoFrameskip-v4':
action_space_size = 4


# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 50
batch_size = 256
# max_env_step = int(1e6)
max_env_step = int(5e5)
reanalyze_ratio = 0.
eps_greedy_exploration_in_collect = True
num_unroll_steps = 5
context_length_init = 1
# for debug ===========
# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# num_simulations = 2
# update_per_collect = 2
# model_update_ratio = 0.25
# batch_size = 2
# max_env_step = int(5e5)
# reanalyze_ratio = 0.
# eps_greedy_exploration_in_collect = True
# num_unroll_steps = 5
# context_length_init = 1
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

atari_muzero_config = dict(
exp_name=f'data_paper_muzero_atari-20-games_0510/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_seed0',
# exp_name=f'data_paper_learn-dynamics_atari-20-games_0424/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_analysis_dratio0025_seed0',
# exp_name=f'data_paper_muzero_variants_0422/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_conlen1_simnorm-cossim_adamw1e-4_seed0',
# exp_name=f'data_paper_muzero_variants_0422/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_conlen1_simnorm-cossim_adamw1e-4_seed0',
# exp_name=f'data_paper_muzero_variants_0422/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_conlen1_sslw2-cossim_adamw1e-4_seed0',
# exp_name=f'data_paper_muzero_variants_0422/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_conlen1_sslw2-cossim_sgd02_seed0',
env=dict(
stop_value=int(1e6),
env_id=env_id,
observation_shape=(3, 64, 64),
frame_stack_num=1,
gray_scale=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# TODO: debug
# collect_max_episode_steps=int(50),
# eval_max_episode_steps=int(50),
# TODO: for breakout
# collect_max_episode_steps=int(5e3), # for breakout
# eval_max_episode_steps=int(5e3), # for breakout
),
policy=dict(

learn=dict(
learner=dict(
hook=dict(
load_ckpt_before_run='',
log_show_after_iter=100,
save_ckpt_after_iter=500000, # default is 1000
save_ckpt_after_run=True,
),
),
),
cal_dormant_ratio=False, # TODO
analysis_sim_norm=False, # TODO
model=dict(
analysis_sim_norm=False, # TODO
image_channel=3,
observation_shape=(3, 64, 64),
frame_stack_num=1,
gray_scale=False,
action_space_size=action_space_size,
downsample=True,
self_supervised_learning_loss=True, # default is False
discrete_action_encoding_type='one_hot',
norm_type='BN',
reward_support_size=101,
value_support_size=101,
support_scale=50,
context_length_init=context_length_init, # NOTE:TODO num_unroll_steps
use_sim_norm=True,
# use_sim_norm_kl_loss=True, # TODO
use_sim_norm_kl_loss=False, # TODO
),
cuda=True,
env_type='not_board_games',
game_segment_length=400, # for collector orig
# game_segment_length=50, # for collector game_segment
random_collect_episode_num=0,
eps=dict(
eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect,
# need to dynamically adjust the number of decay steps
# according to the characteristics of the environment and the algorithm
type='linear',
start=1.,
end=0.01,
decay=int(2e4), # TODO: 20k
),
use_augmentation=True, # TODO
# use_augmentation=False,
use_priority=False,
model_update_ratio = model_update_ratio,
update_per_collect=update_per_collect,
batch_size=batch_size,
dormant_threshold=0.025,

optim_type='SGD', # for collector orig
lr_piecewise_constant_decay=True,
learning_rate=0.2,

# optim_type='AdamW', # for collector game_segment
# lr_piecewise_constant_decay=False,
# learning_rate=1e-4,

target_update_freq=100,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
ssl_loss_weight=2, # default is 0
n_episode=n_episode,
eval_freq=int(2e3),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
)
atari_muzero_config = EasyDict(atari_muzero_config)
main_config = atari_muzero_config

atari_muzero_create_config = dict(
env=dict(
type='atari_lightzero',
import_names=['zoo.atari.envs.atari_lightzero_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='muzero_context',
import_names=['lzero.policy.muzero_context'],
),
)
atari_muzero_create_config = EasyDict(atari_muzero_create_config)
create_config = atari_muzero_create_config

if __name__ == "__main__":
# from lzero.entry import train_muzero_context
# train_muzero_context([main_config, create_config], seed=0, max_env_step=max_env_step)

# Define a list of seeds for multiple runs
seeds = [0,1,2] # You can add more seed values here
# seeds = [2,1] # You can add more seed values here
# seeds = [1,2] # You can add more seed values here
# seeds = [2] # You can add more seed values here

for seed in seeds:
# Update exp_name to include the current seed TODO
main_config.exp_name=f'data_paper_muzero_atari-6-games_0510/{env_id[:-14]}_muzero_stack1_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_seed{seed}'
from lzero.entry import train_muzero_context
train_muzero_context([main_config, create_config], seed=seed, max_env_step=max_env_step)

0 comments on commit 8f1e079

Please sign in to comment.