diff --git a/README.md b/README.md index e66aa14..45428fa 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ The scoring rules available in `scoringrules` include, but are not limited to, t ## Features - **Fast** computation of several probabilistic univariate and multivariate verification metrics -- **Multiple backends**: support for numpy (accelerated with numba), jax, pytorch and tensorflow +- **Multiple backends**: support for numpy (accelerated with numba), jax and pytorch - **Didactic approach** to probabilistic forecast evaluation through clear code and documentation ## Installation diff --git a/docs/index.md b/docs/index.md index 7296b79..bb8e3b1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -53,7 +53,7 @@ reference.md ## Features - **Fast** computation of several probabilistic univariate and multivariate verification metrics -- **Multiple backends**: support for numpy (accelerated with numba), jax, pytorch and tensorflow +- **Multiple backends**: support for numpy (accelerated with numba), jax and pytorch - **Didactic approach** to probabilistic forecast evaluation through clear code and documentation ## Installation diff --git a/pyproject.toml b/pyproject.toml index 1310235..4121699 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,6 @@ jax = [ torch = [ "torch>=2.4.1", ] -tensorflow = [ - "tensorflow>=2.17.0", -] [dependency-groups] dev = [ diff --git a/scoringrules/backend/registry.py b/scoringrules/backend/registry.py index 7c4aae5..a4cd50b 100644 --- a/scoringrules/backend/registry.py +++ b/scoringrules/backend/registry.py @@ -4,7 +4,6 @@ from .base import ArrayBackend from .jax import JaxBackend from .numpy import NumbaBackend, NumpyBackend -from .tensorflow import TensorflowBackend from .torch import TorchBackend if tp.TYPE_CHECKING: @@ -14,7 +13,6 @@ "jax": JaxBackend, "numpy": NumpyBackend, "torch": TorchBackend, - "tensorflow": TensorflowBackend, "numba": NumbaBackend, } @@ -46,7 +44,7 @@ def available_backends(self): return avail_backends def register_backend(self, backend_name: "Backend"): - """Register a backend. Currently supported backends are numpy, numba, jax, torch and tensorflow.""" + """Register a backend. Currently supported backends are numpy, numba, jax, and torch.""" if backend_name not in self.available_backends: raise BackendNotAvailable( f"The backend '{backend_name}' is not available. " diff --git a/scoringrules/backend/tensorflow.py b/scoringrules/backend/tensorflow.py deleted file mode 100644 index 9d58712..0000000 --- a/scoringrules/backend/tensorflow.py +++ /dev/null @@ -1,358 +0,0 @@ -import math -import typing as tp -from importlib import import_module - -from .base import ArrayBackend - -if tp.TYPE_CHECKING: - import tensorflow as tf - - Tensor = tf.Tensor - TensorLike = Tensor | bool | float | int - DTYPE = tf.float32 -else: - tf = None - DTYPE = None -Dtype = tp.TypeVar("Dtype") - - -class TensorflowBackend(ArrayBackend): - """Tensorflow backend.""" - - name = "tensorflow" - - def __init__(self) -> None: - global tf, DTYPE - if tf is None and not tp.TYPE_CHECKING: - tf = import_module("tensorflow") - DTYPE = tf.float32 - self.pi = tf.constant(math.pi, dtype=DTYPE) - - def asarray( - self, - obj: "TensorLike", - /, - *, - dtype: Dtype | None = None, - ) -> "Tensor": - if dtype is None: - dtype = DTYPE - return tf.convert_to_tensor(obj, dtype=dtype) - - def broadcast_arrays(self, *arrays: "Tensor") -> tuple["Tensor", ...]: - raise NotImplementedError - - def mean( - self, - x: "Tensor", - /, - *, - axis: int | tuple[int, ...] | None = None, - keepdims: bool = False, - ) -> "Tensor": - return tf.math.reduce_mean(x, axis=axis, keepdims=keepdims) - - def std( - self, - x: "Tensor", - /, - *, - axis: int | tuple[int, ...] | None = None, - bias: bool = False, - keepdims: bool = False, - ) -> "Tensor": - n = x.shape.num_elements() if axis is None else x.shape[axis] - if not bias: - resc = self.sqrt(n / (n - 1)) - return tf.math.reduce_std(x, axis=axis, keepdims=keepdims) * resc - - def quantile( - self, - x: "Tensor", - q: "TensorLike", - /, - *, - axis: int | tuple[int, ...] | None = None, - keepdims: bool = False, - ) -> "Tensor": - raise NotImplementedError - - def max( - self, - x: "Tensor", - axis: int | tuple[int, ...] | None, - keepdims: bool = False, - ) -> "Tensor": - return tf.math.reduce_max(x, axis=axis, keepdims=keepdims) - - def moveaxis( - self, - x: "Tensor", - /, - source: tuple[int, ...] | int, - destination: tuple[int, ...] | int, - ) -> "Tensor": - return tf.experimental.numpy.moveaxis(x, source, destination) - - def sum( - self, - x: "Tensor", - /, - axis: int | tuple[int, ...] | None = None, - *, - keepdims: bool = False, - ) -> "Tensor": - return tf.math.reduce_sum(x, axis=axis, keepdims=keepdims) - - def cumsum( - self, - x: "Tensor", - /, - axis: int | tuple[int, ...] | None = None, - ) -> "Tensor": - return tf.math.cumsum(x, axis=axis) - - def unique_values(self, x: "Tensor", /) -> "Tensor": - return tf.unique(x) - - def concat( - self, - arrays: tuple["Tensor", ...] | list["Tensor"], - /, - *, - axis: int | None = 0, - ) -> "Tensor": - return tf.concat(arrays, axis=axis) - - # tf.expand_dims() doesn't support tuples in v2.15 - def expand_dims(self, x: "Tensor", /, axis: int | tuple[int] = 0) -> "Tensor": - if isinstance(axis, int): - return tf.expand_dims(x, axis=axis) - elif isinstance(axis, tuple | list): - out_ndim = len(axis) + x.ndim - axis = [a + out_ndim if a < 0 else a for a in axis] - shape_it = iter(x.shape) - shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] - return tf.reshape(x, shape=shape) - - def squeeze( - self, x: "Tensor", /, *, axis: int | tuple[int, ...] | None = None - ) -> "Tensor": - return tf.squeeze(x, axis=axis) - - def stack( - self, arrays: tuple["Tensor", ...] | list["Tensor"], /, *, axis: int = 0 - ) -> "Tensor": - return tf.stack(arrays, axis=axis) - - def arange( - self, - start: int | float, - /, - stop: int | float | None = None, - step: int | float = 1, - *, - dtype: Dtype | None = None, - ) -> "Tensor": - if dtype is None: - dtype = DTYPE - return tf.range(start, stop, step, dtype=dtype) - - def zeros( - self, - shape: int | tuple[int, ...], - *, - dtype: Dtype | None = None, - ) -> "Tensor": - if dtype is None: - dtype = DTYPE - return tf.zeros(shape, dtype=dtype) - - def ones( - self, - shape: int | tuple[int, ...], - *, - dtype: Dtype | None = None, - ) -> "Tensor": - if dtype is None: - dtype = DTYPE - return tf.ones(shape, dtype=dtype) - - def abs(self, x: "Tensor") -> "Tensor": - return tf.math.abs(x) - - def exp(self, x: "Tensor") -> "Tensor": - return tf.math.exp(x) - - def isnan( - self, - x: "Tensor", - *, - dtype: Dtype | None = None, - ) -> "Tensor": - if dtype is None: - dtype = DTYPE - return tf.cast(tf.math.is_nan(x), dtype=dtype) - - def log(self, x: "Tensor") -> "Tensor": - return tf.math.log(x) - - def sqrt(self, x: "Tensor") -> "Tensor": - return tf.math.sqrt(x) - - def any( - self, - x: "Tensor", - /, - *, - axis: int | tuple[int, ...] | None = None, - keepdims: bool = False, - dtype: Dtype | None = None, - ) -> "Tensor": - if dtype is None: - dtype = DTYPE - return tf.cast(tf.math.reduce_any(x, axis=axis, keepdims=keepdims), dtype=dtype) - - def all( - self, - x: "Tensor", - /, - *, - axis: int | tuple[int, ...] | None = None, - keepdims: bool = False, - dtype: Dtype | None = None, - ) -> "Tensor": - if dtype is None: - dtype = DTYPE - return tf.cast(tf.math.reduce_all(x, axis=axis, keepdims=keepdims), dtype=dtype) - - def argsort( - self, - x: "Tensor", - /, - *, - axis: int = -1, - descending: bool = False, - ) -> "Tensor": - direction = "DESCENDING" if descending else "ASCENDING" - return tf.argsort(x, axis=axis, direction=direction) - - def sort( - self, - x: "Tensor", - /, - *, - axis: int = -1, - descending: bool = False, - ) -> "Tensor": - direction = "DESCENDING" if descending else "ASCENDING" - return tf.sort(x, axis=axis, direction=direction) - - def norm(self, x: "Tensor", axis: int | tuple[int, ...] | None = None) -> "Tensor": - return tf.norm(x, axis=axis) - - def erf(self, x: "Tensor") -> "Tensor": - return tf.math.erf(x) - - def apply_along_axis( - self, func1d: tp.Callable[["Tensor"], "Tensor"], x: "Tensor", axis: int - ): - try: - x_shape = x.shape.as_list() - flat = tf.map_fn(func1d, tf.reshape(x, [-1, x_shape.pop(axis)])) - return tf.reshape(flat, x_shape) - except Exception: - return tf.stack( - [func1d(x_i) for x_i in tf.unstack(x, axis=axis)], axis=axis - ) - - def floor(self, x: "Tensor") -> "Tensor": - return tf.math.floor(x) - - def minimum(self, x: "Tensor", y: "TensorLike") -> "Tensor": - return tf.math.minimum(x, y) - - def maximum(self, x: "Tensor", y: "TensorLike") -> "Tensor": - return tf.math.maximum(x, y) - - def beta(self, x: "Tensor", y: "Tensor") -> "Tensor": - return tf.math.exp( - tf.math.lgamma(x) + tf.math.lgamma(y) - tf.math.lgamma(x + y) - ) - - def betainc(self, x: "Tensor", y: "Tensor", z: "Tensor") -> "Tensor": - return tf.math.betainc(x, y, z) - - def mbessel0(self, x: "Tensor") -> "Tensor": - return tf.math.bessel_i0(x) - - def mbessel1(self, x: "Tensor") -> "Tensor": - return tf.math.bessel_i1(x) - - def gamma(self, x: "Tensor") -> "Tensor": - return tf.math.exp(tf.math.lgamma(x)) - - def gammainc(self, x: "Tensor", y: "Tensor") -> "Tensor": - return tf.math.igamma(x, y) - - def gammalinc(self, x: "Tensor", y: "Tensor") -> "Tensor": - return tf.math.igamma(x, y) * tf.math.exp(tf.math.lgamma(x)) - - def gammauinc(self, x: "Tensor", y: "Tensor") -> "Tensor": - return tf.math.igammac(x, y) * tf.math.exp(tf.math.lgamma(x)) - - def factorial(self, n: "TensorLike") -> "TensorLike": - return tf.math.exp(tf.math.lgamma(n + 1)) - - def hypergeometric( - self, a: "Tensor", b: "Tensor", c: "Tensor", z: "Tensor" - ) -> "Tensor": - raise NotImplementedError - - def comb(self, n: "Tensor", k: "Tensor") -> "Tensor": - return self.factorial(n) // (self.factorial(k) * self.factorial(n - k)) - - def expi(self, x: "Tensor") -> "Tensor": - return tf.math.special.expint(x) - - def where(self, condition: "Tensor", x1: "Tensor", x2: "Tensor") -> "Tensor": - return tf.where(condition, x1, x2) - - def size(self, x: "Tensor") -> int: - return x.shape.num_elements() - - def indices(self, dimensions: tuple) -> "Tensor": - ranges = [self.arange(s) for s in dimensions] - index_grids = tf.meshgrid(*ranges, indexing="ij") - indices = tf.stack(index_grids) - return indices - - def gather(self, x: "Tensor", ind: "Tensor", axis: int) -> "Tensor": - d = len(x.shape) - return tf.gather(x, ind, axis=axis, batch_dims=d) - - def roll(self, x: "Tensor", shift: int = 1, axis: int = -1) -> "Tensor": - return tf.roll(x, shift=shift, axis=axis) - - def inv(self, x: "Tensor") -> "Tensor": - return tf.linalg.inv(x) - - def cov(self, x: "Tensor", rowvar: bool = True, bias: bool = False) -> "Tensor": - if not rowvar: - x = tf.transpose(x) - x = x - tf.reduce_mean(x, axis=1, keepdims=True) - correction = tf.cast(tf.shape(x)[1], x.dtype) - 1.0 - if bias: - correction += 1.0 - return tf.matmul(x, x, transpose_b=True) / correction - - def det(self, x: "Tensor") -> "Tensor": - return tf.linalg.det(x) - - def reshape(self, x: "Tensor", shape: int | tuple[int, ...]) -> "Tensor": - return tf.reshape(x, shape) - - -if __name__ == "__main__": - B = TensorflowBackend() - out = B.mean(tf.ones(10)) diff --git a/scoringrules/core/typing.py b/scoringrules/core/typing.py index c87c6eb..a0b897d 100644 --- a/scoringrules/core/typing.py +++ b/scoringrules/core/typing.py @@ -3,11 +3,10 @@ if tp.TYPE_CHECKING: from jax import Array as JaxArray from numpy.typing import NDArray - from tensorflow import Tensor as tensorflowTensor from torch import Tensor as torchTensor - _array = NDArray | JaxArray | torchTensor | tensorflowTensor + _array = NDArray | JaxArray | torchTensor Array = tp.TypeVar("Array", bound=_array) ArrayLike = tp.TypeVar("ArrayLike", bound=_array | float | int) - Backend = tp.Literal["numpy", "numba", "jax", "torch", "tensorflow"] | None + Backend = tp.Literal["numpy", "numba", "jax", "torch"] | None diff --git a/tests/test_crps.py b/tests/test_crps.py index 439a3cc..58c46ee 100644 --- a/tests/test_crps.py +++ b/tests/test_crps.py @@ -463,8 +463,8 @@ def test_crps_cnormal(backend): def test_crps_gtct(backend): - if backend in ["jax", "torch", "tensorflow"]: - pytest.skip("Not implemented in jax, torch or tensorflow backends") + if backend in ["jax", "torch"]: + pytest.skip("Not implemented in jax, torch backends") obs, df, location, scale, lower, upper, lmass, umass = ( 0.9, 20.1, @@ -493,8 +493,8 @@ def test_crps_gtct(backend): def test_crps_tt(backend): - if backend in ["jax", "torch", "tensorflow"]: - pytest.skip("Not implemented in jax, torch or tensorflow backends") + if backend in ["jax", "torch"]: + pytest.skip("Not implemented in jax, torch backends") obs, df, location, scale, lower, upper = -1.0, 2.9, 3.1, 4.2, 1.5, 17.3 expected = 5.084272 @@ -508,8 +508,8 @@ def test_crps_tt(backend): def test_crps_ct(backend): - if backend in ["jax", "torch", "tensorflow"]: - pytest.skip("Not implemented in jax, torch or tensorflow backends") + if backend in ["jax", "torch"]: + pytest.skip("Not implemented in jax, torch backends") obs, df, location, scale, lower, upper = 1.8, 5.4, 0.4, 1.1, 0.0, 2.0 expected = 0.8028996 @@ -623,8 +623,8 @@ def test_crps_mixnorm(backend): def test_crps_negbinom(backend): - if backend in ["jax", "torch", "tensorflow"]: - pytest.skip("Not implemented in jax, torch or tensorflow backends") + if backend in ["jax", "torch"]: + pytest.skip("Not implemented in jax, torch backends") # test exceptions with pytest.raises(ValueError): @@ -703,8 +703,8 @@ def test_crps_poisson(backend): def test_crps_t(backend): - if backend in ["jax", "torch", "tensorflow"]: - pytest.skip("Not implemented in jax, torch or tensorflow backends") + if backend in ["jax", "torch"]: + pytest.skip("Not implemented in jax, torch backends") obs, df, mu, sigma = 11.1, 5.2, 13.8, 2.3 expected = 1.658226 diff --git a/tests/test_logs.py b/tests/test_logs.py index 53857ea..4dd678a 100644 --- a/tests/test_logs.py +++ b/tests/test_logs.py @@ -439,8 +439,8 @@ def test_poisson(backend): def test_t(backend): - if backend in ["jax", "torch", "tensorflow"]: - pytest.skip("Not implemented in jax, torch or tensorflow backends") + if backend in ["jax", "torch"]: + pytest.skip("Not implemented in jax, torch backends") obs, df, mu, sigma = 11.1, 5.2, 13.8, 2.3 res = sr.logs_t(obs, df, mu, sigma, backend=backend) @@ -483,8 +483,8 @@ def test_tnormal(backend): def test_tt(backend): - if backend in ["jax", "torch", "tensorflow"]: - pytest.skip("Not implemented in jax, torch or tensorflow backends") + if backend in ["jax", "torch"]: + pytest.skip("Not implemented in jax, torch backends") obs, df, location, scale, lower, upper = 1.9, 2.9, 3.1, 4.2, 1.5, 17.3 res = sr.logs_tt(obs, df, location, scale, lower, upper, backend=backend)