Skip to content

Commit

Permalink
test: add test for state_dict and b
Browse files Browse the repository at this point in the history
  • Loading branch information
jmclong committed Oct 22, 2022
1 parent ea9b07d commit 836960b
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def test_gaussian_encoding(device):
decimal=5)


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_gaussian_encoding_register_buffer(device):
check_cuda(device)
b = rff.functional.sample_b(1.0, (256, 2)).to(device)
layer = rff.layers.GaussianEncoding(b=b).to(device)
assert 'b' in layer.state_dict()
np.testing.assert_almost_equal(
layer.state_dict()['b'].cpu().numpy(),
b.cpu().numpy(),
decimal=5)


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_basic_encoding(device):
check_cuda(device)
Expand Down

0 comments on commit 836960b

Please sign in to comment.