- https://www.bilibili.com/video/BV1oY1aYzEVi/
- https://unsloth.ai/blog/gradient

In [10]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

- bz, ga, gas
    - bz: batch size
    - ga: gradient accumulation
    - gas: gradient accumulation steps
- true batch size: bz * gas
- 核心点在于
    - 因为填充的存在，也就是 bz * gas 内（批次内样本的长度不一致），内部先根据长度取 loss 的均值时（cross entropy loss mean），最后再取平均时，跟全局的 loss 不等效
    - 所有的样本，都不存在 padding 时
        - gas: 2, bz: 1

    $$
    \frac{\frac{l_{11}+l_{12}+l_{13}+l_{14}}4 + \frac{l_{21}+l_{22} + l_{23} + l_{24}}4}2
    $$

    - 存在 padding 时（ignore index），第二个句子

    $$
    \frac{\frac{l_{11}+l_{12}+l_{13}+l_{14}}4 + \frac{l_{21}+l_{22} + l_{23}}3}2
    $$

    - 其与如下的全局 loss，不等价

    $$
    \frac{l_{11} + l_{12} + l_{13} + l_{14}+ {l_{21}+l_{22} + l_{23}}}7
    $$


### X loss in llm

$$
\frac{1}{\mathbb{I}\{y_i \ne -100\}} \sum L_i
$$

In [5]:
input = torch.tensor(
    [[2.0, 0.5, 1.0], 
     [0.1, 1.0, 0.3], 
     [0.5, 1.2, 0.3], 
     [0.98, 0.6, 0.17]], 
    requires_grad=True
)  

target = torch.tensor([0, 2, 1, 1])  # True labels, with -1 as ignore_index
ignore_index = -1  # 定义忽略索引

In [6]:
F.cross_entropy(input=input, target=target, ignore_index=ignore_index, reduction="sum")

tensor(3.5869, grad_fn=<NllLossBackward0>)

In [7]:
F.softmax(input, dim=-1)

tensor([[0.6285, 0.1402, 0.2312],
        [0.2136, 0.5254, 0.2609],
        [0.2609, 0.5254, 0.2136],
        [0.4698, 0.3213, 0.2090]], grad_fn=<SoftmaxBackward0>)

In [12]:
-(np.log(0.6285) +  np.log(0.2609) + np.log(0.5254) + np.log(0.3213))

3.587012752382311

#### 通过 gather 操作手动计算 x-entropy loss

In [14]:
target.unsqueeze(1)

tensor([[0],
        [2],
        [1],
        [1]])

In [18]:
log_probs = F.log_softmax(input, dim=1)
selected = log_probs.gather(1, target.unsqueeze(1)).squeeze(1)

In [19]:
selected

tensor([-0.4644, -1.3435, -0.6435, -1.1355], grad_fn=<SqueezeBackward1>)

In [20]:
loss = -selected
mask = (target != ignore_index)
loss = loss[mask]
loss.sum()

tensor(3.5869, grad_fn=<SumBackward0>)