- https://verl.readthedocs.io/en/latest/perf/perf_tuning.html
    - `use_remove_padding=True` for sequence packing (i.e., data packing and remove padding).
        - rmpad
    - https://github.com/volcengine/verl/blob/main/tests/model/test_transformer.py

```
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange
```
- `def unpad_input(hidden_states, attention_mask, unused_mask=None):`
    - `input_ids_rmpad, indices, * = unpad_input(input_ids.unsqueeze(-1), attention_mask)`
        - (4, 128) => (4, 128, 1), attention_mask.sum() == 301
        - input_ids_rmpad.shape() == (1, 301)
        - indices.shape == (301)
            - 记录了每个有效 token 在原始 (batch, seqlen) 张量中的位置。
    - `origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)`
        - origin_logits.shape == (4, 128, 32000)
        - origin_logits_rmpad.shape == (301, 32000)
- index_first_axis
    - 根据提供的索引 (indices)，从输入张量 (x) 的第一个维度（axis=0）中高效地选取指定的行/元素。

https://github.com/Dao-AILab/flash-attention/issues/11#issuecomment-1156681278

The most performant approach is to do the unpadding (i.e. remove padding tokens) before the first encoder block and add back the padding after the last encoder block, so that unpadding / padding is only called once (instead of 24 times if you have 24 layers). This has the added benefit of speeding up all other layers (LayerNorm, FFN, etc.) since they don't need to operate on padding tokens.

- `AutoModelForCausalLM.from_config(xx, attn_implementation='flash_attention_2')`
    - `logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad, ...)`

使用input_ids_rmpad和position_ids_rmpad调用模型不受影响的核心原因有：
- Flash Attention 2的变长序列支持（flash_attn_varlen）：
    - 代码中指定了attn_implementation='flash_attention_2'
    - Flash Attention 2原生支持变长序列处理，无需传统的方形注意力矩阵
- 有效信息完整保留：
    - unpad_input函数只移除填充部分，保留所有有效token
    - indices变量记录了每个token在原始批次中的位置信息
    - 移除填充后形状从(batch_size, seqlen)变为(1, total_nnz)（`number of nonzero`），但信息不丢失
- 位置编码的精确对齐：
    - position_ids_rmpad保存了每个有效token的正确位置ID
    - 确保模型内部的旋转位置编码(rotary embedding)能够正常工作
    - 这使得移除填充后的位置信息与原始位置信息一致
- Transformer架构的特性：
    - Transformer对每个token的处理本质上是并行的
    - 只要提供正确的位置信息和token关系，不需要处理无意义的填充

## 一个示例

- input_ids: 
    - `[[句子A token1, 句子A token2, PAD, PAD]`,
    - ` [句子B token1, 句子B token2, 句子B token3, PAD]]`
- attention_mask:
    - `[[1, 1, 0, 0],`
    - `  [1, 1, 1, 0]]` (1 代表有效 token，0 代表 PAD)
- position_ids:
    - `[[0, 1, 0, 0],`
    - ` [0, 1, 2, 0]]` (简化表示，实际可能不同，但 PAD 位通常无效)

## qwen2.5-0.5

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange

In [4]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
    attn_implementation="flash_attention_2",
    device_map=device
)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [7]:
print(tokenizer.pad_token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token # Use EOS if pad token is not set

<|endoftext|>


In [16]:
prompts = [
    "你好，请给我介绍一下大型语言模型。",
    "今天天气怎么样？"
]

### original process

In [17]:
inputs_padded = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=64).to(device)
input_ids = inputs_padded['input_ids']
attention_mask = inputs_padded['attention_mask']
batch_size, seqlen = input_ids.shape

In [9]:
input_ids

tensor([[108386,  37945, 104169, 109432, 101951, 102064, 104949,   1773],
        [100644, 104307, 104472,  11319, 151643, 151643, 151643, 151643]],
       device='cuda:0')

In [13]:
print(tokenizer.decode(input_ids[0]))
print(tokenizer.decode(input_ids[1]))

你好，请给我介绍一下大型语言模型。
今天天气怎么样？<|endoftext|><|endoftext|><|endoftext|><|endoftext|>


In [14]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0]], device='cuda:0')

In [15]:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) # Use 1 for masked positions (consistent with tests)

tensor([[0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 1, 1, 1, 1]], device='cuda:0')

In [23]:
with torch.no_grad():
    outputs_standard = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        use_cache=False
    )
origin_logits = outputs_standard.logits

In [27]:
origin_logits.shape, origin_logits

(torch.Size([2, 8, 151936]),
 tensor([[[ 7.0000,  6.5625,  1.6328,  ..., -2.9375, -2.9375, -2.9375],
          [ 8.5000,  5.6875,  4.3438,  ..., -2.9688, -2.9688, -2.9688],
          [ 3.8438,  6.8125,  2.4062,  ..., -4.2812, -4.2812, -4.2812],
          ...,
          [ 4.5625,  7.0312, -1.0703,  ..., -3.5938, -3.5938, -3.5938],
          [ 5.3125, 10.9375,  3.1094,  ..., -3.3281, -3.3281, -3.3281],
          [ 5.7188, 10.0000,  7.3750,  ..., -5.7812, -5.7812, -5.7812]],
 
         [[ 2.8438,  8.2500,  2.7812,  ..., -2.8281, -2.8281, -2.8281],
          [ 6.3750,  8.8125,  6.2188,  ..., -4.0312, -4.0312, -4.0312],
          [11.9375,  9.6875,  7.9375,  ..., -3.2344, -3.2344, -3.2344],
          ...,
          [ 0.4297, -3.3750,  5.2188,  ..., -0.2324, -0.2334, -0.2324],
          [ 0.4297, -3.3750,  5.2188,  ..., -0.2324, -0.2334, -0.2324],
          [ 0.4297, -3.3750,  5.2188,  ..., -0.2324, -0.2334, -0.2324]]],
        device='cuda:0', dtype=torch.bfloat16))

### unpad

In [18]:
input_ids_unpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)
input_ids_unpad = input_ids_unpad.squeeze(-1) # Back to (total_tokens,)

position_ids_reshaped = rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ...") # (b*s, 1)
position_ids_unpad = index_first_axis(position_ids_reshaped, indices) # (total_tokens, 1)
position_ids_unpad = position_ids_unpad.squeeze(-1) # (total_tokens,)

In [40]:
position_ids_reshaped

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [0],
        [1],
        [2],
        [3],
        [1],
        [1],
        [1],
        [1]], device='cuda:0')

In [39]:
index_first_axis(position_ids_reshaped, indices)

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [0],
        [1],
        [2],
        [3]], device='cuda:0')

In [19]:
input_ids_unpad

tensor([108386,  37945, 104169, 109432, 101951, 102064, 104949,   1773, 100644,
        104307, 104472,  11319], device='cuda:0')

In [20]:
indices

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], device='cuda:0')

In [21]:
position_ids_unpad

tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3], device='cuda:0')

In [22]:
with torch.no_grad():
     input_ids_unpad_batch = input_ids_unpad.unsqueeze(0) # (1, total_tokens)
     position_ids_unpad_batch = position_ids_unpad.unsqueeze(0) # (1, total_tokens)

     outputs_unpad = model(
         input_ids=input_ids_unpad_batch,
         position_ids=position_ids_unpad_batch, # Pass unpadded position_ids
         use_cache=False
     )

In [28]:
logits_unpad = outputs_unpad.logits.squeeze(0)

In [29]:
logits_unpad.shape

torch.Size([12, 151936])

In [30]:
logits_re_padded = pad_input(logits_unpad, indices, batch_size, seqlen)

In [31]:
logits_re_padded

tensor([[[ 7.0000,  6.5625,  1.6328,  ..., -2.9375, -2.9375, -2.9375],
         [ 8.5000,  5.6875,  4.3438,  ..., -2.9688, -2.9688, -2.9688],
         [ 3.8438,  6.8125,  2.4062,  ..., -4.2812, -4.2812, -4.2812],
         ...,
         [ 4.5625,  7.0312, -1.0703,  ..., -3.5938, -3.5938, -3.5938],
         [ 5.3125, 10.9375,  3.1094,  ..., -3.3281, -3.3281, -3.3281],
         [ 5.7188, 10.0000,  7.3750,  ..., -5.7812, -5.7812, -5.7812]],

        [[ 2.8438,  8.2500,  2.7812,  ..., -2.8281, -2.8281, -2.8281],
         [ 6.3750,  8.8125,  6.2188,  ..., -4.0312, -4.0312, -4.0312],
         [11.9375,  9.6875,  7.9375,  ..., -3.2344, -3.2344, -3.2344],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       device='cuda:0', dtype=torch.bfloat16)

In [32]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0]], device='cuda:0')

In [33]:
mask_expanded = attention_mask.unsqueeze(-1).bool()
mask_expanded

tensor([[[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True]],

        [[ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False],
         [False],
         [False]]], device='cuda:0')

In [34]:
valid_origin_logits = torch.masked_select(origin_logits, mask_expanded)
valid_re_padded_logits = torch.masked_select(logits_re_padded, mask_expanded)

In [35]:
valid_origin_logits

tensor([ 7.0000,  6.5625,  1.6328,  ..., -4.5000, -4.5000, -4.5000],
       device='cuda:0', dtype=torch.bfloat16)

In [36]:
valid_re_padded_logits

tensor([ 7.0000,  6.5625,  1.6328,  ..., -4.5000, -4.5000, -4.5000],
       device='cuda:0', dtype=torch.bfloat16)