These functions lead to a warning when expm1/exp overflows, but this overflow is already disarmed by np.where.
def inverse_softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray:
"""Numerically stabilized inverse softplus function."""
return np.where(beta * x > threshold, x, np.log(beta * np.expm1(x)) / beta)
...
def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray:
"""Numerically stabilized softplus function."""
return np.where(beta * x > threshold, x, np.log1p(np.exp(beta * x)) / beta)
The expected behavior is: Don't warn when such a caught overflow occurs.