Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions tests/unit_tests/losses/test_grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def sample_data(self):
return logprobs, ref_logprobs, advantages, padding_mask

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_forward_basic(self, loss_fn, sample_data):
"""Test basic forward pass."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
Expand All @@ -48,7 +47,6 @@ def test_forward_basic(self, loss_fn, sample_data):
assert not torch.isnan(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_output_shape(self, loss_fn):
"""Test output shape for different input sizes."""
for batch_size in [1, 3, 8]:
Expand All @@ -62,7 +60,6 @@ def test_output_shape(self, loss_fn):
assert loss.shape == torch.Size([])

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_gradient_flow(self, loss_fn, sample_data):
"""Test that gradients flow through logprobs."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
Expand All @@ -76,7 +73,6 @@ def test_gradient_flow(self, loss_fn, sample_data):
assert torch.isfinite(logprobs.grad).all()

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_no_gradient_to_ref_logprobs(self, loss_fn, sample_data):
"""Test that gradients don't flow to reference logprobs."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
Expand All @@ -89,7 +85,6 @@ def test_no_gradient_to_ref_logprobs(self, loss_fn, sample_data):
assert ref_logprobs.grad is not None

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_padding_mask_effect(self, loss_fn):
"""Test that padding mask correctly ignores padded tokens."""
batch_size, seq_len = 2, 4
Expand All @@ -111,7 +106,6 @@ def test_padding_mask_effect(self, loss_fn):
assert not torch.allclose(loss_full, loss_partial)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_beta_parameter_effect(self, sample_data):
"""Test that different beta values produce different losses."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
Expand All @@ -128,7 +122,6 @@ def test_beta_parameter_effect(self, sample_data):
assert not torch.allclose(loss_1, loss_2, atol=1e-6)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_zero_advantages(self, loss_fn):
"""Test behavior with zero advantages."""
batch_size, seq_len = 2, 4
Expand All @@ -144,7 +137,6 @@ def test_zero_advantages(self, loss_fn):
assert torch.isfinite(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_identical_policies(self, loss_fn):
"""Test behavior when current and reference policies are identical."""
batch_size, seq_len = 2, 4
Expand All @@ -160,7 +152,6 @@ def test_identical_policies(self, loss_fn):
assert torch.isfinite(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_extreme_values(self, loss_fn):
"""Test with extreme but valid values."""
batch_size, seq_len = 2, 3
Expand All @@ -179,7 +170,6 @@ def test_extreme_values(self, loss_fn):
assert not torch.isnan(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_numerical_stability(self, loss_fn):
"""Test numerical stability with edge cases."""
batch_size, seq_len = 1, 2
Expand All @@ -195,7 +185,6 @@ def test_numerical_stability(self, loss_fn):
assert torch.isfinite(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_all_masked_sequence(self, loss_fn):
"""Test behavior when entire sequence is masked."""
batch_size, seq_len = 1, 3
Expand All @@ -211,7 +200,6 @@ def test_all_masked_sequence(self, loss_fn):
assert torch.isfinite(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_mathematical_correctness(self, loss_fn):
"""Test mathematical correctness with simpler verification."""
# Test with known simple case
Expand Down
3 changes: 1 addition & 2 deletions tests/unit_tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,7 @@ async def test_broadcast_fanout_vs_route():
# Router Tests


@pytest.mark.asyncio
async def test_session_router_with_round_robin_fallback():
def test_session_router_with_round_robin_fallback():
"""Switch fallback router to round-robin and verify assignment order."""
# Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas
replicas = [make_replica(0, load=0), make_replica(1, load=5)]
Expand Down
Loading