|
| 1 | +"""Correlation based estimators.""" |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +from frites.io import logger |
| 5 | +from frites.estimator.est_mi_base import BaseMIEstimator |
| 6 | + |
| 7 | + |
| 8 | +class DcorrEstimator(BaseMIEstimator): |
| 9 | + |
| 10 | + """Distance correlation-based estimator. |
| 11 | +
|
| 12 | + This estimator can be used to estimate the correlation between two |
| 13 | + continuous variables (mi_type='cc'). |
| 14 | +
|
| 15 | + Parameters |
| 16 | + ---------- |
| 17 | + implementation : {'auto', 'frites', 'dcor'} |
| 18 | + Choose wich implementation of the distance correlation to use. If |
| 19 | + 'frites' a home-made version is going to be used. If 'dcor', the one of |
| 20 | + the dcorr package is going to be preferred (see for installation |
| 21 | + `<https://dcor.readthedocs.io/>`_). |
| 22 | + """ |
| 23 | + |
| 24 | + def __init__(self, implementation='auto', verbose=None): |
| 25 | + """Init.""" |
| 26 | + self.name = 'Distance correlation-based Estimator' |
| 27 | + # get the distance correlation function |
| 28 | + fcn, implementation = get_distance_correlation( |
| 29 | + implementation=implementation) |
| 30 | + self._core_fun = wrap_dcorr(fcn) |
| 31 | + # instantiate base class |
| 32 | + super(DcorrEstimator, self).__init__( |
| 33 | + mi_type='cc', verbose=verbose, |
| 34 | + add_str=f', implementation={implementation}') |
| 35 | + # update internal settings |
| 36 | + settings = dict(mi_type='cc', core_fun=self._core_fun.__name__) |
| 37 | + self.settings.merge([settings]) |
| 38 | + |
| 39 | + def estimate(self, x, y, z=None, categories=None): |
| 40 | + """Estimate the distance correlation between two variables. |
| 41 | +
|
| 42 | + This method is made for computing the correlation on 3D variables |
| 43 | + (i.e (n_var, n_mv, n_samples)) where n_var is an additional dimension |
| 44 | + (e.g times, times x freqs etc.)n_mv is a multivariate axis and |
| 45 | + n_samples the number of samples. |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + x, y : array_like |
| 50 | + Array of shape (n_var, n_mv, n_samples). |
| 51 | + categories : array_like | None |
| 52 | + Row vector of categories. This vector should have a shape of |
| 53 | + (n_samples,) and should contains integers describing the category |
| 54 | + of each sample. |
| 55 | +
|
| 56 | + Returns |
| 57 | + ------- |
| 58 | + corr : array_like |
| 59 | + Array of correlation of shape (n_categories, n_var). |
| 60 | + """ |
| 61 | + fcn = self.get_function() |
| 62 | + return fcn(x, y, categories=categories) |
| 63 | + |
| 64 | + def get_function(self): |
| 65 | + """Get the function to execute according to the input parameters. |
| 66 | +
|
| 67 | + This can be particulary usefull when computing correlation in parallel |
| 68 | + as it avoids to pickle the whole estimator and therefore, leading to |
| 69 | + faster computations. |
| 70 | +
|
| 71 | + The returned function has the following signature : |
| 72 | +
|
| 73 | + * fcn(x, y, *args, categories=None, **kwargs) |
| 74 | +
|
| 75 | + and return an array of shape (n_categories, n_var). |
| 76 | + """ |
| 77 | + core_fun = self._core_fun |
| 78 | + |
| 79 | + def estimator(x, y, *args, categories=None, **kwargs): |
| 80 | + if categories is None: |
| 81 | + categories = np.array([], dtype=np.float32) |
| 82 | + |
| 83 | + # be sure that x is at least 3d |
| 84 | + if x.ndim == 1: |
| 85 | + x = x[np.newaxis, np.newaxis, :] |
| 86 | + if x.ndim == 2: |
| 87 | + x = x[np.newaxis, :] |
| 88 | + |
| 89 | + # repeat y (if needed) |
| 90 | + if (y.ndim == 1): |
| 91 | + n_var, n_mv, _ = x.shape |
| 92 | + y = np.tile(y, (n_var, 1, 1)) |
| 93 | + |
| 94 | + return core_fun(x, y, categories) |
| 95 | + |
| 96 | + return estimator |
| 97 | + |
| 98 | + |
| 99 | +def wrap_dcorr(fcn): |
| 100 | + def correlate(x, y, categories): |
| 101 | + """3D distance correlation.""" |
| 102 | + # transpose x and y to be (n_samples, n_mv, n_var) |
| 103 | + x, y = np.transpose(x, (2, 1, 0)), np.transpose(y, (2, 1, 0)) |
| 104 | + # proper shape of the regressor |
| 105 | + n_trials, _, n_times = x.shape |
| 106 | + if len(categories) != n_trials: |
| 107 | + corr = np.zeros((1, n_times), dtype=np.float32) |
| 108 | + for t in range(n_times): |
| 109 | + corr[0, t] = fcn(x[:, :, t], y[:, :, t]) |
| 110 | + else: |
| 111 | + # get categories informations |
| 112 | + u_cat = np.unique(categories) |
| 113 | + n_cats = len(u_cat) |
| 114 | + # compute mi per subject |
| 115 | + corr = np.zeros((n_cats, n_times), dtype=np.float32) |
| 116 | + for n_c, c in enumerate(u_cat): |
| 117 | + is_cat = categories == c |
| 118 | + x_c, y_c = x[is_cat, :, :], y[is_cat, :, :] |
| 119 | + for t in range(n_times): |
| 120 | + corr[n_c, t] = fcn(x_c[:, :, t], y_c[:, :, t]) |
| 121 | + |
| 122 | + return corr |
| 123 | + return correlate |
| 124 | + |
| 125 | + |
| 126 | +def get_distance_correlation(implementation='auto'): |
| 127 | + """Get the function to compute the distance correlation. |
| 128 | +
|
| 129 | + Parameters |
| 130 | + ---------- |
| 131 | + implementation : {'auto', 'frites', 'dcor'} |
| 132 | + description |
| 133 | + """ |
| 134 | + if implementation == 'dcor': |
| 135 | + logger.debug('Using dcor implementation of dcorr') |
| 136 | + from dcor import distance_correlation as dcorr |
| 137 | + return dcorr, 'dcor' |
| 138 | + elif implementation == 'frites': |
| 139 | + logger.debug('Using home-made implementation of dcorr') |
| 140 | + return distance_correlation, 'frites' |
| 141 | + elif implementation == 'auto': |
| 142 | + try: |
| 143 | + logger.debug('Using dcor implementation of dcorr') |
| 144 | + from dcor import distance_correlation as dcorr |
| 145 | + return dcorr, 'dcor' |
| 146 | + except ModuleNotFoundError: |
| 147 | + logger.debug('Using home-made implementation of dcorr') |
| 148 | + return distance_correlation, 'frites' |
| 149 | + |
| 150 | +############################################################################### |
| 151 | +############################################################################### |
| 152 | +# DISTANCE CORRELATION |
| 153 | +############################################################################### |
| 154 | +############################################################################### |
| 155 | + |
| 156 | +def dist_eucl(x): |
| 157 | + """Double centered euclidian distance.""" |
| 158 | + if x.ndim == 1: |
| 159 | + x = x[:, np.newaxis] |
| 160 | + n = x.shape[0] |
| 161 | + |
| 162 | + # compute the euclidian distance |
| 163 | + dist = - 2 * x.dot(x.T) |
| 164 | + x_square = (x * x).sum(axis=1) |
| 165 | + np.add(dist, x_square.reshape(n, 1), out=dist) |
| 166 | + np.add(dist, x_square.reshape(1, n), out=dist) |
| 167 | + np.fill_diagonal(dist, 0.) |
| 168 | + np.sqrt(dist, out=dist) |
| 169 | + |
| 170 | + # double centering |
| 171 | + np.subtract(dist, dist.mean(axis=0, keepdims=True), out=dist) |
| 172 | + np.subtract(dist, dist.mean(axis=1, keepdims=True), out=dist) |
| 173 | + np.add(dist, dist.mean(), out=dist) |
| 174 | + |
| 175 | + return dist |
| 176 | + |
| 177 | + |
| 178 | +def distance_correlation(x, y): |
| 179 | + """Compute the distance correlation. |
| 180 | +
|
| 181 | + This function computes the distance correlation between two, possibly |
| 182 | + multivariate, variables. |
| 183 | +
|
| 184 | + Parameter |
| 185 | + --------- |
| 186 | + x, y : array_like |
| 187 | + Arrays of shape (n_samples, n_var) |
| 188 | +
|
| 189 | + Returns |
| 190 | + ------- |
| 191 | + dcorr : float |
| 192 | + The distance correlation between x and y |
| 193 | + """ |
| 194 | + # inputs checking |
| 195 | + assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray) |
| 196 | + if x.dtype not in [np.float32, np.float64]: |
| 197 | + x = x.astype(np.float32, copy=False) |
| 198 | + if y.dtype not in [np.float32, np.float64]: |
| 199 | + y = y.astype(np.float32, copy=False) |
| 200 | + if x.ndim == 1: |
| 201 | + x = x[:, np.newaxis] |
| 202 | + if y.ndim == 1: |
| 203 | + y = y[:, np.newaxis] |
| 204 | + assert (x.ndim == 2) and (y.ndim == 2) |
| 205 | + assert (x.shape[0] == y.shape[0]) |
| 206 | + |
| 207 | + # compute distance across multivariate axis |
| 208 | + n = x.shape[0] |
| 209 | + a = dist_eucl(x) |
| 210 | + b = dist_eucl(y) |
| 211 | + |
| 212 | + # compute covariances |
| 213 | + denom = float(n * n) |
| 214 | + dcov2_xy = (a * b).sum() / denom |
| 215 | + dcov2_xx = (a * a).sum() / denom |
| 216 | + dcov2_yy = (b * b).sum() / denom |
| 217 | + dcor = np.sqrt(dcov2_xy) / np.sqrt(np.sqrt(dcov2_xx) * np.sqrt(dcov2_yy)) |
| 218 | + return dcor |
| 219 | + |
| 220 | + |
| 221 | +if __name__ == '__main__': |
| 222 | + est = DcorrEstimator(implementation='auto', verbose='debug') |
| 223 | + fcn = est.get_function() |
| 224 | + x = np.random.rand(100).reshape(1, 1, -1) |
| 225 | + y = np.random.rand(100).reshape(-1) |
| 226 | + x[..., 0:50] -= y[..., 0:50] |
| 227 | + # x[..., 50:100] += y[..., 50:100] |
| 228 | + from dcor import distance_correlation |
| 229 | + print(distance_correlation(x.squeeze(), y)) |
| 230 | + cat = np.array([0] * 50 + [1] * 50) |
| 231 | + corr = fcn(x, y, categories=None) |
| 232 | + print(corr) |
0 commit comments