```python
labels = input_ids[:, 1:].contiguous()
output = self.fsdp_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False)
```

$$
p(W)=p(w_1,w_2,\cdots,w_N)=p(w_1)p(w_2|w_1)\cdots p(w_N|w_1,\cdots,w_{N-1})  
$$

- sft 训练的时候，不需要模型进行任何的 decoding/(auto-regression)generation
    - 所有的 response 都是训练数据，这里的 `input_ids = prompt + response + eos_token`
    - 模型的前向只是标准地计算各个位置（预测下一个 token 时）在整个词表上的 logits 分布
        - https://www.bilibili.com/video/av1005936005/
    - `(batch_size, sequence_length)` => `(batch_size, sequence_length, vocab_size)`
- teacher forcing 机制
    - 在训练的每一步，都使用真实的目标输出来作为模型下一步的输入，而不是使用模型自己上一步的预测输出。
    - 在学生练习的每一步都立即纠正学生的错误，并给出正确答案，让学生基于正确答案继续下一步的学习。
    - 学弹钢琴
        - 没有 Teacher Forcing： 你弹了一个音符，然后根据你弹的这个音符去想下一个音符。如果你弹错了一个音，后面的旋律可能就全乱了，你可能需要很久才能回到正确的轨道上。
        - 有 Teacher Forcing： 你弹了一个音符。不管你弹得对不对，老师立刻告诉你乐谱上正确的下一个音符是什么，让你照着这个正确的音符继续练习。这样你能更快地学会整首曲子正确的弹法。
- 假如 sequence_length (prompt + response) 为 N，所谓的 shift logits/labels
    - `labels[1:]`: N-1, `logits[:-1]`: N-1
        - 长度对齐，逐 token 算 loss
        - `labels[1:]`: 相当于左移，当前预测下一个token
        - `logits[:-1]`: 截断最后一位，eos 对应的 logits 无需预测下一个token
- trl/swift training 的时候，除了 metric loss 还会 metric token accuracy
    - argmax logits

- prompt: `你好吗？`
- response：`我很好，谢谢你！EOS`
    - EOS：`<|im_end|>`，消息的结束：（user）prompt 的结束，（assistant）response 的结束
  
|~~好~~|~~吗~~ |~~？~~|我|很 |好|，|谢 |谢|你|! |EOS|x|
|--|--|--|--|--|--|--|--|--|--|--|--|--|
|~~你~~|~~好~~|~~吗~~ |？|我 |很|好|， |谢|谢|你|!|~~EOS~~

- 不涉及 auto-regression decoding/generation 的 teacher forcing
- response 包含 eos token 都需要监督算 loss

In [54]:
import torch
from transformers import AutoTokenizer

In [55]:
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct')

In [56]:
tokenizer.special_tokens_map

{'eos_token': '<|im_end|>',
 'pad_token': '<|endoftext|>',
 'additional_special_tokens': ['<|im_start|>',
  '<|im_end|>',
  '<|object_ref_start|>',
  '<|object_ref_end|>',
  '<|box_start|>',
  '<|box_end|>',
  '<|quad_start|>',
  '<|quad_end|>',
  '<|vision_start|>',
  '<|vision_end|>',
  '<|vision_pad|>',
  '<|image_pad|>',
  '<|video_pad|>']}

In [57]:
tokenizer.eos_token, tokenizer.pad_token_id, tokenizer.encode(['<|im_start|>', '<|im_end|>'])

('<|im_end|>', 151643, [151644, 151645])

In [58]:
tokenizer.decode(151643)

'<|endoftext|>'

In [12]:
# tokenizer.eos_token == '<|im_end|>'

In [59]:
prompt = '你好吗？'
response = '我很好，谢谢你！'

In [60]:
prompt_chat = [{"role": "user", "content": prompt}]

- 之前关于 chat template 的几期内容
    - https://www.bilibili.com/video/BV1LKXSYqE3T/
    - https://www.bilibili.com/video/BV1dsdWYuEXw/
    - https://www.bilibili.com/video/BV1JZLcz4EUC/

In [61]:
prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False)
response_chat_str = response + tokenizer.eos_token

In [62]:
prompt_chat_str

'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n你好吗？<|im_end|>\n<|im_start|>assistant\n'

In [63]:
print(prompt_chat_str)

<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
你好吗？<|im_end|>
<|im_start|>assistant



In [68]:
print(tokenizer.apply_chat_template(prompt_chat, tokenize=False))

<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
你好吗？<|im_end|>



In [8]:
response_chat_str

'我很好，谢谢你！<|im_end|>'

In [66]:
msgs = [{"role": "user", "content": prompt},
        {"role": "assistant", "content": response}]
print(tokenizer.apply_chat_template(msgs, tokenize=False))

<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
你好吗？<|im_end|>
<|im_start|>assistant
我很好，谢谢你！<|im_end|>



### tokenize

In [13]:
prompt_ids_output = tokenizer(prompt_chat_str, return_tensors="pt", add_special_tokens=False)

In [14]:
prompt_ids = prompt_ids_output["input_ids"][0]
prompt_attention_mask = prompt_ids_output["attention_mask"][0]
prompt_ids, prompt_attention_mask

(tensor([151644,   8948,    198,   2610,    525,   1207,  16948,     11,   3465,
            553,  54364,  14817,     13,   1446,    525,    264,  10950,  17847,
             13, 151645,    198, 151644,    872,    198, 108386, 101037,  11319,
         151645,    198, 151644,  77091,    198]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]))

In [19]:
tokenizer.decode([108386, 101037,  11319])

'你好吗？'

In [21]:
response_ids_output = tokenizer(response_chat_str, return_tensors="pt", add_special_tokens=False)
response_ids = response_ids_output["input_ids"][0]
response_attention_mask = response_ids_output["attention_mask"][0]
response_ids

tensor([ 35946, 101243,   3837, 116642,   6313, 151645])

In [69]:
tokenizer.decode([35946, 101243,   3837, 116642,   6313, 151645])

'我很好，谢谢你！<|im_end|>'

### padding

In [24]:
prompt_length = prompt_ids.shape[0]
response_length = response_ids.shape[0]
prompt_length, response_length

(32, 6)

In [28]:
input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
input_ids.shape, attention_mask.shape

(torch.Size([38]), torch.Size([38]))

In [39]:
sequence_length = input_ids.shape[0]

In [36]:
max_length = 40

In [40]:
padded_input_ids = torch.ones(size=(max_length - sequence_length,), dtype=input_ids.dtype) * tokenizer.pad_token_id
padded_attention_mask = torch.zeros(size=(max_length - sequence_length,), dtype=attention_mask.dtype)

In [42]:
padded_input_ids, padded_attention_mask

(tensor([151643, 151643]), tensor([0, 0]))

In [43]:
input_ids = torch.cat((input_ids, padded_input_ids))
attention_mask = torch.cat((attention_mask, padded_attention_mask))

In [45]:
input_ids, attention_mask

(tensor([151644,   8948,    198,   2610,    525,   1207,  16948,     11,   3465,
            553,  54364,  14817,     13,   1446,    525,    264,  10950,  17847,
             13, 151645,    198, 151644,    872,    198, 108386, 101037,  11319,
         151645,    198, 151644,  77091,    198,  35946, 101243,   3837, 116642,
           6313, 151645, 151643, 151643]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]))

In [46]:
def compute_position_id_with_mask(mask):
    return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)

In [48]:
position_ids = compute_position_id_with_mask(attention_mask)
position_ids

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 37, 37])

In [49]:
loss_mask = attention_mask.clone()

In [53]:
if prompt_length > 1:
    # mask out prompt for SFT. prompt 的最后一个 token 参与预测，即预测 response 的第一个 token
    loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0
# mask out the last token in response
loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0

In [52]:
min(prompt_length + response_length, loss_mask.size(0))

38

In [51]:
# response 包含 eos token 都需要监督算 loss
loss_mask

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0])