Skip to content

Commit

Permalink
Use torch.testing in unit tests
Browse files Browse the repository at this point in the history
Summary: Regular asserts like `self.assertTrue(torch.equal(actions[i], action))` don't print an informative message when they fail. Replacing with torch.testing asserts, which print out useful information for debugging

Reviewed By: rodrigodesalvobraz

Differential Revision: D56644803

fbshipit-source-id: 82916ed5e63193c425f0bdf491a809d88846374c
  • Loading branch information
Alex Nikulkov authored and facebook-github-bot committed Apr 27, 2024
1 parent b9ef4ed commit 0dfd91f
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 163 deletions.
3 changes: 2 additions & 1 deletion test/unit/with_pytorch/test_discrete_action_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import unittest

import torch
import torch.testing as tt
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace


Expand All @@ -19,4 +20,4 @@ def test_iter(self) -> None:
actions = [torch.randn(4) for _ in range(5)]
action_space = DiscreteActionSpace(actions=actions)
for i, action in enumerate(action_space):
self.assertTrue(torch.equal(actions[i], action))
tt.assert_close(actions[i], action, rtol=0.0, atol=0.0)
31 changes: 14 additions & 17 deletions test/unit/with_pytorch/test_disjoint_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import unittest

import torch
import torch.testing as tt
from parameterized import parameterized_class
from pearl.policy_learners.contextual_bandits.disjoint_bandit import (
DisjointBanditContainer,
Expand Down Expand Up @@ -73,14 +74,13 @@ def test_learn_batch(self) -> None:
for i, action in enumerate(self.batch.action):
action = action.item()
# check if linear regression works
self.assertTrue(
torch.allclose(
self.policy_learner._linear_regressions[action](
self.batch.state[i : i + 1]
),
self.batch.reward[i : i + 1],
atol=1e-1,
)
tt.assert_close(
self.policy_learner._linear_regressions[action](
self.batch.state[i : i + 1]
),
self.batch.reward[i : i + 1],
atol=1e-1,
rtol=0.0,
)

def test_ucb_act(self) -> None:
Expand Down Expand Up @@ -245,14 +245,11 @@ def test_learn_batch(self) -> None:
for i, action in enumerate(self.batch.action):
action = action.item()
# check if each arm model works
self.assertTrue(
torch.allclose(
policy_learner._arm_bandits[action].model(
self.batch.state[i : i + 1]
),
self.batch.reward[i : i + 1],
atol=1e-1,
)
tt.assert_close(
policy_learner._arm_bandits[action].model(self.batch.state[i : i + 1]),
self.batch.reward[i : i + 1],
atol=1e-1,
rtol=0.0,
)

def test_ucb_act(self) -> None:
Expand Down Expand Up @@ -404,7 +401,7 @@ def test_get_scores(self) -> None:
sigmas = model.calculate_sigma(features)
expected_scores.append(mus + alpha * sigmas)
expected_scores = torch.cat(expected_scores, dim=1)
self.assertTrue(torch.allclose(scores, expected_scores, atol=1e-1))
tt.assert_close(scores, expected_scores, atol=1e-1, rtol=0.0)

def test_learn_batch_arm_subset(self) -> None:
# test that learn_batch still works when the batch has a subset of arms
Expand Down
111 changes: 55 additions & 56 deletions test/unit/with_pytorch/test_dynamic_action_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import unittest

import torch
import torch.testing as tt
from pearl.action_representation_modules.one_hot_action_representation_module import (
OneHotActionTensorRepresentationModule,
)
Expand Down Expand Up @@ -67,35 +68,33 @@ def test_basic(self) -> None:
current_available_actions = batch.curr_available_actions
current_available_actions_mask = batch.curr_unavailable_actions_mask
self.assertIsNotNone(current_available_actions)
self.assertTrue(
torch.equal(
current_available_actions,
torch.tensor([[[0.0], [2.0], [4.0], [0.0], [0.0]]]),
)
tt.assert_close(
current_available_actions,
torch.tensor([[[0.0], [2.0], [4.0], [0.0], [0.0]]]),
rtol=0.0,
atol=0.0,
)
self.assertIsNotNone(current_available_actions_mask)
self.assertTrue(
torch.equal(
current_available_actions_mask,
torch.tensor([[False, False, False, True, True]]),
)
tt.assert_close(
current_available_actions_mask,
torch.tensor([[False, False, False, True, True]]),
rtol=0.0,
atol=0.0,
)

next_available_actions = batch.next_available_actions
next_unavailable_actions_mask = batch.next_unavailable_actions_mask
self.assertIsNotNone(next_available_actions)
self.assertTrue(
torch.equal(
next_available_actions,
torch.tensor([[[0.0], [3.0], [0.0], [0.0], [0.0]]]),
)
tt.assert_close(
next_available_actions,
torch.tensor([[[0.0], [3.0], [0.0], [0.0], [0.0]]]),
rtol=0.0,
atol=0.0,
)
self.assertIsNotNone(next_unavailable_actions_mask)
self.assertTrue(
torch.equal(
next_unavailable_actions_mask,
torch.tensor([[False, False, True, True, True]]),
)
tt.assert_close(
next_unavailable_actions_mask,
torch.tensor([[False, False, True, True, True]]),
)

policy_learner = DeepQLearning(
Expand All @@ -109,53 +108,53 @@ def test_basic(self) -> None:
current_available_actions = batch.curr_available_actions
current_unavailable_actions_mask = batch.curr_unavailable_actions_mask
self.assertIsNotNone(current_available_actions)
self.assertTrue(
torch.equal(
current_available_actions,
torch.tensor(
tt.assert_close(
current_available_actions,
torch.tensor(
[
[
[
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
]
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
]
),
)
]
),
rtol=0.0,
atol=0.0,
)
self.assertIsNotNone(current_unavailable_actions_mask)
self.assertTrue(
torch.equal(
current_unavailable_actions_mask,
torch.tensor([[False, False, False, True, True]]),
)
tt.assert_close(
current_unavailable_actions_mask,
torch.tensor([[False, False, False, True, True]]),
rtol=0.0,
atol=0.0,
)

next_available_actions = batch.next_available_actions
next_unavailable_actions_mask = batch.next_unavailable_actions_mask
self.assertIsNotNone(next_available_actions)
self.assertTrue(
torch.equal(
next_available_actions,
torch.tensor(
tt.assert_close(
next_available_actions,
torch.tensor(
[
[
[
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
]
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
]
),
)
]
),
rtol=0.0,
atol=0.0,
)
self.assertIsNotNone(next_unavailable_actions_mask)
self.assertTrue(
torch.equal(
next_unavailable_actions_mask,
torch.tensor([[False, False, True, True, True]]),
)
tt.assert_close(
next_unavailable_actions_mask,
torch.tensor([[False, False, True, True, True]]),
rtol=0.0,
atol=0.0,
)
10 changes: 6 additions & 4 deletions test/unit/with_pytorch/test_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

import unittest

import numpy.testing as npt
import torch

import torch.testing as tt
from pearl.neural_networks.common.epistemic_neural_networks import Ensemble
from pearl.neural_networks.common.utils import ensemble_forward
from torch import optim
Expand Down Expand Up @@ -49,10 +50,11 @@ def test_ensemble_values(self) -> None:
for_loop_values = ensemble_forward(self.network.models, x, use_for_loop=True)
vectorized_values = ensemble_forward(self.network.models, x, use_for_loop=False)
self.assertEqual(for_loop_values.shape, vectorized_values.shape)
npt.assert_allclose(
for_loop_values.detach().numpy(),
vectorized_values.detach().numpy(),
tt.assert_close(
for_loop_values,
vectorized_values,
atol=1e-5,
rtol=0.0,
)

def test_ensemble_optimization(self) -> None:
Expand Down
45 changes: 24 additions & 21 deletions test/unit/with_pytorch/test_fifo_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import unittest

import torch
import torch.testing as tt

from pearl.replay_buffers.sequential_decision_making.fifo_on_policy_replay_buffer import (
FIFOOnPolicyReplayBuffer,
Expand Down Expand Up @@ -65,32 +66,34 @@ def test_on_poliy_buffer_sarsa_match(self) -> None:
)
# expect S0 A0 R0 S1 A1 returned from sample
batch = replay_buffer.sample(1)
self.assertTrue(
torch.equal(
batch.state,
self.states[0].view(1, -1),
)
tt.assert_close(
batch.state,
self.states[0].view(1, -1),
rtol=0.0,
atol=0.0,
)
self.assertTrue(
torch.equal(
batch.action,
torch.tensor([self.actions[0]]),
)
tt.assert_close(
batch.action,
torch.tensor([self.actions[0]]),
rtol=0.0,
atol=0.0,
)
tt.assert_close(
batch.reward, torch.tensor([self.rewards[0]]), rtol=0.0, atol=0.0
)
self.assertTrue(torch.equal(batch.reward, torch.tensor([self.rewards[0]])))
assert (batch_next_state := batch.next_state) is not None
self.assertTrue(
torch.equal(
batch_next_state,
self.next_states[0].view(1, -1),
)
tt.assert_close(
batch_next_state,
self.next_states[0].view(1, -1),
rtol=0.0,
atol=0.0,
)
assert (batch_next_action := batch.next_action) is not None
self.assertTrue(
torch.equal(
batch_next_action,
torch.tensor([self.actions[1]]),
)
tt.assert_close(
batch_next_action,
torch.tensor([self.actions[1]]),
rtol=0.0,
atol=0.0,
)

def test_on_poliy_buffer_ternimal_push(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Dict

import torch
import torch.testing as tt

from pearl.replay_buffers.sequential_decision_making.hindsight_experience_replay_buffer import (
HindsightExperienceReplayBuffer,
Expand Down Expand Up @@ -89,6 +90,6 @@ def reward_fn(state: torch.Tensor, action: torch.Tensor) -> int:
assert (batch_state := batch.state) is not None
assert (batch_next_state := batch.next_state) is not None
for i in range(2 * len(states) - 2):
self.assertTrue(
torch.all(torch.eq(batch_state[i][-2:], batch_next_state[i][-2:]))
tt.assert_close(
batch_state[i][-2:], batch_next_state[i][-2:], rtol=0.0, atol=0.0
)
41 changes: 17 additions & 24 deletions test/unit/with_pytorch/test_linear_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import unittest

import torch
import torch.testing as tt
from pearl.neural_networks.contextual_bandit.linear_regression import LinearRegression
from pearl.policy_learners.contextual_bandits.linear_bandit import LinearBandit
from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import ( # noqa: E501
Expand Down Expand Up @@ -52,24 +53,20 @@ def setUp(self) -> None:
def test_learn(self) -> None:
batch = self.batch
# a single input
self.assertTrue(
torch.allclose(
self.policy_learner.model(
torch.cat([batch.state[0], batch.action[0]]).unsqueeze(0),
),
batch.reward[0:1],
atol=1e-4,
)
tt.assert_close(
self.policy_learner.model(
torch.cat([batch.state[0], batch.action[0]]).unsqueeze(0),
),
batch.reward[0:1],
atol=1e-3,
rtol=0.0,
)
# a batch input
self.assertTrue(
torch.allclose(
self.policy_learner.model(
torch.cat([batch.state, batch.action], dim=1)
),
batch.reward,
atol=1e-4,
)
tt.assert_close(
self.policy_learner.model(torch.cat([batch.state, batch.action], dim=1)),
batch.reward,
atol=1e-3,
rtol=0.0,
)

def test_linear_ucb_scores(self) -> None:
Expand Down Expand Up @@ -166,14 +163,10 @@ def test_linear_ucb_sigma(self) -> None:
)

# the 2nd arm's sigma is sqrt(10) times 1st arm's sigma
sigma_ratio = (sigma[-1] / sigma[0]).clone().detach()
self.assertTrue(
torch.allclose(
sigma_ratio,
torch.tensor(10.0**0.5), # the 1st arm occured 10 times than 2nd arm
rtol=0.01,
)
)
sigma_ratio = (sigma[-1] / sigma[0]).detach().item()
self.assertAlmostEqual(
sigma_ratio, 10.0**0.5, delta=0.01
) # the 1st arm occured 10 times than 2nd arm

def test_linear_thompson_sampling_act(self) -> None:
"""
Expand Down
Loading

0 comments on commit 0dfd91f

Please sign in to comment.