Skip to content

Commit

Permalink
fix(pu): fix the context of max_kv_size in batch
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayilee65 committed Apr 16, 2024
1 parent 2d362ff commit 6f99cef
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 90 deletions.
2 changes: 1 addition & 1 deletion lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

# 'action_shape': 18, # TODO:for multi-task

"device": 'cuda:1',
"device": 'cuda:0',
'action_shape': 6, # TODO:for pong qbert
# 'action_shape': 9,# TODO:for mspacman
# 'action_shape': 18,# TODO:for Seaquest boxing Frostbite
Expand Down
27 changes: 14 additions & 13 deletions lzero/model/gpt_models/cfg_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@
# "max_tokens": 2 * 16, # 1+0+15 memory_length = 0
# "context_length": 2 * 16,
# "context_length_for_recurrent": 2 * 16,
# # "context_length": 6,
# # "context_length_for_recurrent": 6,
# "recurrent_keep_deepth": 100,

'max_blocks': 17, # TODO
"max_tokens": 2 * 17, # 1+0+15 memory_length = 0
"context_length": 2 * 17,
"context_length_for_recurrent": 2 * 17,
"recurrent_keep_deepth": 100,
# 'max_blocks': 17, # TODO
# "max_tokens": 2 * 17, # 1+0+15 memory_length = 0
# "context_length": 2 * 17,
# "context_length_for_recurrent": 2 * 17,
# "recurrent_keep_deepth": 100,

# 'max_blocks': 30,
# "max_tokens": 2 * 30, # 15+0+15 memory_length = 0

# 'max_blocks': 32,
# "max_tokens": 2 * 32, # 15+2+15 memory_length = 2

# 'max_blocks': 60,
# "max_tokens": 2 * 60, # 15+30+15 memory_length = 30
'max_blocks': 76,
"max_tokens": 2 * 76, # 1+60+15=76 memory_length = 60
"context_length": 2 * 76,
"context_length_for_recurrent": 2 * 76,
"recurrent_keep_deepth": 100,

# 'max_blocks': 80, # memory_length = 50
# "max_tokens": 2 * 80,
Expand All @@ -57,22 +62,18 @@
# 'max_blocks': 530, # memory_length = 500
# "max_tokens": 2 * 530,

"device": 'cuda:3',

# 'embed_dim': 64, # TODO:for memory # same as <Transformer shine in RL> paper
# 'embed_dim': 96, # TODO:for memory # same as <Transformer shine in RL> paper
'group_size': 8, # NOTE
# 'group_size': 768, # NOTE


"device": 'cuda:0',
'attention': 'causal',
# 'num_layers': 1,
# 'num_layers': 2, # same as <Transformer shine in RL> paper
# 'num_layers': 4,

'num_layers': 8,
'num_heads': 8,
# # 'embed_dim': 96, # TODO:
# 'embed_dim': 96, # TODO:for memory # same as <Transformer shine in RL> paper
# 'embed_dim': 768, # TODO:Gpt2 Base
'embed_dim': 256, # TODO:

Expand Down
12 changes: 7 additions & 5 deletions lzero/model/gpt_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_
# else:
# # mask.shape: (B, nh, T, L + T)
# mask = self.mask[L:L + T, :L + T]
# # att.shape: (T, L + T)
# att.shape: (B, nh, T, L + T)
# att = att.masked_fill(mask == 0, float('-inf'))

if valid_context_lengths is not None:
Expand All @@ -155,14 +155,16 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_
# 对每个样本,根据其有效长度,将无效的部分设为0
for i in range(B):
mask[i] = self.mask[L:L + T, :L + T].clone() # 不需要.clone()吗
# if L - valid_context_lengths[i]>0:
mask[i, :, :(L - valid_context_lengths[i])] = 0
# 将mask的维度调整为与att的后两个维度相同
# (B, T, L + T) -> (B, nh, T, L + T)
mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1)
else:
# mask.shape: (B, nh, T, L + T)
# mask.shape: (T, L + T)
mask = self.mask[L:L + T, :L + T]
# att.shape: (T, L + T)
att = att.masked_fill(mask == 0, float('-inf'))

# att.shape: (B, nh, T, L + T)
att = att.masked_fill(mask == 0, float('-inf'))

att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
Expand Down

0 comments on commit 6f99cef

Please sign in to comment.