diff --git a/torch_struct/test_distributions.py b/torch_struct/test_distributions.py index 39b3adec..53a6b749 100644 --- a/torch_struct/test_distributions.py +++ b/torch_struct/test_distributions.py @@ -46,8 +46,8 @@ def test_simple(data, seed): marginals = dist.marginals assert ((samples.mean(0) - marginals).abs() < 0.2).all() - kmax = dist.kmax(5) - count = dist.count + dist.kmax(5) + dist.count @given(data(), integers(min_value=1, max_value=20))