In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

$$
\begin{split}
&-\sum_i^Cy_i\log p_i\\
\end{split}
$$

- 对于 imagenet 的 1000 分类问题

$$
-\sum_{i=1}^{1000}y_i\log p_i
$$

- input_ids => logits
- logits, 要与 labels 等 shape；
    - 出于对齐 shape 的目的，至少在 seqlen 要保持一致；
- batch 进去算，默认也是平均意义的（即单样本级别的 loss）；
    - https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
    - 默认 reduction = 'mean'
- https://www.bilibili.com/video/BV1NY4y1E76o/

In [2]:
import torch
torch.manual_seed(42)

<torch._C.Generator at 0x724a0d6039f0>

In [3]:
# 模拟输入和标签
input_ids = torch.tensor([[1, 2, 3, 4], [4, 3, 2, 1]])
labels = torch.tensor([[2, 3, 4, -100], [3, 2, 1, -100]])

In [4]:
# 模拟模型输出 (batch_size, sequence_length, num_classes)
logits = torch.randn(2, 4, 10)

In [5]:
# 初始化损失函数并设置 ignore_index
criterion = nn.CrossEntropyLoss(ignore_index=-100)

In [14]:
logits.view(-1, logits.size(-1)).shape

torch.Size([8, 10])

In [20]:
labels.view(-1), labels.view(-1).shape

(tensor([   2,    3,    4, -100,    3,    2,    1, -100]), torch.Size([8]))

In [21]:
# 计算损失
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
loss

tensor(2.5207)

In [23]:
torch.log(nn.Softmax(dim=1)(logits.view(-1, logits.size(-1))))

tensor([[-1.2072, -1.6468, -2.2334, -5.2397, -2.4557, -4.3687, -3.1772, -4.7388,
         -3.8863, -1.4854],
        [-2.9390, -3.9501, -3.2744, -3.1059, -3.3153, -1.7841, -0.9042, -2.7061,
         -3.0439, -2.1069],
        [-4.0058, -2.1694, -2.4469, -1.5671, -1.9685, -1.9512, -2.6372, -1.9129,
         -3.4793, -3.2059],
        [-2.8926, -1.7811, -4.0257, -3.5122, -2.8644, -0.9236, -2.3221, -3.0655,
         -2.3353, -3.4156],
        [-4.2205, -1.6673, -3.5427, -3.2640, -3.9371, -0.5401, -3.8976, -3.1508,
         -3.5767, -3.3210],
        [-2.3889, -1.9412, -2.9550, -1.2756, -3.2810, -3.2030, -3.8702, -2.4310,
         -2.5305, -1.7914],
        [-3.5394, -1.5970, -4.6261, -2.0580, -1.9965, -2.5852, -1.2235, -2.9184,
         -3.0949, -3.6389],
        [-3.5348, -1.2022, -2.6524, -1.9564, -2.4236, -2.0539, -1.9052, -3.1219,
         -4.6866, -3.2310]])

In [26]:
((-2.2334 + (-3.1059) + (-1.9685)) + ((-3.2640) + (-2.9550) + (-1.5970)))/6

-2.520633333333333

### labels 作为 id

In [4]:
x = torch.randn(3, 5)
x

tensor([[ 0.9580,  1.3221,  0.8172, -0.7658, -0.7506],
        [ 1.3525,  0.6863, -0.3278,  0.7950,  0.2815],
        [ 0.0562,  0.5227, -0.2384, -0.0499,  0.5263]])

In [13]:
labels = torch.tensor([1, 2, 3])

In [14]:
F.cross_entropy(x, labels)

tensor(1.8159)

In [12]:
torch.log(F.softmax(x, dim=1))

tensor([[-1.2995, -0.9353, -1.4403, -3.0233, -3.0081],
        [-0.9613, -1.6276, -2.6417, -1.5189, -2.0324],
        [-1.7646, -1.2980, -2.0591, -1.8706, -1.2944]])

In [15]:
(-0.9353 + (-2.6417) + (-1.8706))/3

-1.8158666666666665