diff --git a/ivy/functional/backends/jax/statistical.py b/ivy/functional/backends/jax/statistical.py index bff5fca9a2d1e..d03e9411f2561 100644 --- a/ivy/functional/backends/jax/statistical.py +++ b/ivy/functional/backends/jax/statistical.py @@ -147,7 +147,7 @@ def var( # ------# -@with_unsupported_dtypes({"0.4.23 and below": "bfloat16"}, backend_version) +@with_unsupported_dtypes({"0.4.23 and below": ("bfloat16", "bool")}, backend_version) def cumprod( x: JaxArray, /, diff --git a/ivy/functional/backends/numpy/statistical.py b/ivy/functional/backends/numpy/statistical.py index cba78d4124988..bffc814767cc3 100644 --- a/ivy/functional/backends/numpy/statistical.py +++ b/ivy/functional/backends/numpy/statistical.py @@ -178,7 +178,7 @@ def var( # ------# -@with_unsupported_dtypes({"1.26.3 and below": ("bfloat16",)}, backend_version) +@with_unsupported_dtypes({"1.26.3 and below": ("bfloat16", "bool")}, backend_version) def cumprod( x: np.ndarray, /, diff --git a/ivy/functional/backends/tensorflow/statistical.py b/ivy/functional/backends/tensorflow/statistical.py index f111a2804e5a1..95db57526ef9b 100644 --- a/ivy/functional/backends/tensorflow/statistical.py +++ b/ivy/functional/backends/tensorflow/statistical.py @@ -178,7 +178,7 @@ def var( # ------# -@with_unsupported_dtypes({"2.15.0 and below": "bfloat16"}, backend_version) +@with_unsupported_dtypes({"2.15.0 and below": ("bfloat16", "bool")}, backend_version) def cumprod( x: Union[tf.Tensor, tf.Variable], /, diff --git a/ivy/functional/backends/torch/statistical.py b/ivy/functional/backends/torch/statistical.py index 76030c794dfb6..eb4666237ae0c 100644 --- a/ivy/functional/backends/torch/statistical.py +++ b/ivy/functional/backends/torch/statistical.py @@ -245,7 +245,7 @@ def var( # TODO: bfloat16 support is added in PyTorch 1.12.1 @with_unsupported_dtypes( { - "2.1.2 and below": ("uint8", "float16", "bfloat16"), + "2.1.2 and below": ("uint8", "float16", "bfloat16", "bool"), }, backend_version, ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py b/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py index 4d5066ded502d..7c6cd30e28a00 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py @@ -108,6 +108,7 @@ def _statistical_dtype_values(draw, *, function, min_value=None, max_value=None) dtype_x_axis_castable=_get_castable_dtype(), exclusive=st.booleans(), reverse=st.booleans(), + test_gradients=st.just(False), ) def test_cumprod( *,