- 之前发的时候，cross_entropy_loss直接用的unsloth的，我结合unsloth和flash-attn库的，写一个更快的cross_entropy，并且支持vocab并行，无缝衔接megatron框架。但是不支持scale和smooth等功能，基本满足大多数人的训练需求

In [2]:
import triton
import triton.language as tl
import torch
import torch.distributed as dist
from copy import deepcopy
import os
from mdy_triton.core import fast_cross_entropy_loss
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
torch_ce = torch.nn.CrossEntropyLoss(reduce=False)




In [19]:
@triton.jit
def _cross_entropy_fwd_kernel(LOGITS, LABELS, LOSSES, LOGSUMEXP,
                             vocab_start_index, row_stride, 
                             M, N, SPLIT, BLOCK_SIZE: tl.constexpr, 
                             ):
    row_idx = tl.program_id(0)
    row_stride = row_stride.to(tl.int64)
    label_idx = tl.load(LABELS + row_idx).to(tl.int32)
    if (label_idx != -100):
        LOGITS += row_idx * row_stride
        base_cols = tl.arange(0, BLOCK_SIZE)
        m_i = -float("inf")
        l_i = 0.0
        for start_n in tl.range(0, N, BLOCK_SIZE):
            cols = start_n + base_cols
            mask = cols < N
            logits = tl.load(LOGITS+cols, mask=mask, other=-float('inf')).to(tl.float32)
            m_ij = tl.max(logits)
            new_m_i = tl.maximum(m_i, m_ij)
            l_i = l_i * tl.exp(m_i - new_m_i) + tl.sum(tl.exp(logits - new_m_i))
            m_i = new_m_i
        lse = tl.log(l_i) + m_i

        if (label_idx >= vocab_start_index) and (label_idx < (vocab_start_index + N)):
            x = -1.0 * tl.load(LOGITS+label_idx-vocab_start_index).to(tl.float32)
            if not SPLIT:
                loss = lse + x
                tl.store(LOSSES+row_idx, loss)
            else:
                tl.store(LOSSES+row_idx, x)
        tl.store(LOGSUMEXP+row_idx, lse)

# @triton.autotune([triton.Config({'BLOCK_SIZE': bsz}, num_stages=ns, num_warps=nw)
#                  for bsz in [8192*2,8192*4, 8192*8]
#                  for ns in [1, 2,3,4]
#                  for nw in [16, 32]
#                  ], key=['M', 'N']
#                  )
@triton.jit
def _cross_entropy_bwd_kernel(DLOSSES, DLOGITS,
                            LOGITS, LABELS, LOGSUMEXP,
                             vocab_start_index, row_stride, 
                             M, N,  INPLACE,
                             BLOCK_SIZE: tl.constexpr,
                             ):
    row_idx = tl.program_id(0)
    LABELS += row_idx
    label_idx = tl.load(LABELS).to(tl.int32)
    row_stride = row_stride.to(tl.int64)
    if (label_idx != -100):
        # label_idx -= vocab_start_index
        LOGITS += row_idx * row_stride
        DLOGITS += row_idx * row_stride
        LOGSUMEXP += row_idx
        DLOSSES += row_idx
        lse = tl.load(LOGSUMEXP)
        dloss = tl.load(DLOSSES).to(tl.float32)
        base_cols = tl.arange(0, BLOCK_SIZE)
        for start_n in tl.range(0, N, BLOCK_SIZE):
            cols = start_n + base_cols
            mask = cols < N
            logits = tl.load(LOGITS+cols, mask=mask, other=0.).to(tl.float32)
            probs = tl.exp(logits - lse)
            tmp = vocab_start_index + start_n
            if (label_idx >= tmp) and (label_idx < (tmp + BLOCK_SIZE)):
                probs = tl.where(cols+vocab_start_index != label_idx, probs, probs-1.)
            tl.store(DLOGITS+cols, probs * dloss, mask=mask)
    elif INPLACE:
        DLOGITS += row_idx * row_stride
        base_cols = tl.arange(0, BLOCK_SIZE)
        zeros = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
        for start_n in tl.range(0, N, BLOCK_SIZE):
            cols = start_n + base_cols
            mask = cols < N
            tl.store(DLOGITS+cols, zeros, mask=mask)

class _FastCrossEntropyLoss(torch.autograd.Function):

    @staticmethod
    def forward(ctx, logits, labels, inplace):
        ctx.input_shape = logits.shape
        tp_rank = 0
        tp_size = 1
        tp_group = None
        N = ctx.input_shape[-1]
        logits = logits.view(-1, N)
        M = logits.size(0)
        losses = torch.zeros(*ctx.input_shape[:-1], device=logits.device, dtype=torch.float32)
        split = tp_size > 1
        vocab_start_index = N * tp_rank
        logsumexp = torch.zeros(M, device=logits.device, dtype=torch.float32)
        # print(logsumexp.stride(), losses.stride())
        _cross_entropy_fwd_kernel[(M,)](logits, labels, losses, logsumexp,
                                                    vocab_start_index, logits.stride(0),
                                                    M, N, split,
                                                    BLOCK_SIZE=4096, num_warps=4, num_stages=3
                                                    )
        if tp_size>1:
            lse_allgather = torch.empty(tp_size, M, dtype=logsumexp.dtype, device=logsumexp.device)
            torch.distributed.all_gather_into_tensor(lse_allgather, logsumexp, group=tp_group)
            torch.distributed.all_reduce(
                losses, op=torch.distributed.ReduceOp.SUM,
            )
            logsumexp = torch.logsumexp(lse_allgather, dim=0)
            losses += logsumexp
            losses.masked_fill_(labels.view(-1)==-100, 0.)
        ctx.save_for_backward(logits, labels, logsumexp)
        ctx.inplace = inplace
        ctx.tp_rank = tp_rank
        return losses
    
    @staticmethod
    def backward(ctx, dlosses):
        logits, labels, logsumexp = ctx.saved_tensors
        dlogits = logits if ctx.inplace else torch.zeros_like(logits)
        N = logits.size(-1)
        logits = logits.view(-1, N)
        M = logits.size(0)
        vocab_start_index = N * ctx.tp_rank
        _cross_entropy_bwd_kernel[(M,)](dlosses, dlogits, 
                                        logits, labels, logsumexp,
                                        vocab_start_index, logits.stride(0),
                                        M, N, ctx.inplace, 
                                        BLOCK_SIZE=16384, num_warps=16, num_stages=4
                                        # BLOCK_SIZE=32768, num_warps=32, num_stages=1
                                                    )
        return dlogits.view(*ctx.input_shape), None, None
def triton_entropy_loss(logits, labels, inplace=False):
    return _FastCrossEntropyLoss.apply(logits, labels, inplace)

In [26]:
torch.cuda.empty_cache()
bs = 4
seq_len = 1024
vocab_size = 80000
dtype=torch.bfloat16
device = 'cuda:0'
factor = 4 if dtype == torch.float32 else 2
print('logits显存占用：',(bs * seq_len * vocab_size) / (1024)**3 * factor,"G")
logits1 = torch.randn(bs, seq_len, vocab_size, dtype=dtype, device=device)
logits1.requires_grad_(True)
logits2 = deepcopy(logits1)
labels = torch.randint(0, vocab_size-10, (bs, seq_len), device=device)
labels = torch.randint(0, vocab_size*2-1, (bs, seq_len)).cuda() - vocab_size
labels.masked_fill_(labels<0, -100)




logits显存占用： 0.6103515625 G


tensor([[54972,  -100,  -100,  ..., 67878,  -100, 40067],
        [70641, 29120, 42895,  ...,  -100,  -100,  -100],
        [ 6666,  -100,  -100,  ...,  -100,  -100, 68260],
        [79595, 72069,  7431,  ...,  -100,  -100, 67588]], device='cuda:0')

In [27]:
y1 = triton_entropy_loss(logits1, labels, True)
y2 = torch_ce(logits2.view(-1, vocab_size).to(torch.float32), labels.view(-1))
dy = torch.rand_like(y1)
y1.backward(dy)
y2.backward(dy.view(-1))
print(torch.allclose(y1.view(-1), y2.view(-1), 1e-4, 1e-4))
print(torch.allclose(logits1.grad, logits2.grad, 1e-4, 1e-4))

True
True


In [28]:
print(triton.testing.do_bench(lambda: triton_entropy_loss(logits1, labels)))    # my
print(triton.testing.do_bench(lambda: fast_cross_entropy_loss(logits1, labels)))# unsloth
print(triton.testing.do_bench(lambda: torch_ce(logits2.view(-1, vocab_size).float(), labels.view(-1)))) #torch

0.9373115301132202
2.3171751499176025
13.22261905670166


In [29]:
torch.cuda.empty_cache()
def fwd_bwd1():
    y = triton_entropy_loss(logits1, labels, inplace=True)
    y.sum().backward()
def fwd_bwd2():
    y = fast_cross_entropy_loss(logits1, labels)
    y.sum().backward()
def fwd_bwd3():
    y = torch_ce(logits1.view(-1, logits1.size(-1)).float(), labels.view(-1))
    y.sum().backward()
print(triton.testing.do_bench(lambda: fwd_bwd1(), rep=1000)) # my
print(triton.testing.do_bench(lambda: fwd_bwd2(), rep=1000)) # unsloth
print(triton.testing.do_bench(lambda: fwd_bwd3(), rep=1000)) # torch

7.344710350036621
9.676716804504395
41.73855209350586
