Skip to content

Commit

Permalink
sync code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayilee65 committed May 11, 2024
1 parent ad246cf commit 2cc3468
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
5 changes: 4 additions & 1 deletion lzero/entry/train_unizero_multi_task_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def train_unizero_multi_task_v2(
# # TODO: comment if debugging
# stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)

totoal_env_steps = 0
while True:
# 每个环境单独收集数据,并放入各自独立的replay buffer中
for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate(zip(cfgs, collectors, evaluators, game_buffers)):
Expand All @@ -170,7 +171,9 @@ def train_unizero_multi_task_v2(
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
totoal_env_steps += collector.envstep
# collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) # TODO
collect_kwargs['epsilon'] = epsilon_greedy_fn(totoal_env_steps)
else:
collect_kwargs['epsilon'] = 0.0

Expand Down
10 changes: 5 additions & 5 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
# "gru_gating": False,


"device": 'cuda:5',
"device": 'cuda:0',
'analysis_sim_norm': False,
'analysis_dormant_ratio': False,

Expand Down Expand Up @@ -131,12 +131,12 @@
# 'max_cache_size':500,
"env_num": 8,

'latent_recon_loss_weight': 0.05,
'perceptual_loss_weight': 0.05, # for stack1 rgb obs
# 'latent_recon_loss_weight': 0.05,
# 'perceptual_loss_weight': 0.05, # for stack1 rgb obs
# # 'perceptual_loss_weight':0., # for stack4 gray obs

# 'latent_recon_loss_weight': 0.,
# 'perceptual_loss_weight': 0., # for stack1 rgb obs
'latent_recon_loss_weight': 0.,
'perceptual_loss_weight': 0., # for stack1 rgb obs

# 'latent_recon_loss_weight':0.,
# 'perceptual_loss_weight':0.,
Expand Down
6 changes: 4 additions & 2 deletions lzero/model/gpt_models/world_model_multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ def precompute_pos_emb_diff_kv(self):


def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, kvcache_independent=False, is_init_infer=True, valid_context_lengths=None, task_id=0) -> WorldModelOutput:
# task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE:TODO
task_embeddings = torch.zeros(768, device=self.device) # NOTE:TODO
task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO
# task_embeddings = torch.zeros(768, device=self.device) # NOTE:TODO

if kvcache_independent:
# 根据past_keys_values获取每个样本的步骤数
Expand Down Expand Up @@ -457,11 +457,13 @@ def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysV
# ========== for visualize ==========

# 1,...,0,1 https://github.com/eloialonso/iris/issues/19
# one head
# logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)
# logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
# logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps)
# logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps)

# N head
logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
Expand Down
12 changes: 8 additions & 4 deletions zoo/atari/config/atari_unizero_config_stack1_multitask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(5)
torch.cuda.set_device(0)
# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
env_id = 'PongNoFrameskip-v4'
# env_id = 'MsPacmanNoFrameskip-v4'
Expand Down Expand Up @@ -42,15 +42,19 @@
# max_env_step = int(5e5)

reanalyze_ratio = 0.
batch_size = 64

# batch_size = 64
batch_size = 32 # TODO: multitask


num_simulations = 50

num_unroll_steps = 10
eps_greedy_exploration_in_collect = True


# exp_name_prefix = f'data_unizero_mt_stack1/pong-mspacman_action{action_space_size}_taskembedding_one-head/'
exp_name_prefix = f'data_unizero_mt_stack1/pong-mspacman_action{action_space_size}_notaskembedding_N-head/'
exp_name_prefix = f'data_unizero_mt_stack1/pong-mspacman_action{action_space_size}_taskembedding_N-head/'
# exp_name_prefix = f'data_unizero_mt_stack1/pong-mspacman_action{action_space_size}_notaskembedding_N-head/'
# exp_name_prefix = f'data_unizero_mt_stack1/pong-action{action_space_size}_notaskembedding/'


Expand Down

0 comments on commit 2cc3468

Please sign in to comment.