Skip to content

Commit

Permalink
Merge pull request #8095 from asi1024/boxcox-llf-scipy112
Browse files Browse the repository at this point in the history
Fix `boxcox_llf` for SciPy 1.12 changes
  • Loading branch information
emcastillo committed Jan 24, 2024
2 parents 96ee0fa + 61be61e commit 9c01872
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
19 changes: 17 additions & 2 deletions cupyx/scipy/stats/_morestats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
import cupy
from cupyx.scipy import special


def _log_mean(logx):
# compute log of mean of x from log(x)
return special.logsumexp(logx, axis=0) - cupy.log(len(logx))


def _log_var(logx):
# compute log of variance of x from log(x)
neg_logmean = cupy.broadcast_to(_log_mean(logx) - cupy.pi * 1j, logx.shape)
logxmu = special.logsumexp(cupy.asarray([logx, neg_logmean]), axis=0)
return special.logsumexp(2 * logxmu, axis=0).real - cupy.log(len(logx))


def boxcox_llf(lmb, data):
Expand Down Expand Up @@ -41,7 +54,9 @@ def boxcox_llf(lmb, data):
# Compute the variance of the transformed data
if lmb == 0:
variance = cupy.var(logdata, axis=0)
logvar = cupy.log(variance)
else:
variance = cupy.var(data**lmb / lmb, axis=0)
logx = lmb * logdata - cupy.log(abs(lmb))
logvar = _log_var(logx)

return (lmb - 1) * cupy.sum(logdata, axis=0) - N/2 * cupy.log(variance)
return (lmb - 1) * cupy.sum(logdata, axis=0) - N/2 * logvar
28 changes: 13 additions & 15 deletions tests/cupyx_tests/scipy_tests/stats_tests/test_morestats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy
import scipy.stats

import cupy
from cupy import testing
Expand Down Expand Up @@ -42,25 +43,22 @@ def _make_data(shape, xp, dtype):

def _compute(xp, scp, lmb, data):
result = scp.stats.boxcox_llf(lmb, data)
if data.ndim == 1:
if data.dtype.kind == 'c':
assert result.dtype == xp.complex128
else:
assert result.dtype == xp.float64
expected_dtype = scipy.stats.boxcox_llf(lmb, cupy.asnumpy(data)).dtype
assert result.dtype == expected_dtype

if xp is cupy:
return result, _dtype(data.dtype, xp)
else:
if data.dtype.kind in 'cf':
assert result.dtype == data.dtype
elif lmb == 0:
for dtype1 in [xp.float16, xp.float32, xp.float64]:
if xp.can_cast(data.dtype, dtype1):
break
assert result.dtype == dtype1
assert xp is numpy
# Compute with higher precision
if data.dtype.kind == 'c':
result = scp.stats.boxcox_llf(lmb, data.astype(xp.complex128))
else:
assert result.dtype == xp.float64
return result, _dtype(data.dtype, xp)
result = scp.stats.boxcox_llf(lmb, data.astype(xp.float64))
return result, _dtype(data.dtype, xp)


@testing.with_requires('scipy')
@testing.with_requires('scipy>=1.12.0rc1')
class TestBoxcox_llf:

@testing.for_all_dtypes(no_bool=True)
Expand Down

0 comments on commit 9c01872

Please sign in to comment.