Skip to content

Commit

Permalink
fix(pu): fix unizero_multitask
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayilee65 committed May 8, 2024
1 parent a3eeeb3 commit dfa47c3
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 108 deletions.
1 change: 1 addition & 0 deletions lzero/entry/train_unizero_multi_task_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def train_unizero_multi_task_v2(
print(f'='*20)
print(f'collect task_id: {task_id}...')
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
Expand Down
114 changes: 40 additions & 74 deletions lzero/model/gpt_models/world_model_multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
# nn.Linear(config.embed_dim, self.support_size)
# )
# )
self.head_observations = Head( # TODO
max_blocks=config.max_blocks,
block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19
head_module=nn.Sequential(
nn.Linear(config.embed_dim, config.embed_dim),
nn.GELU(),
nn.Linear(config.embed_dim, self.obs_per_embdding_dim),
self.sim_norm,
)
)
# self.head_observations = Head( # TODO
# max_blocks=config.max_blocks,
# block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19
# head_module=nn.Sequential(
# nn.Linear(config.embed_dim, config.embed_dim),
# nn.GELU(),
# nn.Linear(config.embed_dim, self.obs_per_embdding_dim),
# self.sim_norm,
# )
# )
# self.head_policy = Head(
# max_blocks=config.max_blocks,
# block_mask=value_policy_tokens_pattern, # [0,...,1,0]
Expand All @@ -202,10 +202,13 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
self.head_policy_multi_task = nn.ModuleList()
self.head_value_multi_task = nn.ModuleList()
self.head_rewards_multi_task = nn.ModuleList()
self.head_observations_multi_task = nn.ModuleList()


# TODO:======================
# for task_id in range(3):
for task_id in range(2): # TODO
action_space_size=18
action_space_size=18 # TODO:======================
self.head_policy = Head(
max_blocks=config.max_blocks,
block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0]
Expand Down Expand Up @@ -239,12 +242,25 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
)
self.head_rewards_multi_task.append(self.head_rewards)

self.head_observations = Head( # TODO
max_blocks=config.max_blocks,
block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19
head_module=nn.Sequential(
nn.Linear(config.embed_dim, config.embed_dim),
nn.GELU(),
nn.Linear(config.embed_dim, self.obs_per_embdding_dim),
self.sim_norm,
)
)
self.head_observations_multi_task.append(self.head_observations)



self.apply(init_weights)

last_linear_layer_init_zero = True # TODO: 有利于收敛速度。
if last_linear_layer_init_zero:
for head in self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task:
for head in self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task:
# 将头部模块的最后一个线性层的权重和偏置初始化为零
for _, layer in enumerate(reversed(head.head_module)):
if isinstance(layer, nn.Linear):
Expand Down Expand Up @@ -440,15 +456,15 @@ def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysV
# ========== for visualize ==========

# 1,...,0,1 https://github.com/eloialonso/iris/issues/19
logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)

logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps)
logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps)

# logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
# logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
# logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
# logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)
# logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
# logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps)
# logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps)

logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)
logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps)

# TODO: root reward value
return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value)
Expand Down Expand Up @@ -704,7 +720,7 @@ def trim_and_pad_kv_cache(self, is_init_infer=True):
return self.keys_values_wm_size_list



# =================== world_model是共享的,是否需要根据task_id来存储kv_cache ===========
def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, latent_state_index_in_search_path=[], valid_context_lengths=None):
if self.context_length <= 2:
# 即全部是单帧的,没有context
Expand Down Expand Up @@ -910,7 +926,7 @@ def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int):
total_memory_gb = total_memory_bytes / (1024 ** 3)
return total_memory_gb

def compute_loss(self, batch, target_tokenizer: Tokenizer=None, task_id=0, **kwargs: Any) -> LossWithIntermediateLosses:
def compute_loss(self, batch, target_tokenizer: Tokenizer=None, inverse_scalar_transform_handle=None, task_id=0, **kwargs: Any) -> LossWithIntermediateLosses:
# 将观察编码为潜在状态表示
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False)

Expand Down Expand Up @@ -1032,7 +1048,7 @@ def compute_policy_entropy_loss(self, logits, mask):
log_probs = torch.log_softmax(logits, dim=1)
entropy = -(probs * log_probs).sum(1)
# 应用mask并返回平均熵损失
entropy_loss = (entropy * mask).mean()
entropy_loss = (entropy * mask)
return entropy_loss

def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor,
Expand All @@ -1057,55 +1073,5 @@ def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, ta
labels_value = target_value.masked_fill(mask_fill_value, -100)
return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu)

def render_img(self, obs: int, rec_img: int):
import torch
from PIL import Image
import matplotlib.pyplot as plt

# 假设batch是一个字典,其中包含了observations键,
# 并且它的形状是torch.Size([B, N, C, H, W])
# batch_observations = batch_for_gpt['observations']
# batch_observations = batch['observations']
batch_observations = obs.unsqueeze(0)
# batch_observations = rec_img.unsqueeze(0)

# batch_observations = observations.unsqueeze(0)
# batch_observations = x.unsqueeze(0)
# batch_observations = reconstructions.unsqueeze(0)

B, N, C, H, W = batch_observations.shape # 自动检测维度

# 分隔条的宽度(可以根据需要调整)
separator_width = 2

# 遍历每个样本
for i in range(B):
# 提取当前样本中的所有帧
frames = batch_observations[i]

# 计算拼接图像的总宽度(包括分隔条)
total_width = N * W + (N - 1) * separator_width

# 创建一个新的图像,其中包含分隔条
concat_image = Image.new('RGB', (total_width, H), color='black')

# 拼接每一帧及分隔条
for j in range(N):
frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C)
frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB')

# 计算当前帧在拼接图像中的位置
x_position = j * (W + separator_width)
concat_image.paste(frame_image, (x_position, 0))

# 显示图像
plt.imshow(concat_image)
plt.title(f'Sample {i+1}')
plt.axis('off') # 关闭坐标轴显示
plt.show()

# 保存图像到文件
concat_image.save(f'sample_{i+1}.png')

def __repr__(self) -> str:
return "world_model"
80 changes: 50 additions & 30 deletions lzero/policy/unizero_multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,18 @@ def configure_optimizers(model, weight_decay, learning_rate, betas, device_type)
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
if torch.cuda.is_available():
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
else:
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")

return optimizer

@POLICY_REGISTRY.register('unizero_multi_task')
class MuZeroGPTMTPolicy(Policy):
class UniZeroMTPolicy(Policy):
"""
Overview:
The policy class for MuZero.
Expand Down Expand Up @@ -308,6 +311,10 @@ def _init_learn(self) -> None:
self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution
)
self.intermediate_losses = defaultdict(float)
self.l2_norm_before = 0.
self.l2_norm_after= 0.
self.grad_norm_before= 0.
self.grad_norm_after= 0.

#@profile
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
Expand Down Expand Up @@ -445,7 +452,7 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni
# update world model
# ==============================================================
intermediate_losses = defaultdict(float)
losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._target_model.world_model.tokenizer, task_id)
losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._target_model.world_model.tokenizer, task_id=task_id)

weighted_total_loss += losses.loss_total
# weighted_total_loss = weighted_total_loss + losses.loss_total # 修改为非in-place操作
Expand Down Expand Up @@ -479,21 +486,31 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni
for name, parameter in self._learn_model.tokenizer.named_parameters():
print(name)
"""
gradient_scale = 1 / self._cfg.num_unroll_steps
# TODO(pu): test the effect of gradient scale.
weighted_total_loss.register_hook(lambda grad: grad * gradient_scale)
# gradient_scale = 1 / self._cfg.num_unroll_steps
# # TODO(pu): test the effect of gradient scale.
# weighted_total_loss.register_hook(lambda grad: grad * gradient_scale)

self._optimizer_world_model.zero_grad()
weighted_total_loss.backward()

# ============= for analysis ============= TODO
if self._cfg.analysis_sim_norm:
del self.l2_norm_before
del self.l2_norm_after
del self.grad_norm_before
del self.grad_norm_after
self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze()
self._target_model.encoder_hook.clear_data() # 非常非常重要!!!
# ============= for analysis =============


# 在训练循环中使用
# self.monitor_weights_and_grads(self._learn_model.tokenizer.representation_network)
# print('torch.cuda.memory_summary():', torch.cuda.memory_summary())

if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(
self._learn_model.world_model.parameters(), self._cfg.grad_clip_value
)
total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value)

self._optimizer_world_model.step()
if self._cfg.lr_piecewise_constant_decay:
Expand All @@ -503,20 +520,23 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni
# the core target model update step.
# ==============================================================
self._target_model.update(self._learn_model.state_dict())
if self._cfg.use_rnd_model:
self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict())

# 确保所有的CUDA核心完成工作,以便准确统计显存使用情况
torch.cuda.synchronize()
# 获取当前分配的显存总量(字节)
current_memory_allocated = torch.cuda.memory_allocated()
# 获取程序运行到目前为止分配过的最大显存量(字节)
max_memory_allocated = torch.cuda.max_memory_allocated()

# 将显存使用量从字节转换为GB
current_memory_allocated_gb = current_memory_allocated / (1024**3)
max_memory_allocated_gb = max_memory_allocated / (1024**3)
# 使用SummaryWriter记录当前和最大显存使用量
# # 确保所有的CUDA核心完成工作,以便准确统计显存使用情况
if torch.cuda.is_available():
torch.cuda.synchronize()
# 获取当前分配的显存总量(字节)
current_memory_allocated = torch.cuda.memory_allocated()
# 获取程序运行到目前为止分配过的最大显存量(字节)
max_memory_allocated = torch.cuda.max_memory_allocated()

# 将显存使用量从字节转换为GB
current_memory_allocated_gb = current_memory_allocated / (1024**3)
max_memory_allocated_gb = max_memory_allocated / (1024**3)
# 使用SummaryWriter记录当前和最大显存使用量
else:
# TODO
current_memory_allocated_gb = 0.
max_memory_allocated_gb = 0.


# 然后,在您的代码中,使用这个函数来构建损失字典:
Expand Down Expand Up @@ -803,13 +823,13 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
# Setting deterministic=True implies choosing the action with the highest value (argmax) rather than
# sampling during the evaluation phase.

# action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
# distributions, temperature=1, deterministic=True
# )
# TODO: eval for breakout
action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
distributions, temperature=0.25, deterministic=False
distributions, temperature=1, deterministic=True
)
# TODO: eval for breakout
# action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
# distributions, temperature=0.25, deterministic=False
# )
# NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the
# entire action set.
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
Expand Down
20 changes: 16 additions & 4 deletions zoo/atari/config/atari_unizero_config_stack1_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

# share action space
action_space_size = 18
# action_space_size = 6

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -36,7 +38,9 @@
update_per_collect = 1000 # TODO
model_update_ratio = 0.25

max_env_step = int(3e6)
max_env_step = int(1e6) # TODO
# max_env_step = int(5e5)

reanalyze_ratio = 0.
batch_size = 64
num_simulations = 50
Expand All @@ -45,9 +49,12 @@
eps_greedy_exploration_in_collect = True


exp_name_prefix = 'data_unizero_mt_stack1_pong-mspacman/'
exp_name_prefix = 'data_unizero_mt_stack1_pong-mspacman_action{action_space_size}_taskembedding/'
# exp_name_prefix = 'data_unizero_mt_stack1_pong-action{action_space_size}_notaskembedding/'
# exp_name_prefix = f'data_unizero_mt_stack1_pong-action{action_space_size}_taskembedding/'

# debug

# only for debug =========
# batch_size = 2
# update_per_collect = 1 # debug
# num_simulations = 1 # debug
Expand All @@ -59,7 +66,8 @@

atari_muzero_config = dict(
# mcts_ctree, muzero_collector: empty_cache
exp_name=exp_name_prefix+f'{env_id[:-14]}/{env_id[:-14]}_unizero_upc{update_per_collect}-mur{model_update_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_conlen{8}_lsd768-nlayer4-nh8_bacth-kvmaxsize_taskembdding_sharehead_seed0',
# exp_name=exp_name_prefix+f'{env_id[:-14]}/{env_id[:-14]}_unizero_upc{update_per_collect}-mur{model_update_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_conlen{8}_lsd768-nlayer4-nh8_bacth-kvmaxsize_notaskembdding_sharehead_seed0',
exp_name=exp_name_prefix+f'{env_id[:-14]}/{env_id[:-14]}_unizero_upc{update_per_collect}-mur{model_update_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_conlen{8}_lsd768-nlayer4-nh8_bacth-kvmaxsize_taskembdding_nosharehead_seed0',
# exp_name=exp_name_prefix+f'{env_id[:-14]}/{env_id[:-14]}_unizero_upc{update_per_collect}-mur{model_update_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_conlen{8}_lsd768-nlayer4-nh8_bacth-kvmaxsize_taskembdding_sharehead_seed0',
env=dict(
stop_value=int(1e6),
Expand Down Expand Up @@ -186,6 +194,10 @@
main_config_2.policy.task_id = 1
main_config_3.policy.task_id = 2

# Pong
# train_unizero_multi_task_v2([[0, [main_config, create_config]]], seed=0, max_env_step=max_env_step)

# Pong Mspacman
train_unizero_multi_task_v2([[0, [main_config, create_config]], [1, [main_config_2, create_config_2]]], seed=0, max_env_step=max_env_step)

# train_unizero_multi_task([[0, [main_config, create_config]], [1, [main_config_2, create_config_2]], [2, [main_config_3, create_config_3]]], seed=0, max_env_step=max_env_step)

0 comments on commit dfa47c3

Please sign in to comment.