Skip to content

Commit bc370a9

Browse files
committed
Add distance correlation estimator + unit-tests
1 parent 73ed8bb commit bc370a9

File tree

2 files changed

+347
-0
lines changed

2 files changed

+347
-0
lines changed

frites/estimator/est_dcorr.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""Correlation based estimators."""
2+
import numpy as np
3+
4+
from frites.io import logger
5+
from frites.estimator.est_mi_base import BaseMIEstimator
6+
7+
8+
class DcorrEstimator(BaseMIEstimator):
9+
10+
"""Distance correlation-based estimator.
11+
12+
This estimator can be used to estimate the correlation between two
13+
continuous variables (mi_type='cc').
14+
15+
Parameters
16+
----------
17+
implementation : {'auto', 'frites', 'dcor'}
18+
Choose wich implementation of the distance correlation to use. If
19+
'frites' a home-made version is going to be used. If 'dcor', the one of
20+
the dcorr package is going to be preferred (see for installation
21+
`<https://dcor.readthedocs.io/>`_).
22+
"""
23+
24+
def __init__(self, implementation='auto', verbose=None):
25+
"""Init."""
26+
self.name = 'Distance correlation-based Estimator'
27+
# get the distance correlation function
28+
fcn, implementation = get_distance_correlation(
29+
implementation=implementation)
30+
self._core_fun = wrap_dcorr(fcn)
31+
# instantiate base class
32+
super(DcorrEstimator, self).__init__(
33+
mi_type='cc', verbose=verbose,
34+
add_str=f', implementation={implementation}')
35+
# update internal settings
36+
settings = dict(mi_type='cc', core_fun=self._core_fun.__name__)
37+
self.settings.merge([settings])
38+
39+
def estimate(self, x, y, z=None, categories=None):
40+
"""Estimate the distance correlation between two variables.
41+
42+
This method is made for computing the correlation on 3D variables
43+
(i.e (n_var, n_mv, n_samples)) where n_var is an additional dimension
44+
(e.g times, times x freqs etc.)n_mv is a multivariate axis and
45+
n_samples the number of samples.
46+
47+
Parameters
48+
----------
49+
x, y : array_like
50+
Array of shape (n_var, n_mv, n_samples).
51+
categories : array_like | None
52+
Row vector of categories. This vector should have a shape of
53+
(n_samples,) and should contains integers describing the category
54+
of each sample.
55+
56+
Returns
57+
-------
58+
corr : array_like
59+
Array of correlation of shape (n_categories, n_var).
60+
"""
61+
fcn = self.get_function()
62+
return fcn(x, y, categories=categories)
63+
64+
def get_function(self):
65+
"""Get the function to execute according to the input parameters.
66+
67+
This can be particulary usefull when computing correlation in parallel
68+
as it avoids to pickle the whole estimator and therefore, leading to
69+
faster computations.
70+
71+
The returned function has the following signature :
72+
73+
* fcn(x, y, *args, categories=None, **kwargs)
74+
75+
and return an array of shape (n_categories, n_var).
76+
"""
77+
core_fun = self._core_fun
78+
79+
def estimator(x, y, *args, categories=None, **kwargs):
80+
if categories is None:
81+
categories = np.array([], dtype=np.float32)
82+
83+
# be sure that x is at least 3d
84+
if x.ndim == 1:
85+
x = x[np.newaxis, np.newaxis, :]
86+
if x.ndim == 2:
87+
x = x[np.newaxis, :]
88+
89+
# repeat y (if needed)
90+
if (y.ndim == 1):
91+
n_var, n_mv, _ = x.shape
92+
y = np.tile(y, (n_var, 1, 1))
93+
94+
return core_fun(x, y, categories)
95+
96+
return estimator
97+
98+
99+
def wrap_dcorr(fcn):
100+
def correlate(x, y, categories):
101+
"""3D distance correlation."""
102+
# transpose x and y to be (n_samples, n_mv, n_var)
103+
x, y = np.transpose(x, (2, 1, 0)), np.transpose(y, (2, 1, 0))
104+
# proper shape of the regressor
105+
n_trials, _, n_times = x.shape
106+
if len(categories) != n_trials:
107+
corr = np.zeros((1, n_times), dtype=np.float32)
108+
for t in range(n_times):
109+
corr[0, t] = fcn(x[:, :, t], y[:, :, t])
110+
else:
111+
# get categories informations
112+
u_cat = np.unique(categories)
113+
n_cats = len(u_cat)
114+
# compute mi per subject
115+
corr = np.zeros((n_cats, n_times), dtype=np.float32)
116+
for n_c, c in enumerate(u_cat):
117+
is_cat = categories == c
118+
x_c, y_c = x[is_cat, :, :], y[is_cat, :, :]
119+
for t in range(n_times):
120+
corr[n_c, t] = fcn(x_c[:, :, t], y_c[:, :, t])
121+
122+
return corr
123+
return correlate
124+
125+
126+
def get_distance_correlation(implementation='auto'):
127+
"""Get the function to compute the distance correlation.
128+
129+
Parameters
130+
----------
131+
implementation : {'auto', 'frites', 'dcor'}
132+
description
133+
"""
134+
if implementation == 'dcor':
135+
logger.debug('Using dcor implementation of dcorr')
136+
from dcor import distance_correlation as dcorr
137+
return dcorr, 'dcor'
138+
elif implementation == 'frites':
139+
logger.debug('Using home-made implementation of dcorr')
140+
return distance_correlation, 'frites'
141+
elif implementation == 'auto':
142+
try:
143+
logger.debug('Using dcor implementation of dcorr')
144+
from dcor import distance_correlation as dcorr
145+
return dcorr, 'dcor'
146+
except ModuleNotFoundError:
147+
logger.debug('Using home-made implementation of dcorr')
148+
return distance_correlation, 'frites'
149+
150+
###############################################################################
151+
###############################################################################
152+
# DISTANCE CORRELATION
153+
###############################################################################
154+
###############################################################################
155+
156+
def dist_eucl(x):
157+
"""Double centered euclidian distance."""
158+
if x.ndim == 1:
159+
x = x[:, np.newaxis]
160+
n = x.shape[0]
161+
162+
# compute the euclidian distance
163+
dist = - 2 * x.dot(x.T)
164+
x_square = (x * x).sum(axis=1)
165+
np.add(dist, x_square.reshape(n, 1), out=dist)
166+
np.add(dist, x_square.reshape(1, n), out=dist)
167+
np.fill_diagonal(dist, 0.)
168+
np.sqrt(dist, out=dist)
169+
170+
# double centering
171+
np.subtract(dist, dist.mean(axis=0, keepdims=True), out=dist)
172+
np.subtract(dist, dist.mean(axis=1, keepdims=True), out=dist)
173+
np.add(dist, dist.mean(), out=dist)
174+
175+
return dist
176+
177+
178+
def distance_correlation(x, y):
179+
"""Compute the distance correlation.
180+
181+
This function computes the distance correlation between two, possibly
182+
multivariate, variables.
183+
184+
Parameter
185+
---------
186+
x, y : array_like
187+
Arrays of shape (n_samples, n_var)
188+
189+
Returns
190+
-------
191+
dcorr : float
192+
The distance correlation between x and y
193+
"""
194+
# inputs checking
195+
assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray)
196+
if x.dtype not in [np.float32, np.float64]:
197+
x = x.astype(np.float32, copy=False)
198+
if y.dtype not in [np.float32, np.float64]:
199+
y = y.astype(np.float32, copy=False)
200+
if x.ndim == 1:
201+
x = x[:, np.newaxis]
202+
if y.ndim == 1:
203+
y = y[:, np.newaxis]
204+
assert (x.ndim == 2) and (y.ndim == 2)
205+
assert (x.shape[0] == y.shape[0])
206+
207+
# compute distance across multivariate axis
208+
n = x.shape[0]
209+
a = dist_eucl(x)
210+
b = dist_eucl(y)
211+
212+
# compute covariances
213+
denom = float(n * n)
214+
dcov2_xy = (a * b).sum() / denom
215+
dcov2_xx = (a * a).sum() / denom
216+
dcov2_yy = (b * b).sum() / denom
217+
dcor = np.sqrt(dcov2_xy) / np.sqrt(np.sqrt(dcov2_xx) * np.sqrt(dcov2_yy))
218+
return dcor
219+
220+
221+
if __name__ == '__main__':
222+
est = DcorrEstimator(implementation='auto', verbose='debug')
223+
fcn = est.get_function()
224+
x = np.random.rand(100).reshape(1, 1, -1)
225+
y = np.random.rand(100).reshape(-1)
226+
x[..., 0:50] -= y[..., 0:50]
227+
# x[..., 50:100] += y[..., 50:100]
228+
from dcor import distance_correlation
229+
print(distance_correlation(x.squeeze(), y))
230+
cat = np.array([0] * 50 + [1] * 50)
231+
corr = fcn(x, y, categories=None)
232+
print(corr)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Test correlation and distance correlation estimators."""
2+
import numpy as np
3+
4+
from frites.estimator import CorrEstimator, DcorrEstimator
5+
6+
7+
array_equal = np.testing.assert_array_equal
8+
9+
10+
class TestCorrEstimator(object):
11+
12+
def test_corr_definition(self):
13+
"""Test definition of correlation estimator."""
14+
CorrEstimator()
15+
16+
def test_corr_estimate(self):
17+
"""Test getting the core function."""
18+
x, y = np.random.rand(10, 1, 100), np.random.rand(10, 1, 100)
19+
cat = np.array([0] * 50 + [1] * 50)
20+
est = CorrEstimator()
21+
22+
for func in [0, 1]:
23+
if func == 0: # estimator.get_function()
24+
fcn = est.get_function()
25+
elif func == 1: # estimator.estimate
26+
fcn = est.estimate
27+
28+
# no categories
29+
array_equal(fcn(x[0, 0, :], y[0, 0, :]).shape, (1, 1))
30+
array_equal(fcn(x[0, :, :], y[0, 0, :]).shape, (1, 1))
31+
array_equal(fcn(x, y).shape, (1, 10))
32+
33+
# with categories
34+
array_equal(fcn(x[0, 0, :], y[0, 0, :],
35+
categories=cat).shape, (2, 1))
36+
array_equal(fcn(x[0, :, :], y[0, 0, :],
37+
categories=cat).shape, (2, 1))
38+
array_equal(fcn(x, y, categories=cat).shape, (2, 10))
39+
40+
def test_corr_functional(self):
41+
"""Functional test of the correlation."""
42+
fcn = CorrEstimator().get_function()
43+
44+
# no categories
45+
x, y = np.random.rand(2, 1, 100), np.random.rand(100)
46+
x[1, ...] += y.reshape(1, -1)
47+
corr = fcn(x, y).ravel()
48+
assert corr[0] < corr[1]
49+
50+
# with categories
51+
x, y = np.random.rand(100), np.random.rand(100)
52+
cat = np.array([0] * 50 + [1] * 50)
53+
x[0:50] += y[0:50]
54+
x[50::] -= y[50::]
55+
corr_nocat = fcn(x, y).ravel()
56+
corr_cat = fcn(x, y, categories=cat).ravel()
57+
assert (corr_nocat < corr_cat[0]) and (corr_nocat < abs(corr_cat[1]))
58+
assert (corr_cat[0] > 0) and (corr_cat[1] < 0)
59+
60+
def test_dcorr_definition(self):
61+
"""Test definition of distance correlation estimator."""
62+
DcorrEstimator(implementation='auto')
63+
DcorrEstimator(implementation='frites')
64+
DcorrEstimator(implementation='dcor')
65+
66+
def test_dcorr_estimate(self):
67+
"""Test getting the core function."""
68+
x, y = np.random.rand(10, 1, 100), np.random.rand(10, 1, 100)
69+
cat = np.array([0] * 50 + [1] * 50)
70+
71+
for imp in ['auto', 'frites', 'dcor']:
72+
est = DcorrEstimator(implementation=imp)
73+
for func in [0, 1]:
74+
# function definition
75+
if func == 0: # estimator.get_function()
76+
fcn = est.get_function()
77+
elif func == 1: # estimator.estimate
78+
fcn = est.estimate
79+
80+
# no categories
81+
array_equal(fcn(x[0, 0, :], y[0, 0, :]).shape, (1, 1))
82+
array_equal(fcn(x[0, :, :], y[0, 0, :]).shape, (1, 1))
83+
array_equal(fcn(x, y).shape, (1, 10))
84+
85+
# with categories
86+
array_equal(fcn(x[0, 0, :], y[0, 0, :],
87+
categories=cat).shape, (2, 1))
88+
array_equal(fcn(x[0, :, :], y[0, 0, :],
89+
categories=cat).shape, (2, 1))
90+
array_equal(fcn(x, y, categories=cat).shape, (2, 10))
91+
92+
def test_dcorr_functional(self):
93+
"""Functional test of the correlation."""
94+
for imp in ['auto', 'frites', 'dcor']:
95+
fcn = DcorrEstimator(implementation=imp).get_function()
96+
97+
# no categories
98+
x, y = np.random.rand(2, 1, 100), np.random.rand(100)
99+
x[1, ...] += y.reshape(1, -1)
100+
dcorr = fcn(x, y).ravel()
101+
assert dcorr[0] < dcorr[1]
102+
103+
# with categories
104+
x, y = np.random.rand(100), np.random.rand(100)
105+
cat = np.array([0] * 50 + [1] * 50)
106+
x[0:50] += y[0:50]
107+
x[50::] -= y[50::]
108+
dc_nocat = fcn(x, y).ravel()
109+
dc_cat = fcn(x, y, categories=cat).ravel()
110+
assert (dc_nocat < dc_cat[0]) and (dc_nocat < dc_cat[1])
111+
assert (0 < dc_cat[0]) and (0 < dc_cat[1])
112+
113+
114+
if __name__ == '__main__':
115+
TestCorrEstimator().test_dcorr_functional()

0 commit comments

Comments
 (0)