Skip to content

Commit

Permalink
fix(pu): fix the context of max_kv_size in batch
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayilee65 committed Apr 17, 2024
1 parent 6f99cef commit 51af708
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 104 deletions.
2 changes: 1 addition & 1 deletion lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

# 'action_shape': 18, # TODO:for multi-task

"device": 'cuda:0',
"device": 'cuda:4',
'action_shape': 6, # TODO:for pong qbert
# 'action_shape': 9,# TODO:for mspacman
# 'action_shape': 18,# TODO:for Seaquest boxing Frostbite
Expand Down
43 changes: 22 additions & 21 deletions lzero/model/gpt_models/cfg_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
cfg['world_model'] = {
'tokens_per_block': 2,

# 'max_blocks': 16,
# "max_tokens": 2 * 16, # 1+0+15 memory_length = 0
# "context_length": 2 * 16,
# "context_length_for_recurrent": 2 * 16,

# 'max_blocks': 16,
# "max_tokens": 2 * 16, # 1+0+15 memory_length = 0
# "context_length": 2 * 16,
# "context_length_for_recurrent": 2 * 16,
# # "context_length": 6,
# # "context_length_for_recurrent": 6,

'max_blocks': 16,
"max_tokens": 2 * 16, # 1+0+15 memory_length = 0
"context_length": 2 * 16,
"context_length_for_recurrent": 2 * 16,
"recurrent_keep_deepth": 100,

# 'max_blocks': 21,
# "max_tokens": 2 * 21, # 1+0+15 memory_length = 0
# "context_length": 2 * 21,
# "context_length_for_recurrent": 2 * 21,
# "recurrent_keep_deepth": 100,

# 'max_blocks': 17, # TODO
Expand All @@ -32,17 +32,18 @@
# "context_length_for_recurrent": 2 * 17,
# "recurrent_keep_deepth": 100,

# 'max_blocks': 30,
# "max_tokens": 2 * 30, # 15+0+15 memory_length = 0

# 'max_blocks': 32,
# "max_tokens": 2 * 32, # 15+2+15 memory_length = 2
# 'max_blocks': 46,
# "max_tokens": 2 * 46, # 1+30+15=76 memory_length = 30
# "context_length": 2 * 46,
# "context_length_for_recurrent": 2 * 46,
# "recurrent_keep_deepth": 100,

'max_blocks': 76,
"max_tokens": 2 * 76, # 1+60+15=76 memory_length = 60
"context_length": 2 * 76,
"context_length_for_recurrent": 2 * 76,
"recurrent_keep_deepth": 100,
# 'max_blocks': 76,
# "max_tokens": 2 * 76, # 1+60+15=76 memory_length = 60
# "context_length": 2 * 76,
# "context_length_for_recurrent": 2 * 76,
# "recurrent_keep_deepth": 100,

# 'max_blocks': 80, # memory_length = 50
# "max_tokens": 2 * 80,
Expand All @@ -62,7 +63,7 @@
# 'max_blocks': 530, # memory_length = 500
# "max_tokens": 2 * 530,

"device": 'cuda:3',
"device": 'cuda:7',

'group_size': 8, # NOTE
# 'group_size': 768, # NOTE
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/gpt_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_
# 对每个样本,根据其有效长度,将无效的部分设为0
for i in range(B):
mask[i] = self.mask[L:L + T, :L + T].clone() # 不需要.clone()吗
mask[i, :, :(L - valid_context_lengths[i])] = 0
mask[i, :, :(L - valid_context_lengths[i])] = 0 # 无效的部分设为0
# 将mask的维度调整为与att的后两个维度相同
# (B, T, L + T) -> (B, nh, T, L + T)
mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1)
Expand Down
77 changes: 14 additions & 63 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0,
print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt)

# 输入self.keys_values_wm_list,输出为self.keys_values_wm
self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False)
self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) # 与上面self.retrieve_or_generate_kvcache返回的一致
self.keys_values_wm_size_list_current = self.keys_values_wm_size_list
for k in range(2): # 假设每次只有一个动作token。
# action_token obs_token, ..., obs_token 1+1
Expand All @@ -585,6 +585,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0,
obs_embeddings_or_act_tokens = {'obs_embeddings': token}

outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, kvcache_independent=False, is_init_infer=False)
print('keys_values_wm_size_list_current:', self.keys_values_wm_size_list_current)
self.keys_values_wm_size_list_current = [i+1 for i in self.keys_values_wm_size_list_current] # NOTE: +1 ===============
if k == 0:
# 如果k==0,token是action_token,outputs_wm.logits_rewards 是有值的
Expand Down Expand Up @@ -636,10 +637,10 @@ def trim_and_pad_kv_cache(self, is_init_infer=True):

# 如果需要填充,在缓存的开头添加'pad_size'个零 ====================
if pad_size > 0:
# NOTE: 先去掉后面pad_size个无效kv, 注意位置编码的正确性
# NOTE: 先去掉后面pad_size个无效的零 kv, 再在缓存的开头添加'pad_size'个零 ,注意位置编码的正确性
k_cache_trimmed = k_cache[:, :, :-pad_size, :]
v_cache_trimmed = v_cache[:, :, :-pad_size, :]
k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0)
k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) # 在缓存的开头添加'pad_size'个零
v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0)
else:
k_cache_padded = k_cache
Expand All @@ -659,58 +660,6 @@ def trim_and_pad_kv_cache(self, is_init_infer=True):

return self.keys_values_wm_size_list

# def trim_and_pad_kv_cache(self, is_init_infer=True):
# """
# This method trims and pads the key and value caches of the attention mechanism
# to a consistent size across all items in the batch, determined by the smallest cache size.
# """
# if is_init_infer:
# print('='*20)
# print(f'is_init_infer: {is_init_infer}')
# print(f'self.keys_values_wm_size_list: {self.keys_values_wm_size_list}')

# # 找到所有key-value尺寸中的最小尺寸,用于填充/修剪
# min_size = min(self.keys_values_wm_size_list)

# # 遍历transformer的每一层
# for layer in range(self.num_layers):
# # 初始化列表来存储修剪和填充后的k和v缓存
# kv_cache_k_list = []
# kv_cache_v_list = []

# # 枚举key-value对列表
# for idx, keys_values in enumerate(self.keys_values_wm_list):
# # 检索当前层的key和value缓存
# k_cache = keys_values[layer]._k_cache._cache
# v_cache = keys_values[layer]._v_cache._cache

# # 获取当前缓存的有效尺寸
# effective_size = self.keys_values_wm_size_list[idx]
# # 计算需要修剪的尺寸差异
# trim_size = effective_size - min_size if effective_size > min_size else 0

# # 如果需要修剪,从缓存的开头移除'trim_size'
# if trim_size > 0:
# k_cache_trimmed = k_cache[:, :, trim_size:, :]
# v_cache_trimmed = v_cache[:, :, trim_size:, :]
# # 在第三维上用零填充修剪后的缓存
# k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0)
# v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0)
# else:
# k_cache_padded = k_cache
# v_cache_padded = v_cache

# # 将处理后的缓存添加到列表中
# kv_cache_k_list.append(k_cache_padded)
# kv_cache_v_list.append(v_cache_padded)

# # 沿新维度堆叠缓存,并用squeeze()移除额外维度
# self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1)
# self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1)

# # 修剪和填充后,将缓存尺寸更新为最小尺寸
# self.keys_values_wm._keys_values[layer]._k_cache._size = min_size
# self.keys_values_wm._keys_values[layer]._v_cache._size = min_size

def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, latent_state_index_in_search_path=[], valid_context_lengths=None):
if self.context_length <= 2:
Expand Down Expand Up @@ -738,14 +687,14 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde

trim_size = current_max_context_length-self.keys_values_wm_size_list_current[i]
# 根据有效长度裁剪 TODO=======================================
# NOTE: 先去掉前面pad_size/trim_size个无效kv, 注意位置编码的正确性
k_cache_trimmed = k_cache_current[:, trim_size:, :]
v_cache_trimmed = v_cache_current[:, trim_size:, :]

# 如果有效长度<current_max_context_length, 需要在缓存的后面补充'trim_size'个零 ====================
if trim_size > 0:
# NOTE: 先去掉后面pad_size个无效kv, 注意位置编码的正确性
k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0)
v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0)
k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) # 在缓存的后面补充'trim_size'个零
v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0)
else:
k_cache_padded = k_cache_trimmed
v_cache_padded = v_cache_trimmed
Expand All @@ -758,7 +707,9 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm_size_list_current[i]

# NOTE: check 非常重要 ============
if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length-1: # 固定只保留最近5个timestep的context
if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length-1:
# 固定只保留最近self.context_length-3个timestep的context
# ===============对于memory环境,训练时是H步,recurrent_inference时可能超出H步 =================
# print(f'self.keys_values_wm_size_list_current[i]:{self.keys_values_wm_size_list_current[i]}')
# 需要对self.keys_values_wm_single_env进行处理,而不是self.keys_values_wm
# 裁剪和填充逻辑
Expand All @@ -777,8 +728,8 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)

# 沿第3维,用0填充后2步
padding_size = (0, 0, 0, 3) # F.pad的参数(0, 0, 0, 2)指定了在每个维度上的填充量。这些参数是按(左, 右, 上, 下)的顺序给出的,对于三维张量来说,分别对应于(维度2左侧, 维度2右侧, 维度1左侧, 维度1右侧)的填充。
# 沿第3维,用0填充后3步
padding_size = (0, 0, 0, 3) # F.pad的参数(0, 0, 0, 3)指定了在每个维度上的填充量。这些参数是按(左, 右, 上, 下)的顺序给出的,对于三维张量来说,分别对应于(维度2左侧, 维度2右侧, 维度1左侧, 维度1右侧)的填充。
k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0)
v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0)
# 更新单环境cache
Expand Down Expand Up @@ -819,8 +770,8 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)

# 沿第3维,用0填充后2步
padding_size = (0, 0, 0, 3) # F.pad的参数(0, 0, 0, 2)指定了在每个维度上的填充量。这些参数是按(左, 右, 上, 下)的顺序给出的,对于三维张量来说,分别对应于(维度2左侧, 维度2右侧, 维度1左侧, 维度1右侧)的填充。
# 沿第3维,用0填充后3步
padding_size = (0, 0, 0, 3) # F.pad的参数(0, 0, 0, 3)指定了在每个维度上的填充量。这些参数是按(左, 右, 上, 下)的顺序给出的,对于三维张量来说,分别对应于(维度2左侧, 维度2右侧, 维度1左侧, 维度1右侧)的填充。
k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0)
v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0)
# 更新单环境cache
Expand Down
4 changes: 2 additions & 2 deletions zoo/atari/config/atari_xzero_config_stack1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from easydict import EasyDict

import torch
torch.cuda.set_device(0)
torch.cuda.set_device(4)
# ==== NOTE: 需要设置cfg_atari中的action_shape =====

# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
Expand Down Expand Up @@ -75,7 +75,7 @@
# TODO:
# mcts_ctree
# muzero_collector/evaluator: empty_cache
exp_name=f'data_xzero_atari_0416/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_grugating-false_latent-groupkl_conleninit{8}-conlenrecur{8}clear_lsd768-nlayer1-nh8_bacth-kvmaxsize-fix_seed0',
exp_name=f'data_xzero_atari_0416/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_grugating-false_latent-groupkl_conleninit{8}-conlenrecur{8}clear_lsd768-nlayer1-nh8_bacth-kvmaxsize-fix0417_seed0',
# exp_name=f'data_xzero_atari_0407/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_grugating-false_latent-groupkl_conleninit{20}-conlenrecur{20}clear-gamma1_lsd1536-nlayer12-nh12_steplosslog_seed0',

env=dict(
Expand Down
24 changes: 15 additions & 9 deletions zoo/memory/config/memory_xzero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
env_id = 'visual_match' # The name of the environment, options: 'visual_match', 'key_to_door'
# env_id = 'key_to_door' # The name of the environment, options: 'visual_match', 'key_to_door'

memory_length = 60
# memory_length = 0
# memory_length = 60
memory_length = 0

# visual_match [2, 60, 100, 250, 500]
# key_to_door [2, 60, 120, 250, 500]

max_env_step = int(3e6)
# max_env_step = int(3e6)
# max_env_step = int(1e6)
# max_env_step = int(5e5)
max_env_step = int(5e5)


# ==== NOTE: 需要设置cfg_memory中的action_shape =====
Expand Down Expand Up @@ -42,11 +42,14 @@
num_unroll_steps = 16 + memory_length
game_segment_length = 16 + memory_length # TODO: for "explore": 1

# num_unroll_steps = 21 + memory_length
# game_segment_length = 21 + memory_length # TODO: for "explore": 1

# num_unroll_steps = 17 + memory_length
# game_segment_length = 17 + memory_length # TODO: for "explore": 2

# num_unroll_steps = 9 + memory_length
# game_segment_length = 9 + memory_length # TODO: for "explore": 1
# num_unroll_steps = 30 + memory_length
# game_segment_length = 30 + memory_length # TODO: for "explore": 1


reanalyze_ratio = 0
Expand All @@ -59,14 +62,16 @@
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
torch.cuda.set_device(3)
torch.cuda.set_device(7)
memory_xzero_config = dict(
# (4,5,5) config, world_model, muzero_gpt_model, memory env
# mcts_ctree.py muzero_collector muzero_evaluator
exp_name=f'data_memory_{env_id}_0416/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_bs{batch_size}'
f'_seed{seed}_evalenv{evaluator_env_num}_collectenv{collector_env_num}_bacth-kvmaxsize-fix0417',
# exp_name=f'data_memory_{env_id}_0416/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_bs{batch_size}'
# f'_seed{seed}_eval{evaluator_env_num}_nl8-nh8-emd256_phase3-fixed-colormap-bce_phase1-fixed-target-pos_random-target-color_reclw005_encoder-layer3_obschannel3_valuesize101_collectenv{collector_env_num}_bacth-kvmaxsize-fix0417',
# exp_name=f'data_memory_{env_id}_0413_debug/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_bs{batch_size}'
# f'_seed{seed}_eval{evaluator_env_num}_nl12-nh12-emd1536_phase3-fixed-colormap-bce_phase1-fixed-target-pos_random-target-color_reclw005_encoder-layer4_obschannel4_reclw0',
exp_name=f'data_memory_{env_id}_0416/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_bs{batch_size}'
f'_seed{seed}_eval{evaluator_env_num}_nl8-nh8-emd256_phase3-fixed-colormap-bce_phase1-fixed-target-pos_random-target-color_reclw005_encoder-layer3_obschannel3_valuesize101_collectenv{collector_env_num}_bacth-kvmaxsize-fixmask',
env=dict(
stop_value=int(1e6),
env_id=env_id,
Expand All @@ -78,6 +83,7 @@
"explore": 1, # for visual_match
"distractor": memory_length,
"reward": 15
# "reward": 20
# "reward": 8 # debug
}, # Maximum frames per phase
collector_env_num=collector_env_num,
Expand Down
14 changes: 7 additions & 7 deletions zoo/memory/config/memory_xzero_config_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
# game_segment_length=30+memory_length # TODO: for "explore": 15

# for visual_match
# num_unroll_steps = 16 + memory_length
# game_segment_length = 16 + memory_length # TODO: for "explore": 1
num_unroll_steps = 16 + memory_length
game_segment_length = 16 + memory_length # TODO: for "explore": 1

num_unroll_steps = 17 + memory_length
game_segment_length = 17 + memory_length # TODO: for "explore": 2
# num_unroll_steps = 17 + memory_length
# game_segment_length = 17 + memory_length # TODO: for "explore": 2

# num_unroll_steps = 9 + memory_length
# game_segment_length = 9 + memory_length # TODO: for "explore": 1
Expand Down Expand Up @@ -75,7 +75,7 @@
flatten_observation=False, # Whether to flatten the observation
max_frames={
# "explore": 15, # for key_to_door
"explore": 2, # for visual_match
"explore": 1, # for visual_match
"distractor": memory_length,
"reward": 15
# "reward": 8 # debug
Expand All @@ -98,8 +98,8 @@
),
# sample_type='transition',
sample_type='episode', # NOTE: very important for memory env
# model_path=None,
model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_visual_match_0415/visual_match_memlen-0_xzero_H17_bs64_seed0_eval8_nl8-nh8-emd256_phase3-fixed-colormap-bce_phase1-fixed-target-pos_random-target-color_reclw005_encoder-layer3_obschannel3_valuesize101_240415_165207/ckpt/ckpt_best.pth.tar',
model_path=None,
# model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_visual_match_0415/visual_match_memlen-0_xzero_H17_bs64_seed0_eval8_nl8-nh8-emd256_phase3-fixed-colormap-bce_phase1-fixed-target-pos_random-target-color_reclw005_encoder-layer3_obschannel3_valuesize101_240415_165207/ckpt/ckpt_best.pth.tar',
# model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_visual_match_0413/visual_match_memlen-0_xzero_H16_bs64_seed0_eval8_nl8-nh8-emd768_phase3-fixed-colormap-bce_phase1-fixed-target-pos_random-target-color_reclw005_encoder-layer4_obschannel4_240414_172713/ckpt/ckpt_best.pth.tar',
transformer_start_after_envsteps=int(0),
update_per_collect_transformer=update_per_collect,
Expand Down

0 comments on commit 51af708

Please sign in to comment.