In [5]:
import torch
import triton
import triton.language as tl
from copy import deepcopy
import os
from tqdm import tqdm
import time
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
from transformers import Qwen2ForCausalLM

# torch code

In [6]:

# 代码是根据trl仓库改的，因此triton实现也是根据这个仓库的实现方式进行改进的
# 最主要的就是p(x)和p_old(x)是一样的
def get_log_probs(logits, input_ids):
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False):
    # logits通过以下计算得到
    # logits_to_keep = completion_ids.size(1)
    # logits = model(input_ids=input_ids, 
    #             attention_mask=attention_mask,
    #             logits_to_keep=logits_to_keep + 1).logits
    # 传ref_logp（bs*L）而不是ref_logits的原因是，该值可以在inference_mode()下得到，
    # 无需保存中间结果，ref_logits会浪费显存
    assert logits.is_contiguous() and ref_logp.is_contiguous()
    logits = logits[:, :-1] # 错一位，对应下一个输入token的概率         
    per_token_logps = get_log_probs(logits, input_ids) # logits是需要计算梯度，因此会保存中间结果log_probs
    ref_per_token_logps = ref_logp
    per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
    # old_model一步一更新， p(x) 和 p_old(x) 是一样的。
    # 下面这个指数部分等于1，做这一步，是让logits挂到奖励部分的Loss，反向传播时，奖励会对logits产生一部分梯度
    per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
    per_token_loss = -(per_token_loss - beta * per_token_kl)
    if completion_mask is not None:
        per_token_loss *= completion_mask 
        if save_kl:
            per_token_kl *= completion_mask
    return per_token_loss if not save_kl else (per_token_loss, per_token_kl)# 外部进行reduce

# skip mask part

In [None]:
# @triton.autotune([triton.Config({'BLOCK_SIZE': bsz}, num_stages=ns, num_warps=nw)
#                  for bsz in [2048*(2**i) for i in range(5)]
#                  for ns in [1,2,4]
#                  for nw in [8, 16, 32]
#                  ], key=['N']
#                  )
@triton.jit
def _grpo_loss_fwd(LOGITS, REF_LOGP, INPUT_IDS, ADVANTAGES, MASK, BETA,
                    LOSS, LSE, SAVE_KL,
                    M, N, L, INPUT_IDS_START_INDEX,
                    BLOCK_SIZE: tl.constexpr
                    ):
    row_idx = tl.program_id(0)
    # 因为最后一个位置不需要计算，实际上Logits是一个B*(L+1)行的向量，而我们只启动了B*L个程序
    # 比如3*4*vocab_size，每第4个位置不需要计算
    # row_idx从0开始，如果到第2行第一个为止，row_id为3，而真实的行id应该是4。
    # 因此用off_b去记录一个偏移量
    off_b = row_idx // L    
    N = tl.cast(N, tl.int64)

    LOGITS += N * (row_idx + off_b) # 加上偏移量
    REF_LOGP += row_idx
    # 同样input_ids前面介绍时也有多余的prompt部分
    # 比如prompt长度为64，第1行的起始位置应该从64开始
    INPUT_IDS += row_idx + (off_b+1) * INPUT_IDS_START_INDEX
    LOSS += row_idx
    LSE += row_idx
    ADVANTAGES += off_b
    
    MASK += row_idx
    not_skip = tl.load(MASK)# 跳过padding的部分，节约时间
    if not_skip == 1:       # 尤其是output长短不一时，都会pad到最长的那个，会浪费很多计算资源
        base_cols = tl.arange(0, BLOCK_SIZE)
        # 没啥好说的，计算两个lse，online-softmax那一套
        m_i = -float("inf")
        l_i = 0.0
        for start_n in tl.range(0, N, BLOCK_SIZE):
            cols = start_n + base_cols
            mask = cols < N
            logits = tl.load(LOGITS+cols, mask=mask, other=-float('inf')).to(tl.float32)
            m_ij = tl.max(logits)
            new_m_i = tl.maximum(m_i, m_ij)
            l_i = l_i * tl.exp(m_i - new_m_i) + tl.sum(tl.exp(logits - new_m_i))
            m_i = new_m_i
        lse = tl.log(l_i) + m_i

        # 有了lse，直接读取input_ids对应的logits即可，一个标量
        idx = tl.load(INPUT_IDS)
        x = tl.load(LOGITS+idx).to(tl.float32)
        advantage = tl.load(ADVANTAGES).to(tl.float32)
        ref_logp = tl.load(REF_LOGP)
        logp = x - lse
        diff = ref_logp - logp
        kl = tl.exp(diff) - diff - 1
        # 因为我们知道 torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
        # 实际上等于 1 * advantages.unsqueeze(1)
        # loss我们直接减去一个advantage
        loss = kl * BETA - advantage
        tl.store(LOSS, loss)
        tl.store(LSE, lse)
        if SAVE_KL:
            tl.store(LOSS+M, kl)


# @triton.autotune([triton.Config({'BLOCK_SIZE': bsz}, num_stages=ns, num_warps=nw)
#                  for bsz in [2048*(2**i) for i in range(5)]
#                  for ns in [1,2,4]
#                  for nw in [8, 16, 32]
#                  ], key=['N']
#                  )
@triton.jit
def _grpo_loss_bwd(DLOSS, DLOGITS, 
                   LOGITS, REF_LOGP, INPUT_IDS, ADVANTAGES, MASK, BETA,
                    LSE,
                    N, L, INPUT_IDS_START_INDEX,
                    BLOCK_SIZE: tl.constexpr
                    ):
    # 与forward部分如出一辙
    row_idx = tl.program_id(0)
    off_b = row_idx // L
    N = tl.cast(N, tl.int64)

    DLOSS += row_idx
    DLOGITS += N * (row_idx + off_b)
    LOGITS += N * (row_idx + off_b)
    REF_LOGP += row_idx
    INPUT_IDS += row_idx + (off_b+1) * INPUT_IDS_START_INDEX
    LSE += row_idx
    ADVANTAGES += off_b
    base_cols = tl.arange(0, BLOCK_SIZE)

    MASK += row_idx
    not_skip = tl.load(MASK)
    if not_skip == 1:
        dloss = tl.load(DLOSS).to(tl.float32)
        lse = tl.load(LSE)
        idx = tl.load(INPUT_IDS)
        x = tl.load(LOGITS+idx).to(tl.float32)
        advantage = tl.load(ADVANTAGES).to(tl.float32)
        ref_logp = tl.load(REF_LOGP)
        logp = x - lse

        # 算dlogp
        dlogp = (BETA * (-1.0 * tl.exp(ref_logp - logp) + 1) \
                        - advantage) \
                        * dloss

        # 用dlogp再去算dlogits
        for start_n in tl.range(0, N, BLOCK_SIZE):
            cols = start_n + base_cols
            mask = cols < N
            logits = tl.load(LOGITS+cols, mask=mask, other=-float('inf')).to(tl.float32)
            probs = tl.exp(logits - lse)
            dlogits = tl.where(cols==idx, 1-probs, -probs) * dlogp
            # DLOGITS的内存就对应REF_LOGITS，废物再利用
            tl.store(DLOGITS+cols, dlogits, mask=mask)
    else:
        dlogits = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
        for start_n in tl.range(0, N, BLOCK_SIZE):
            cols = start_n + base_cols
            mask = cols < N
            # DLOGITS的内存就对应REF_LOGITS，废物再利用
            tl.store(DLOGITS+cols, dlogits, mask=mask)


class _GrpoLoss(torch.autograd.Function):

    @staticmethod
    def forward(ctx, logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl, inplace):
        # 设计思路：
        # 为什么输入是模型的原始输出，而不是logits[:, :-1]？
        # triton一般需要tensor是连续的，如果不连续，处理起来很麻烦
        # 而logits[:, :-1].contiguous() 会创建一个新的张量，增加显存开销
        # 实际上我们在内部计算时，忽略掉最后一个位置即可
        assert logits.is_contiguous() and ref_logp.is_contiguous()
        ctx.input_shape = logits.shape
        B, L_ADD_1, N = ctx.input_shape
        L = L_ADD_1 - 1 
        M = B * L # 我们实际需要计算的长度是 B * (L + 1 - 1)个行向量户即可
        # input_ids也需要是连续的， 如果是 input_ids[:, -logits_to_keep:]，这就不是连续的了
        # 当然也可以是input_ids[:, -logits_to_keep:].contiguous()，这少一个vocab_size维度，基本无开销
        # 但是我们也可以记录下output的起始位置，跳过prompt部分即可
        input_ids_start_index = input_ids.size(1) - L  
        # 下面都用fp32进行存储，因为都没有vocab_size这个维度，基本无额外显存开销，但是大大提高精度
        if not save_kl:
            loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32) 
        else:
            loss = torch.zeros(B*2, L, device=logits.device, dtype=torch.float32) # 后一半存kl
        lse = torch.empty(B, L, device=logits.device, dtype=torch.float32)  # 等价 max(x) + logsumexp(x)，用于backward的快速计算

        if completion_mask is None:
            completion_mask = torch.ones(B,L, device=logits.device, dtype=torch.int32)
        else:
            loss[:B].masked_fill_(completion_mask.logical_not(), 0)
        kwargs = {'BLOCK_SIZE': 8192, 'num_warps': 8, 'num_stages':1}
        _grpo_loss_fwd[(M,)](logits, ref_logp, input_ids, advantages, completion_mask, beta,
                            loss, lse, save_kl,
                            M, N, L, input_ids_start_index,
                            **kwargs,
                            )
        ctx.beta = beta
        ctx.save_for_backward(lse, logits, input_ids, advantages, completion_mask)
        ctx.ref_logp = ref_logp
        ctx.inplace = inplace
        return loss
    
    @staticmethod
    def backward(ctx, dloss):
        # logits对应的grad来自两个部分，reward部分和kl部分
        # print(dloss.view(-1).stride(), dloss.shape)
        if not dloss.is_contiguous():
            dloss = dloss.contiguous()
        
        lse, logits, input_ids, advantages, completion_mask = ctx.saved_tensors
        B, L_ADD_1, N = ctx.input_shape
        L = L_ADD_1 - 1
        M = B * L
        input_ids_start_index = input_ids.size(1) - L
        # 实际上当我们读取一些logits的值后，这个张量就一点用都没有了
        # 我们直接把logits的grad用logits存储，直接废物再利用，节省显存
        dlogits = logits if ctx.inplace else torch.empty_like(logits)
        kwargs = {'BLOCK_SIZE': 8192, 'num_warps': 32, 'num_stages':4}
        _grpo_loss_bwd[(M,)](dloss, dlogits, 
                            logits, ctx.ref_logp, input_ids, advantages, completion_mask, ctx.beta,
                            lse,
                            N, L, input_ids_start_index,
                            **kwargs
                                )
        # 最后一个位置的token并没有参与计算，梯度需要设置为0
        # 因为empty的初始化或者ref_logits的初始化，该位置都不是0，需要手动设置下
        dlogits[:, -1, :] = 0
        return dlogits.view(*ctx.input_shape), None, None, None, None, None, None, None

def triton_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False, inplace=True) -> torch.Tensor:
    '''
    compute grpo loss, save memory(no addition usage) and fast speed(6X for A800)

    Args:
        logtits: Tensor, [B, L+1, vocab_size], the origin output of model, it's not logits[:, :-1]
        ref_logp: Tensor, [B, L], the origin output of model, it's not ref_logits[:, :-1]
        input_ids: Tensor, [B, K+L], it's prompt_completion_id, it contains the prompt ids and output ids
        advantages: Tensor, [B], the advantages of each prompt
        beta: float, the weight of kl loss
        completion_mask: Tensor, loss mask
        save_kl: bool, if true will save kl
        inplace: bool, if true, in backward use logits to store the logits's grad, it can save memory

    Retutn:
        loss: Tensor, [B, L], the loss of grpo, it contains the advantage part and kl part

    NOTE: logits(ref_logits) is computed by these steps
        logits_to_keep = completion_ids.size(1)

        def get_per_token_logits(model, input_ids, attention_mask, logits_to_keep):
            # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
            logits = model(
                input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
            ).logits
            return logits
            
        logits = get_per_token_logits(model, prompt_completion_ids, attention_mask, logits_to_keep)
    '''
    out = _GrpoLoss.apply(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl, inplace)
    if not save_kl:
        return out
    else:
        return out.chunk(2, axis=0)
    
def get_random_ref_log_probs(logits, input_ids):
    with torch.inference_mode():
        logits = logits[:,:-1]
        per_token_logps = []
        for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):
            log_probs = torch.randn_like(logits_row).log_softmax(dim=-1)
            token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
            per_token_logps.append(token_log_prob)
        torch.cuda.empty_cache()
        return torch.stack(per_token_logps)

# 精度测试

In [25]:
dtype = torch.bfloat16
device = 'cuda'
bs, seq_len, vocab_size = 8, 1024, 150000
logits = torch.randn(bs, seq_len + 1, vocab_size, device=device, dtype=dtype) # 最后一个位置是eos token的logits，计算时会扔掉
logits.requires_grad_(True)
copy_logits = deepcopy(logits)
advantages = torch.randn(bs, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size-1, (bs, seq_len + 64), device=device) # 64是随便设置的，表示prompt ids的长度，剩下是output
ref_logp = get_random_ref_log_probs(logits, input_ids)
beta = 0.04
completion_mask = torch.ones(bs, seq_len, dtype=torch.int32, device=device)
completion_mask[::2, seq_len//2:] = 0  # 假设有一半的后半部分都是padding
save_kl = True

gold_logits = logits.detach().clone().float()
gold_logits.requires_grad_(True)
gold_ref_logp= deepcopy(ref_logp).float()


In [26]:
y2 = triton_grpo_loss(copy_logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl=False)
y2.sum().backward()
# (y2.sum(-1) / completion_mask.sum(-1)).mean().backward()

(0,) torch.Size([8, 1024])


In [None]:
torch.cuda.empty_cache()
y1 = torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl)           # torch bf16
y2 = triton_grpo_loss(copy_logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl=save_kl)     # triton
y3 = torch_grpo_loss(gold_logits, gold_ref_logp, input_ids, advantages, beta, completion_mask, save_kl) # torch fp32
if save_kl:
    y1, kl1 = y1
    y2, kl2 = y2
    y3, kl3 = y3
    print('='*50 + ' KL:')
    print((kl1-kl3).abs().max(), (kl1-kl3).abs().mean())  # kl, torch bf16 vs torch fp32,
    print((kl2-kl3).abs().max(), (kl2-kl3).abs().mean())  # kl, triton vs torch fp32,
dy = torch.randn_like(y1)
y1.backward(dy)
y2.backward(dy)
y3.backward(dy.float())
print('='*50 + ' Loss:')
print((y1-y3).abs().max(), (y1-y3).abs().mean())  # fwd, torch bf16 vs torch fp32
print((y2-y3).abs().max(), (y2-y3).abs().mean())  # fwd, triton vs torch fp32
print('='*50 + ' Grad:')
print((logits.grad - gold_logits.grad).abs().max(), (logits.grad - gold_logits.grad).abs().mean()) # bwd, torch bf16 vs torch fp32
print((copy_logits.grad - gold_logits.grad).abs().max(), (copy_logits.grad - gold_logits.grad).abs().mean()) # bwd, triton vs torch fp32
# 多尝试几次，使用triton计算的结果更精确， 误差更小一些

tensor(3.7727, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.0271, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0003, device='cuda:0', grad_fn=<MaxBackward1>) tensor(2.6195e-07, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.1397, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.0011, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.2875e-05, device='cuda:0', grad_fn=<MaxBackward1>) tensor(1.1167e-08, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.1768, device='cuda:0') tensor(5.9009e-08, device='cuda:0')
tensor(0.0132, device='cuda:0') tensor(8.0266e-09, device='cuda:0')


In [6]:
print(triton.testing.do_bench(lambda:torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta)))
print(triton.testing.do_bench(lambda:triton_grpo_loss(copy_logits, ref_logp, input_ids, advantages, beta, completion_mask)))

2.588759660720825
Triton autotuning for function _grpo_loss_fwd finished after 12.51s; best config selected: BLOCK_SIZE: 8192, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;
0.514387845993042


In [8]:
# 重新生成数据，直接运行这个代码
y1 = torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta)
y2 = triton_grpo_loss(copy_logits, ref_logp, input_ids, advantages, beta, completion_mask, inplace=False)
dy = torch.randn_like(y1)
print(triton.testing.do_bench(lambda:y1.backward(dy, retain_graph=True), grad_to_none=[logits]))
print(triton.testing.do_bench(lambda:y2.backward(dy, retain_graph=True), grad_to_none=[copy_logits]))

11.383981704711914
Triton autotuning for function _grpo_loss_bwd finished after 13.34s; best config selected: BLOCK_SIZE: 8192, num_warps: 32, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;
1.1460970640182495


# 显存

In [None]:
# 刷新运行
dtype = torch.bfloat16
device = 'cuda'
bs, seq_len, vocab_size = 8, 2048, 150000
logits = torch.randn(bs, seq_len + 1, vocab_size, device=device, dtype=dtype)
logits.requires_grad_(True)
advantages = torch.randn(bs, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size-1, (bs, seq_len + 64), device=device)
ref_logp = get_random_ref_log_probs(logits, input_ids)
beta = 0.04
iters = 500
dy = torch.randn(bs, seq_len, device=device, dtype=dtype)
factor = 4 if dtype == torch.float32 else 2
print('logits显存占用：',(bs * (seq_len+1) * vocab_size) / (1024)**3 * factor,"G")
time.sleep(3) # 初始化时观察显存，可以用nvitop
for i in tqdm(range(iters)):
    y = torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta)
    y.backward(dy)
    logits.grad = None
# 5.7G -> 24.6 G

logits显存占用： 4.579871892929077 G


100%|██████████| 500/500 [00:12<00:00, 38.89it/s]


In [None]:
# 刷新运行
dtype = torch.bfloat16
device = 'cuda'
bs, seq_len, vocab_size = 8, 2048, 150000
logits = torch.randn(bs, seq_len + 1, vocab_size, device=device, dtype=dtype)
logits.requires_grad_(True)
advantages = torch.randn(bs, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size-1, (bs, seq_len + 64), device=device)
ref_logp = get_random_ref_log_probs(logits, input_ids)
beta = 0.04
completion_mask = torch.ones(bs, seq_len, dtype=torch.int32, device=device)
completion_mask[::2, seq_len//2:] = 0 
iters = 2000
dy = torch.randn(bs, seq_len, device=device, dtype=dtype)
factor = 4 if dtype == torch.float32 else 2
print('logits显存占用：',(bs * (seq_len+1) * vocab_size) / (1024)**3 * factor,"G")
time.sleep(3) # 初始化时观察显存，可以用nvitop
pbar = tqdm(total=iters)
for i in range(iters):
    y = triton_grpo_loss(logits, ref_logp, input_ids, advantages, beta, completion_mask, inplace=True)
    y.backward(dy)
    logits.grad = None
    pbar.update(1)
# 5.7G -> 5.7G, 基本无任何额外开销

logits显存占用： 4.579871892929077 G


 99%|█████████▉| 1981/2000 [00:08<00:00, 258.84it/s]

In [6]:
ref_logp = get_random_ref_log_probs(logits, input_ids).clone()

In [7]:
torch.cuda.empty_cache()

In [6]:
5800 / 1024

5.6640625