diff --git a/test/test_layers.py b/test/test_layers.py index 6b97ff8..5f295d9 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -29,6 +29,18 @@ def test_gaussian_encoding(device): decimal=5) +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +def test_gaussian_encoding_register_buffer(device): + check_cuda(device) + b = rff.functional.sample_b(1.0, (256, 2)).to(device) + layer = rff.layers.GaussianEncoding(b=b).to(device) + assert 'b' in layer.state_dict() + np.testing.assert_almost_equal( + layer.state_dict()['b'].cpu().numpy(), + b.cpu().numpy(), + decimal=5) + + @pytest.mark.parametrize('device', ['cpu', 'cuda']) def test_basic_encoding(device): check_cuda(device)