Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement trial resampling estimator
- Loading branch information
1 parent
875189b
commit 13f6271
Showing
6 changed files
with
262 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
Information-based estimators | ||
---------------------------- | ||
|
||
Set of examples illustrating how to use Frites' information-based estimators. | ||
|
||
.. contents:: Contents | ||
:local: | ||
:depth: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
Trial-resampling: correcting for unbalanced designs | ||
=================================================== | ||
This example illustrates how to correct information estimation in case of | ||
unbalanced designs (i.e. when the number of epochs or trials is very different | ||
between conditions). | ||
The technique of trial-resampling consist in randomly taking an equal number of | ||
trials per condition, estimating the effect size and then repeating this | ||
procedure for a more reliable estimation. | ||
""" | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from frites.estimator import GCMIEstimator, ResamplingEstimator, DcorrEstimator | ||
|
||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
############################################################################### | ||
# Data creation | ||
# ------------- | ||
# | ||
# This first section creates the data using random points drawn from gaussian | ||
# distributions | ||
|
||
n_variables = 1000 # number of random variables | ||
n_epochs = 500 # total number of epochs | ||
prop = 5 # proportion (in percent) of epochs in the first condition | ||
|
||
# proportion of trials | ||
n_prop = int(np.round(prop * n_epochs / 100)) | ||
|
||
# create continuous variables | ||
x_1 = np.random.normal(loc=1., size=(n_variables, 1, n_prop)) | ||
x_2 = np.random.normal(loc=2., size=(n_variables, 1, n_epochs - n_prop)) | ||
x = np.concatenate((x_1, x_2), axis=-1) | ||
y_c = np.r_[np.random.normal(size=(n_prop,)), | ||
np.random.normal(size=(n_epochs - n_prop,))] | ||
|
||
# create discret variable | ||
y_d = np.array([0] * n_prop + [1] * (n_epochs - n_prop)) | ||
|
||
print(f"Smaller dataset : {x_1.shape}") | ||
print(f"Larger dataset : {x_2.shape}") | ||
|
||
|
||
############################################################################### | ||
# Information shared between a continuous and a discret variable | ||
# -------------------------------------------------------------- | ||
# | ||
# In this second section, we define an estimator for computing the information | ||
# shared between a continuous and a discret variable. In a second step, we are | ||
# going to wrap this estimator with a trial-resampling estimator. | ||
|
||
# mutual information uncorrected estimator | ||
est = GCMIEstimator(mi_type='cd', biascorrect=False) | ||
mi_1 = est.estimate(x, y_d).squeeze() | ||
|
||
# mutual information corrected estimator (with trial-resampling) | ||
est_r = ResamplingEstimator(est, n_resampling=100) | ||
mi_2 = est_r.estimate(x, y_d).squeeze() | ||
|
||
df = pd.DataFrame({ | ||
'MI': np.r_[mi_1, mi_2], | ||
'Estimator': ['Uncorrected'] * len(mi_1) + ['Corrected'] * len(mi_2) | ||
}) | ||
|
||
############################################################################### | ||
# .. note:: | ||
# As shown below, the effect size for the corrected estimator is slightly | ||
# over the non-corrected one. | ||
|
||
sns.displot(df, x='MI', hue='Estimator', kde=True, height=7) | ||
plt.title("Information shared between a continuous and a discrete variable") | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
|
||
############################################################################### | ||
# Information shared between two continuous variables | ||
# --------------------------------------------------- | ||
# | ||
# In this last section, we define an estimator for computing the information | ||
# shared between two continuous variables. Similarly to above, we are then | ||
# going to wrap this estimator with a trial-resampling estimator. | ||
|
||
# distance correlation uncorrected estimator | ||
est = DcorrEstimator() | ||
mi_1 = est.estimate(x, y_c, z=y_d).squeeze() | ||
|
||
# distance correlation corrected estimator (with trial-resampling) | ||
est_r = ResamplingEstimator(est, n_resampling=20) | ||
mi_2 = est_r.estimate(x, y_c, z=y_d).squeeze() | ||
|
||
df = pd.DataFrame({ | ||
'MI': np.r_[mi_1, mi_2], | ||
'Estimator': ['Uncorrected'] * len(mi_1) + ['Corrected'] * len(mi_2) | ||
}) | ||
|
||
############################################################################### | ||
# .. note:: | ||
# As shown below, the effect size for the corrected estimator is slightly | ||
# over the non-corrected one. | ||
|
||
sns.displot(df, x='MI', hue='Estimator', kde=True, height=7) | ||
plt.title("Information shared between two continuous variables") | ||
plt.tight_layout() | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""Resampling estimator.""" | ||
import numpy as np | ||
|
||
from frites.estimator.est_mi_base import BaseMIEstimator | ||
from frites.utils import nonsorted_unique | ||
|
||
|
||
class ResamplingEstimator(BaseMIEstimator): | ||
|
||
"""Trial-resampling estimator. | ||
In case of unbalanced contrast (i.e. when the number of trials per | ||
condition is very different) it can be interesting to use a | ||
trial-resampling technique to minimize the possibility that the effect | ||
size is driven by the number of trials. To this end, the same number of | ||
trials is used to estimate the effect size and the final | ||
""" | ||
|
||
def __init__(self, estimator, n_resampling=100, verbose=None): | ||
"""Init.""" | ||
self.name = f'{estimator.name} (n_resampling={n_resampling})' | ||
mi_type = estimator.settings['mi_type'] | ||
assert mi_type in ['cc', 'cd'] | ||
super(ResamplingEstimator, self).__init__( | ||
mi_type=mi_type, verbose=verbose) | ||
# update internal settings | ||
settings = dict(n_resampling=n_resampling) | ||
self.settings.merge([settings, estimator.settings.data]) | ||
# track internals | ||
self._estimator = estimator | ||
self._n_resampling = n_resampling | ||
|
||
def estimate(self, x, y, z=None, categories=None): | ||
fcn = self.get_function() | ||
return fcn(x, y, z=z, categories=categories) | ||
|
||
def get_function(self): | ||
|
||
_fcn = self._estimator.get_function() | ||
n_resampling = self._n_resampling | ||
|
||
def estimator(x, y, z=None, categories=None): | ||
# define how to balance the classes | ||
classes = y if z is None else z | ||
u_classes = nonsorted_unique(classes) | ||
assert len(u_classes) == 2, "Currently only works for 2 classes" | ||
n_per_classes = {u: (classes == u).sum() for u in u_classes} | ||
min_nb = min(n_per_classes.values()) | ||
choices = [] | ||
for n_k, k in enumerate(range(n_resampling)): | ||
_choices = [] | ||
for c in u_classes: | ||
_trials = np.where(classes == c)[0] | ||
if n_per_classes[c] == min_nb: | ||
_choices += [_trials] | ||
else: | ||
sd = np.random.RandomState(n_k) | ||
__choices = sd.choice(_trials, min_nb, replace=False) | ||
_choices += [__choices] | ||
choices += [np.concatenate(_choices)] | ||
|
||
# run computations | ||
mi = [] | ||
for tr in choices: | ||
_x, _y = x[..., tr], y[..., tr] | ||
_cat = None if categories is None else categories[tr] | ||
mi.append(_fcn(_x, _y, z=tr, categories=_cat)) | ||
|
||
# merge computations | ||
mi = np.stack(mi).mean(0) | ||
|
||
return mi | ||
|
||
return estimator | ||
|
||
|
||
if __name__ == '__main__': | ||
from frites.estimator import GCMIEstimator | ||
|
||
est = GCMIEstimator(mi_type='cc') | ||
est_r = ResamplingEstimator(est) | ||
|
||
x = np.random.rand(33, 1, 400) | ||
y = np.random.rand(400) | ||
z = np.array([0] * 20 + [1] * 380) | ||
|
||
mi = est_r.estimate(x, y, z) | ||
|
||
print(mi.shape) | ||
print(est_r) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Test resampling estimator.""" | ||
import numpy as np | ||
|
||
from frites.estimator import (ResamplingEstimator, GCMIEstimator, | ||
DcorrEstimator, CorrEstimator) | ||
|
||
|
||
# dataset variables | ||
n_trials = 100 | ||
x = np.random.rand(n_trials) | ||
y_d = np.array([0] * 50 + [1] * 50) | ||
cat = np.array([0] * 25 + [1] * 75) | ||
y_c = np.random.rand(n_trials) | ||
|
||
# estimator creation | ||
est_gc = GCMIEstimator(mi_type='cc', verbose=False) | ||
est_gd = GCMIEstimator(mi_type='cd') | ||
est_c = CorrEstimator() | ||
est_d = DcorrEstimator() | ||
|
||
|
||
class TestResamplingEstimator(object): | ||
|
||
def test_resampling_cc(self): | ||
"""Test resampling between continuous variables.""" | ||
for est in [est_gc, est_c, est_d]: | ||
# category free | ||
est_w = ResamplingEstimator(est) | ||
mi = est_w.estimate(x, y_c, z=y_d) | ||
assert mi.shape == (1, 1) | ||
|
||
# with categories | ||
mi = est_w.estimate(x, y_c, z=y_d, categories=cat) | ||
assert mi.shape == (2, 1) | ||
|
||
def test_resampling_cd(self): | ||
"""Test resampling between a continuous and a discrete variables.""" | ||
# category free | ||
est_w = ResamplingEstimator(est_gd) | ||
mi = est_w.estimate(x, y_d) | ||
assert mi.shape == (1, 1) | ||
|
||
# with categories | ||
mi = est_w.estimate(x, y_d, categories=cat) | ||
assert mi.shape == (2, 1) | ||
|
||
|
||
if __name__ == '__main__': | ||
TestResamplingEstimator().test_resampling_cd() |