def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
I check the shape of the inputs and find follows:
In [1]: logits.shape
Out[1]: torch.Size([4, 896, 152064])
In [2]: labels.shape
Out[2]: torch.Size([4, 896])
In [3]: num_items_in_batch
Out[3]: 4390
Why is 4390>4*896?