In [2]:
import torch


def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor,
                                    response_mask: torch.Tensor):
    """
    Compute advantage for ReMax, operating only on Outcome reward 
    This implementation is based on the paper: https://arxiv.org/abs/2310.10505

    (with only one scalar reward for each response).
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        reward_baselines: `(torch.Tensor)`
            shape: (bs,)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
    
    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """

    with torch.no_grad():
        returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
        advantages = returns - reward_baselines.unsqueeze(-1) * response_mask

    return advantages, returns

token_level_rewards = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
reward_baselines = torch.tensor([6.0, 15.0])
response_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])

advantages, returns = compute_remax_outcome_advantage(token_level_rewards, reward_baselines, response_mask)
print("Advantages:", advantages)
print("Returns:", returns)


Advantages: tensor([[  0.,  -1.,  -3.],
        [ -6., -10.,   0.]])
Returns: tensor([[6., 5., 3.],
        [9., 5., 0.]])
