Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add distance correlation estimator + unit-tests
- Loading branch information
1 parent
73ed8bb
commit bc370a9
Showing
2 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
"""Correlation based estimators.""" | ||
import numpy as np | ||
|
||
from frites.io import logger | ||
from frites.estimator.est_mi_base import BaseMIEstimator | ||
|
||
|
||
class DcorrEstimator(BaseMIEstimator): | ||
|
||
"""Distance correlation-based estimator. | ||
This estimator can be used to estimate the correlation between two | ||
continuous variables (mi_type='cc'). | ||
Parameters | ||
---------- | ||
implementation : {'auto', 'frites', 'dcor'} | ||
Choose wich implementation of the distance correlation to use. If | ||
'frites' a home-made version is going to be used. If 'dcor', the one of | ||
the dcorr package is going to be preferred (see for installation | ||
`<https://dcor.readthedocs.io/>`_). | ||
""" | ||
|
||
def __init__(self, implementation='auto', verbose=None): | ||
"""Init.""" | ||
self.name = 'Distance correlation-based Estimator' | ||
# get the distance correlation function | ||
fcn, implementation = get_distance_correlation( | ||
implementation=implementation) | ||
self._core_fun = wrap_dcorr(fcn) | ||
# instantiate base class | ||
super(DcorrEstimator, self).__init__( | ||
mi_type='cc', verbose=verbose, | ||
add_str=f', implementation={implementation}') | ||
# update internal settings | ||
settings = dict(mi_type='cc', core_fun=self._core_fun.__name__) | ||
self.settings.merge([settings]) | ||
|
||
def estimate(self, x, y, z=None, categories=None): | ||
"""Estimate the distance correlation between two variables. | ||
This method is made for computing the correlation on 3D variables | ||
(i.e (n_var, n_mv, n_samples)) where n_var is an additional dimension | ||
(e.g times, times x freqs etc.)n_mv is a multivariate axis and | ||
n_samples the number of samples. | ||
Parameters | ||
---------- | ||
x, y : array_like | ||
Array of shape (n_var, n_mv, n_samples). | ||
categories : array_like | None | ||
Row vector of categories. This vector should have a shape of | ||
(n_samples,) and should contains integers describing the category | ||
of each sample. | ||
Returns | ||
------- | ||
corr : array_like | ||
Array of correlation of shape (n_categories, n_var). | ||
""" | ||
fcn = self.get_function() | ||
return fcn(x, y, categories=categories) | ||
|
||
def get_function(self): | ||
"""Get the function to execute according to the input parameters. | ||
This can be particulary usefull when computing correlation in parallel | ||
as it avoids to pickle the whole estimator and therefore, leading to | ||
faster computations. | ||
The returned function has the following signature : | ||
* fcn(x, y, *args, categories=None, **kwargs) | ||
and return an array of shape (n_categories, n_var). | ||
""" | ||
core_fun = self._core_fun | ||
|
||
def estimator(x, y, *args, categories=None, **kwargs): | ||
if categories is None: | ||
categories = np.array([], dtype=np.float32) | ||
|
||
# be sure that x is at least 3d | ||
if x.ndim == 1: | ||
x = x[np.newaxis, np.newaxis, :] | ||
if x.ndim == 2: | ||
x = x[np.newaxis, :] | ||
|
||
# repeat y (if needed) | ||
if (y.ndim == 1): | ||
n_var, n_mv, _ = x.shape | ||
y = np.tile(y, (n_var, 1, 1)) | ||
|
||
return core_fun(x, y, categories) | ||
|
||
return estimator | ||
|
||
|
||
def wrap_dcorr(fcn): | ||
def correlate(x, y, categories): | ||
"""3D distance correlation.""" | ||
# transpose x and y to be (n_samples, n_mv, n_var) | ||
x, y = np.transpose(x, (2, 1, 0)), np.transpose(y, (2, 1, 0)) | ||
# proper shape of the regressor | ||
n_trials, _, n_times = x.shape | ||
if len(categories) != n_trials: | ||
corr = np.zeros((1, n_times), dtype=np.float32) | ||
for t in range(n_times): | ||
corr[0, t] = fcn(x[:, :, t], y[:, :, t]) | ||
else: | ||
# get categories informations | ||
u_cat = np.unique(categories) | ||
n_cats = len(u_cat) | ||
# compute mi per subject | ||
corr = np.zeros((n_cats, n_times), dtype=np.float32) | ||
for n_c, c in enumerate(u_cat): | ||
is_cat = categories == c | ||
x_c, y_c = x[is_cat, :, :], y[is_cat, :, :] | ||
for t in range(n_times): | ||
corr[n_c, t] = fcn(x_c[:, :, t], y_c[:, :, t]) | ||
|
||
return corr | ||
return correlate | ||
|
||
|
||
def get_distance_correlation(implementation='auto'): | ||
"""Get the function to compute the distance correlation. | ||
Parameters | ||
---------- | ||
implementation : {'auto', 'frites', 'dcor'} | ||
description | ||
""" | ||
if implementation == 'dcor': | ||
logger.debug('Using dcor implementation of dcorr') | ||
from dcor import distance_correlation as dcorr | ||
return dcorr, 'dcor' | ||
elif implementation == 'frites': | ||
logger.debug('Using home-made implementation of dcorr') | ||
return distance_correlation, 'frites' | ||
elif implementation == 'auto': | ||
try: | ||
logger.debug('Using dcor implementation of dcorr') | ||
from dcor import distance_correlation as dcorr | ||
return dcorr, 'dcor' | ||
except ModuleNotFoundError: | ||
logger.debug('Using home-made implementation of dcorr') | ||
return distance_correlation, 'frites' | ||
|
||
############################################################################### | ||
############################################################################### | ||
# DISTANCE CORRELATION | ||
############################################################################### | ||
############################################################################### | ||
|
||
def dist_eucl(x): | ||
"""Double centered euclidian distance.""" | ||
if x.ndim == 1: | ||
x = x[:, np.newaxis] | ||
n = x.shape[0] | ||
|
||
# compute the euclidian distance | ||
dist = - 2 * x.dot(x.T) | ||
x_square = (x * x).sum(axis=1) | ||
np.add(dist, x_square.reshape(n, 1), out=dist) | ||
np.add(dist, x_square.reshape(1, n), out=dist) | ||
np.fill_diagonal(dist, 0.) | ||
np.sqrt(dist, out=dist) | ||
|
||
# double centering | ||
np.subtract(dist, dist.mean(axis=0, keepdims=True), out=dist) | ||
np.subtract(dist, dist.mean(axis=1, keepdims=True), out=dist) | ||
np.add(dist, dist.mean(), out=dist) | ||
|
||
return dist | ||
|
||
|
||
def distance_correlation(x, y): | ||
"""Compute the distance correlation. | ||
This function computes the distance correlation between two, possibly | ||
multivariate, variables. | ||
Parameter | ||
--------- | ||
x, y : array_like | ||
Arrays of shape (n_samples, n_var) | ||
Returns | ||
------- | ||
dcorr : float | ||
The distance correlation between x and y | ||
""" | ||
# inputs checking | ||
assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray) | ||
if x.dtype not in [np.float32, np.float64]: | ||
x = x.astype(np.float32, copy=False) | ||
if y.dtype not in [np.float32, np.float64]: | ||
y = y.astype(np.float32, copy=False) | ||
if x.ndim == 1: | ||
x = x[:, np.newaxis] | ||
if y.ndim == 1: | ||
y = y[:, np.newaxis] | ||
assert (x.ndim == 2) and (y.ndim == 2) | ||
assert (x.shape[0] == y.shape[0]) | ||
|
||
# compute distance across multivariate axis | ||
n = x.shape[0] | ||
a = dist_eucl(x) | ||
b = dist_eucl(y) | ||
|
||
# compute covariances | ||
denom = float(n * n) | ||
dcov2_xy = (a * b).sum() / denom | ||
dcov2_xx = (a * a).sum() / denom | ||
dcov2_yy = (b * b).sum() / denom | ||
dcor = np.sqrt(dcov2_xy) / np.sqrt(np.sqrt(dcov2_xx) * np.sqrt(dcov2_yy)) | ||
return dcor | ||
|
||
|
||
if __name__ == '__main__': | ||
est = DcorrEstimator(implementation='auto', verbose='debug') | ||
fcn = est.get_function() | ||
x = np.random.rand(100).reshape(1, 1, -1) | ||
y = np.random.rand(100).reshape(-1) | ||
x[..., 0:50] -= y[..., 0:50] | ||
# x[..., 50:100] += y[..., 50:100] | ||
from dcor import distance_correlation | ||
print(distance_correlation(x.squeeze(), y)) | ||
cat = np.array([0] * 50 + [1] * 50) | ||
corr = fcn(x, y, categories=None) | ||
print(corr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
"""Test correlation and distance correlation estimators.""" | ||
import numpy as np | ||
|
||
from frites.estimator import CorrEstimator, DcorrEstimator | ||
|
||
|
||
array_equal = np.testing.assert_array_equal | ||
|
||
|
||
class TestCorrEstimator(object): | ||
|
||
def test_corr_definition(self): | ||
"""Test definition of correlation estimator.""" | ||
CorrEstimator() | ||
|
||
def test_corr_estimate(self): | ||
"""Test getting the core function.""" | ||
x, y = np.random.rand(10, 1, 100), np.random.rand(10, 1, 100) | ||
cat = np.array([0] * 50 + [1] * 50) | ||
est = CorrEstimator() | ||
|
||
for func in [0, 1]: | ||
if func == 0: # estimator.get_function() | ||
fcn = est.get_function() | ||
elif func == 1: # estimator.estimate | ||
fcn = est.estimate | ||
|
||
# no categories | ||
array_equal(fcn(x[0, 0, :], y[0, 0, :]).shape, (1, 1)) | ||
array_equal(fcn(x[0, :, :], y[0, 0, :]).shape, (1, 1)) | ||
array_equal(fcn(x, y).shape, (1, 10)) | ||
|
||
# with categories | ||
array_equal(fcn(x[0, 0, :], y[0, 0, :], | ||
categories=cat).shape, (2, 1)) | ||
array_equal(fcn(x[0, :, :], y[0, 0, :], | ||
categories=cat).shape, (2, 1)) | ||
array_equal(fcn(x, y, categories=cat).shape, (2, 10)) | ||
|
||
def test_corr_functional(self): | ||
"""Functional test of the correlation.""" | ||
fcn = CorrEstimator().get_function() | ||
|
||
# no categories | ||
x, y = np.random.rand(2, 1, 100), np.random.rand(100) | ||
x[1, ...] += y.reshape(1, -1) | ||
corr = fcn(x, y).ravel() | ||
assert corr[0] < corr[1] | ||
|
||
# with categories | ||
x, y = np.random.rand(100), np.random.rand(100) | ||
cat = np.array([0] * 50 + [1] * 50) | ||
x[0:50] += y[0:50] | ||
x[50::] -= y[50::] | ||
corr_nocat = fcn(x, y).ravel() | ||
corr_cat = fcn(x, y, categories=cat).ravel() | ||
assert (corr_nocat < corr_cat[0]) and (corr_nocat < abs(corr_cat[1])) | ||
assert (corr_cat[0] > 0) and (corr_cat[1] < 0) | ||
|
||
def test_dcorr_definition(self): | ||
"""Test definition of distance correlation estimator.""" | ||
DcorrEstimator(implementation='auto') | ||
DcorrEstimator(implementation='frites') | ||
DcorrEstimator(implementation='dcor') | ||
|
||
def test_dcorr_estimate(self): | ||
"""Test getting the core function.""" | ||
x, y = np.random.rand(10, 1, 100), np.random.rand(10, 1, 100) | ||
cat = np.array([0] * 50 + [1] * 50) | ||
|
||
for imp in ['auto', 'frites', 'dcor']: | ||
est = DcorrEstimator(implementation=imp) | ||
for func in [0, 1]: | ||
# function definition | ||
if func == 0: # estimator.get_function() | ||
fcn = est.get_function() | ||
elif func == 1: # estimator.estimate | ||
fcn = est.estimate | ||
|
||
# no categories | ||
array_equal(fcn(x[0, 0, :], y[0, 0, :]).shape, (1, 1)) | ||
array_equal(fcn(x[0, :, :], y[0, 0, :]).shape, (1, 1)) | ||
array_equal(fcn(x, y).shape, (1, 10)) | ||
|
||
# with categories | ||
array_equal(fcn(x[0, 0, :], y[0, 0, :], | ||
categories=cat).shape, (2, 1)) | ||
array_equal(fcn(x[0, :, :], y[0, 0, :], | ||
categories=cat).shape, (2, 1)) | ||
array_equal(fcn(x, y, categories=cat).shape, (2, 10)) | ||
|
||
def test_dcorr_functional(self): | ||
"""Functional test of the correlation.""" | ||
for imp in ['auto', 'frites', 'dcor']: | ||
fcn = DcorrEstimator(implementation=imp).get_function() | ||
|
||
# no categories | ||
x, y = np.random.rand(2, 1, 100), np.random.rand(100) | ||
x[1, ...] += y.reshape(1, -1) | ||
dcorr = fcn(x, y).ravel() | ||
assert dcorr[0] < dcorr[1] | ||
|
||
# with categories | ||
x, y = np.random.rand(100), np.random.rand(100) | ||
cat = np.array([0] * 50 + [1] * 50) | ||
x[0:50] += y[0:50] | ||
x[50::] -= y[50::] | ||
dc_nocat = fcn(x, y).ravel() | ||
dc_cat = fcn(x, y, categories=cat).ravel() | ||
assert (dc_nocat < dc_cat[0]) and (dc_nocat < dc_cat[1]) | ||
assert (0 < dc_cat[0]) and (0 < dc_cat[1]) | ||
|
||
|
||
if __name__ == '__main__': | ||
TestCorrEstimator().test_dcorr_functional() |