Skip to content

Commit

Permalink
sync code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayilee65 committed May 13, 2024
1 parent 8f1e079 commit c31c43d
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 183 deletions.
2 changes: 2 additions & 0 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ def __init__(

self.normalize_pixel = normalize_pixel
self.sim_norm = SimNorm(simnorm_dim=group_size)
# self.sim_norm = nn.Sigmoid() # only for ablation



def forward(self, image):
Expand Down
14 changes: 7 additions & 7 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@
# "gru_gating": False,


"device": 'cuda:6',
"device": 'cuda:3',
'analysis_sim_norm': False,
'analysis_dormant_ratio': False,

'action_shape': 6, # TODO:for multi-task
# 'action_shape': 6, # TODO:for multi-task

# 'action_shape': 6, # TODO:for pong qbert
# 'action_shape': 9,# TODO:for mspacman
# 'action_shape': 18,# TODO:for Seaquest boxing Frostbite
'action_shape': 18,# TODO:for Seaquest boxing Frostbite
# 'action_shape': 4,# TODO:for breakout

# 'embed_dim':512, # TODO:for atari
Expand Down Expand Up @@ -133,12 +133,12 @@
# "env_num":16, # TODO
# "env_num":1, # TODO

'latent_recon_loss_weight': 0.05,
'perceptual_loss_weight': 0.05, # for stack1 rgb obs
# 'latent_recon_loss_weight': 0.05,
# 'perceptual_loss_weight': 0.05, # for stack1 rgb obs
# # 'perceptual_loss_weight':0., # for stack4 gray obs

# 'latent_recon_loss_weight': 0.,
# 'perceptual_loss_weight': 0., # for stack1 rgb obs
'latent_recon_loss_weight': 0.,
'perceptual_loss_weight': 0., # for stack1 rgb obs

# 'latent_recon_loss_weight':0.,
# 'perceptual_loss_weight':0.,
Expand Down
83 changes: 50 additions & 33 deletions lzero/model/gpt_models/cfg_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
# "context_length": 2 * (16+5),
# "context_length_for_recurrent": 2 * (16+5),

# 'max_blocks': 18+5,
# "max_tokens": 2 * (18+5), # 1+2+15 memory_length = 2
# "context_length": 2 * (18+5),
# "context_length_for_recurrent": 2 * (18+5),

# 'max_blocks': 76+5,
# "max_tokens": 2 * (76+5), # 1+60+15 memory_length = 60
# "context_length": 2 * (76+5),
Expand All @@ -33,6 +38,11 @@
# "context_length": 2 * (116+5),
# "context_length_for_recurrent": 2 * (116+5),

# 'max_blocks': 136+5,
# "max_tokens": 2 * (136+5), # 1+120+15 memory_length = 120
# "context_length": 2 * (136+5),
# "context_length_for_recurrent": 2 * (136+5),

# 'max_blocks': 266+5,
# "max_tokens": 2 * (266+5), # 1+250+15 memory_length = 250
# "context_length": 2 * (266+5),
Expand All @@ -43,50 +53,57 @@
# "context_length": 2 * (516+5),
# "context_length_for_recurrent": 2 * (516+5),

'max_blocks': 1016+5,
"max_tokens": 2 * (1016+5), # 1+1000+15 memory_length = 1000
"context_length": 2 * (1016+5),
"context_length_for_recurrent": 2 * (1016+5),
# 'max_blocks': 1016+5,
# "max_tokens": 2 * (1016+5), # 1+1000+15 memory_length = 1000
# "context_length": 2 * (1016+5),
# "context_length_for_recurrent": 2 * (1016+5),

# 'max_blocks': 16,
# "max_tokens": 2 * 16, # 1+0+15 memory_length = 0
# "context_length": 2 * 16,
# "context_length_for_recurrent": 2 * 16,
# Key2Door

# 'max_blocks': 90+5,
# "max_tokens": 2 * (90+5), # 15+60+15 memory_length = 60
# "context_length": 2 * (90+5),
# "context_length_for_recurrent": 2 * (90+5),

# 'max_blocks': 21,
# "max_tokens": 2 * 21, # 1+0+15 memory_length = 0
# "context_length": 2 * 21,
# "context_length_for_recurrent": 2 * 21,
# "recurrent_keep_deepth": 100,
# 'max_blocks': 150+5,
# "max_tokens": 2 * (150+5), # 15+120+15 memory_length = 120
# "context_length": 2 * (150+5),
# "context_length_for_recurrent": 2 * (150+5),

# 'max_blocks': 17, # TODO
# "max_tokens": 2 * 17, # 1+0+15 memory_length = 0
# "context_length": 2 * 17,
# "context_length_for_recurrent": 2 * 17,
# "recurrent_keep_deepth": 100,
'max_blocks': 280+5,
"max_tokens": 2 * (280+5), # 15+250+15 memory_length = 250
"context_length": 2 * (280+5),
"context_length_for_recurrent": 2 * (280+5),

# 'max_blocks': 46,
# "max_tokens": 2 * 46, # 1+30+15=76 memory_length = 30
# "context_length": 2 * 46,
# "context_length_for_recurrent": 2 * 46,
# "recurrent_keep_deepth": 100,
# 'max_blocks': 530+5,
# "max_tokens": 2 * (530+5), # 15+500+15 memory_length = 500
# "context_length": 2 * (530+5),
# "context_length_for_recurrent": 2 * (530+5),



"device": 'cuda:0',
"device": 'cuda:7',
'analysis_sim_norm': False,
'analysis_dormant_ratio': False,

'group_size': 8, # NOTE
# 'group_size': 64, # NOTE

# 'group_size': 768, # NOTE
'attention': 'causal',

# 'num_layers': 8,
# 'num_heads': 8,
# 'embed_dim': 64, # TODO: for for visual_match [1000]

'num_layers': 4,
'num_heads': 4,
'embed_dim': 32, # TODO: for memlen=500,1000
'embed_dim': 64, # TODO: for for visual_match [250, 500]
# 'embed_dim': 32, # TODO: for for visual_match [250, 500]

# 'embed_dim': 64, # TODO: for memlen=500,1000
# 'embed_dim': 96, # TODO: for memlen=500,1000
# 'num_layers': 4,
# 'num_heads': 4,
# 'embed_dim': 64, # TODO: for visual_match [2, 60, 100]

# 'num_layers': 8,
# 'num_heads': 8,
Expand Down Expand Up @@ -123,17 +140,17 @@
'support_size': 101, # TODO
'action_shape': 4, # NOTE:for memory
'max_cache_size': 5000,
"env_num": 10, # for memeory_length=1000 TODO
# "env_num": 20,
"env_num": 20, # for memeory_length=1000 TODO
# "env_num": 10,

'latent_recon_loss_weight': 0.05,
# 'latent_recon_loss_weight': 0.0,
# 'latent_recon_loss_weight': 0.05,
'latent_recon_loss_weight': 0.0, # TODO
# 'latent_recon_loss_weight':0.5,
# 'latent_recon_loss_weight':10,

'perceptual_loss_weight': 0.,
'policy_entropy_weight': 1e-4, # NOTE:for key_to_door
# 'policy_entropy_weight': 1e-1, # NOTE:for visual_match
'policy_entropy_weight': 1e-4, # NOTE:for visual_match
# 'policy_entropy_weight': 1e-3, # NOTE:for key_to_door

'predict_latent_loss_type': 'group_kl',
# 'predict_latent_loss_type': 'mse',
Expand Down
10 changes: 6 additions & 4 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,12 +918,14 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, inverse_scalar_t
# 从潜在状态表示重建观察
reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)

# 计算重建损失和感知损失
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
# 计算重建损失和感知损失 TODO
# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
# perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1

# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # NOTE: for stack=4
# perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # NOTE: for stack=4
latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # NOTE: for stack=4
perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # NOTE: for stack=4


elif self.obs_type == 'vector':
perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # NOTE: for stack=4
Expand Down
8 changes: 4 additions & 4 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ def collect(self,
# policy forward
# ==============================================================
# print(f'ready_env_id:{ready_env_id}')
# policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) # for unizero
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id) # for muzero_rnn_allobs
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) # for unizero and muzero
# policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id) # for muzero_rnn_allobs


actions_no_env_id = {k: v['action'] for k, v in policy_output.items()}
Expand Down Expand Up @@ -535,8 +535,8 @@ def collect(self,
# print(f'eps_steps_lst[{env_id}]:{eps_steps_lst[env_id]}')

if hasattr(self._policy.get_attribute('collect_model'), 'world_model'):
# if eps_steps_lst[env_id] % 2000 == 0: # TODO: NOTE for memory
if eps_steps_lst[env_id] % 200 == 0: # TODO: NOTE for atari unizero
if eps_steps_lst[env_id] % 2000 == 0: # TODO: NOTE for memory
# if eps_steps_lst[env_id] % 200 == 0: # TODO: NOTE for atari unizero
# if eps_steps_lst[env_id] % 32 == 0: # TODO: NOTE
# if eps_steps_lst[env_id] % 90 == 0:
# if eps_steps_lst[env_id] % 130 == 0:
Expand Down
8 changes: 4 additions & 4 deletions lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ def eval(
# ==============================================================
# policy forward
# ==============================================================
# policy_output = self._policy.forward(stack_obs, action_mask, to_play) # for unizero
policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) # for muzero_rnn_allobs
policy_output = self._policy.forward(stack_obs, action_mask, to_play) # for unizero and muzero
# policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) # for muzero_rnn_allobs

actions_no_env_id = {k: v['action'] for k, v in policy_output.items()}
distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()}
Expand Down Expand Up @@ -326,8 +326,8 @@ def eval(
eps_steps_lst[env_id] += 1

if hasattr(self._policy.get_attribute('collect_model'), 'world_model'):
# if eps_steps_lst[env_id] % 2000 == 0: # TODO: NOTE for memory
if eps_steps_lst[env_id] % 200 == 0: # TODO: NOTE for atari unizero
if eps_steps_lst[env_id] % 2000 == 0: # TODO: NOTE for memory
# if eps_steps_lst[env_id] % 200 == 0: # TODO: NOTE for atari unizero
# if eps_steps_lst[env_id] % 32 == 0: # TODO: NOTE
# if eps_steps_lst[env_id] % 90 == 0:
# if eps_steps_lst[env_id] % 130 == 0:
Expand Down
9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
DI-engine[common_env]>=0.5.0
# DI-engine[common_env]>=0.5.0
gym[accept-rom-license]==0.25.1
numpy>=1.22.4
# numpy>=1.22.4
pympler
bsuite
# bsuite
minigrid
moviepy
pycolab
pycolab
opencv-python
21 changes: 15 additions & 6 deletions zoo/atari/config/atari_muzero_config_46464_stack1_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@
num_simulations = 50
model_update_ratio = 0.25
batch_size = 256
# max_env_step = int(5e5)
max_env_step = int(1e6)
max_env_step = int(5e5)
# max_env_step = int(1e6)
reanalyze_ratio = 0.
eps_greedy_exploration_in_collect = True

torch.cuda.set_device(1)
torch.cuda.set_device(3)

num_unroll_steps = 10
context_length_init = 4 # 1
Expand Down Expand Up @@ -98,7 +98,7 @@
hook=dict(
load_ckpt_before_run='',
log_show_after_iter=100,
save_ckpt_after_iter=100000, # default is 1000
save_ckpt_after_iter=500000, # default is 1000
save_ckpt_after_run=True,
),
),
Expand Down Expand Up @@ -185,5 +185,14 @@
create_config = atari_muzero_create_config

if __name__ == "__main__":
from lzero.entry import train_muzero_context
train_muzero_context([main_config, create_config], seed=0, max_env_step=max_env_step)
# Define a list of seeds for multiple runs
seeds = [1,2] # You can add more seed values here
# seeds = [1,2] # You can add more seed values here
# seeds = [2] # You can add more seed values here

for seed in seeds:
# Update exp_name to include the current seed TODO
main_config.exp_name=f'data_paper_muzero_variants_0511/stack1_mlp/{env_id[:-14]}_muzero_stack1_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_sslw{ssl_loss_weight}_seed{seed}'
# main_config.exp_name=f'data_paper_muzero_atari-20-games_0510/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_seed{seed}'
from lzero.entry import train_muzero_context
train_muzero_context([main_config, create_config], seed=seed, max_env_step=max_env_step)
31 changes: 22 additions & 9 deletions zoo/atari/config/atari_muzero_config_46464_stack4_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@
num_simulations = 50
model_update_ratio = 0.25
batch_size = 256
# max_env_step = int(5e5)
max_env_step = int(1e6)
max_env_step = int(5e5)
# max_env_step = int(1e6)
reanalyze_ratio = 0.
eps_greedy_exploration_in_collect = True
torch.cuda.set_device(2)
torch.cuda.set_device(0)

num_unroll_steps = 10
ssl_loss_weight = 2
context_length_init = 4 # 1
context_length_init = 4 # 1,4
ssl_loss_weight = 2 # 0,2


# for debug ===========
# collector_env_num = 1
Expand All @@ -74,7 +75,7 @@
# ==============================================================

atari_muzero_config = dict(
exp_name=f'data_paper_muzero_variants_0429/stack4/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_sslw{ssl_loss_weight}_seed0',
exp_name=f'data_paper_muzero_variants_0511/stack4_mlp/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_sslw{ssl_loss_weight}_seed0',
# exp_name=f'data_paper_learn-dynamics_0423/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_adamw1e-4_analysis_dratio0025_seed0',
# exp_name=f'data_paper_muzero_variants_0422/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_conlen1_simnorm-cossim_adamw1e-4_seed0',
# exp_name=f'data_paper_muzero_variants_0422/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_conlen1_simnorm-cossim_adamw1e-4_seed0',
Expand All @@ -100,7 +101,7 @@
hook=dict(
load_ckpt_before_run='',
log_show_after_iter=100,
save_ckpt_after_iter=100000, # default is 1000
save_ckpt_after_iter=500000, # default is 1000
save_ckpt_after_run=True,
),
),
Expand Down Expand Up @@ -185,5 +186,17 @@
create_config = atari_muzero_create_config

if __name__ == "__main__":
from lzero.entry import train_muzero_context
train_muzero_context([main_config, create_config], seed=0, max_env_step=max_env_step)
# from lzero.entry import train_muzero_context
# train_muzero_context([main_config, create_config], seed=0, max_env_step=max_env_step)

# Define a list of seeds for multiple runs
seeds = [1,2] # You can add more seed values here
# seeds = [1,2] # You can add more seed values here
# seeds = [2] # You can add more seed values here

for seed in seeds:
# Update exp_name to include the current seed TODO
main_config.exp_name=f'data_paper_muzero_variants_0511/stack4_mlp/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_sslw{ssl_loss_weight}_seed{seed}'
# main_config.exp_name=f'data_paper_muzero_atari-20-games_0510/{env_id[:-14]}_muzero_stack4_H{num_unroll_steps}_initconlen{context_length_init}_simnorm-cossim_sgd02_seed{seed}'
from lzero.entry import train_muzero_context
train_muzero_context([main_config, create_config], seed=seed, max_env_step=max_env_step)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(6)
torch.cuda.set_device(3)

# env_id = 'AlienNoFrameskip-v4'
# env_id = 'AmidarNoFrameskip-v4'
Expand All @@ -25,7 +25,7 @@
# env_id = 'KungFuMasterNoFrameskip-v4'
# env_id = 'PrivateEyeNoFrameskip-v4'
# env_id = 'RoadRunnerNoFrameskip-v4'
env_id = 'UpNDownNoFrameskip-v4'
# env_id = 'UpNDownNoFrameskip-v4'

update_per_collect = None # for others
# model_update_ratio = 1.
Expand Down

0 comments on commit c31c43d

Please sign in to comment.