Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.util.ops import selective_log_softmax
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer
Expand Down Expand Up @@ -137,16 +137,6 @@ def collate(batches: list[list[Episode]]):
return inputs, targets


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]
logits = logits[:, context_length - 1 : -1].to(input_ids.device)
scaled_logits = logits / temperature
logprobs = selective_log_softmax(scaled_logits, input_ids)
return logprobs


def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
Expand All @@ -155,7 +145,12 @@ def simple_grpo_loss(
padding_mask: torch.Tensor,
beta: float = 0.1,
) -> torch.Tensor:
logprobs = compute_logprobs(logits, response)
"""
Example GRPO Loss Function for RLTrainer
"""
logprobs: torch.Tensor = compute_logprobs(logits, response)

# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - beta * kl)
Expand Down
23 changes: 23 additions & 0 deletions src/forge/util/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,26 @@ def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Te
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
"""
Computes the log probabilities of the input tokens given the model logits and temperature.

Args:
logits (`torch.Tensor`):
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
input_ids (`torch.Tensor`):
The input token ids of shape `(batch_size, target_sequence_length)`.
temperature (`float`, *optional*, defaults to 1.0):
The temperature value for scaling logits before computing log probabilities.

"""
# Ignore the last token from logits because it predicts the next token (-1)
# And align logits with the input tokens length.
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
scaled_logits = logits / temperature
logprobs = selective_log_softmax(scaled_logits, input_ids)
return logprobs
111 changes: 111 additions & 0 deletions tests/unit_tests/util/test_compute_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
import torch.nn.functional as F
from forge.util.ops import compute_logprobs


def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor):
# Helper: Textbook Log Softmax
log_probs = F.log_softmax(logits, dim=-1)
return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)


class TestComputeLogprobs:
def test_single_batch_item(self):
"""Test with single batch item."""
# Shape: (1, 2, 3)
logits = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
# Shape: (1, 1)
input_ids = torch.tensor([[1]])
result = compute_logprobs(logits, input_ids)

# Manual calculation
expected_logits = torch.tensor([[[1.0, 2.0, 3.0]]])
expected = _textbook_log_softmax(expected_logits, input_ids)

assert torch.allclose(result, expected, atol=1e-5)
assert result.shape == (1, 1)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Shape: (1, 3, 3)
logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])
# Shape: (1, 2)
input_ids = torch.tensor([[2, 0]])
result = compute_logprobs(logits, input_ids)

# Manual calculation
expected_logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]])
expected = _textbook_log_softmax(expected_logits, input_ids)

assert torch.allclose(result, expected, atol=1e-5)
assert result.shape == (1, 2)

@pytest.mark.timeout(10)
def test_multi_batch(self):
"""Test with multiple batch items."""
# Shape: (2, 2, 3)
logits = torch.tensor(
[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[0.5, 1.5, 2.5], [3.5, 4.5, 5.5]]]
)
# Shape: (2, 1)
input_ids = torch.tensor([[1], [2]])
result = compute_logprobs(logits, input_ids)

# Manual calculation
expected_logits = torch.tensor([[[1.0, 2.0, 3.0]], [[0.5, 1.5, 2.5]]])
expected = _textbook_log_softmax(expected_logits, input_ids)

assert torch.allclose(result, expected, atol=1e-5)
assert result.shape == (2, 1)

@pytest.mark.timeout(10)
def test_temperature(self):
"""Test with different temperature values."""
batch_size, seq_len, vocab_size = 2, 4, 6
logits = torch.randn(batch_size, seq_len, vocab_size)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len - 1))

# Manual calculation with temperature scaling
def _manual(temperature: float):
expected_logits = logits[:, 0:-1] / temperature
return _textbook_log_softmax(expected_logits, input_ids)

temperatures = [1.0, 2.0, 4.5]
for temperature in temperatures:
result = compute_logprobs(logits, input_ids, temperature=temperature)
expected = _manual(temperature)
assert torch.allclose(result, expected, atol=1e-5)
assert result.shape == input_ids.shape

@pytest.mark.timeout(10)
def test_edge_cases(self):
"""Test edge cases."""
# Test with very large values (numerical stability)
logits = torch.tensor([[[1000.0, 2000.0], [1500.0, 2500.0]]])
input_ids = torch.tensor([[0]])
result = compute_logprobs(logits, input_ids)
# Should not be NaN or inf
assert torch.isfinite(result).all()

# Test with very small values
logits = torch.tensor([[[-1000.0, -2000.0], [-1500.0, -2500.0]]])
input_ids = torch.tensor([[1]])
result = compute_logprobs(logits, input_ids)
# Should not be NaN or inf
assert torch.isfinite(result).all()

def test_compute_logprobs_empty_response(self):
"""Test logprobs computation with empty response."""
batch_size, seq_len, vocab_size = 1, 5, 1000
logits = torch.randn(batch_size, seq_len, vocab_size)
input_ids = torch.tensor([[]])

result = compute_logprobs(logits, input_ids)
assert result.shape == (batch_size, 0)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from forge.util.ops import selective_log_softmax


class TestOps:
class TestSelectiveLogSoftmax:
@pytest.mark.timeout(10)
def test_basic_2d(self):
"""Test basic 2D case."""
Expand Down
Loading