Skip to content

Commit

Permalink
Implement trial resampling estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
EtienneCmb committed Aug 9, 2021
1 parent 875189b commit 13f6271
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Expand Up @@ -148,6 +148,7 @@
'../../examples/tutorials',
'../../examples/dataset',
'../../examples/mi',
'../../examples/estimators',
'../../examples/conn',
'../../examples/armodel',
'../../examples/utils',
Expand Down
8 changes: 8 additions & 0 deletions examples/estimators/README.txt
@@ -0,0 +1,8 @@
Information-based estimators
----------------------------

Set of examples illustrating how to use Frites' information-based estimators.

.. contents:: Contents
:local:
:depth: 2
111 changes: 111 additions & 0 deletions examples/estimators/plot_est_resample.py
@@ -0,0 +1,111 @@
"""
Trial-resampling: correcting for unbalanced designs
===================================================
This example illustrates how to correct information estimation in case of
unbalanced designs (i.e. when the number of epochs or trials is very different
between conditions).
The technique of trial-resampling consist in randomly taking an equal number of
trials per condition, estimating the effect size and then repeating this
procedure for a more reliable estimation.
"""
import numpy as np
import pandas as pd

from frites.estimator import GCMIEstimator, ResamplingEstimator, DcorrEstimator

import seaborn as sns
import matplotlib.pyplot as plt


###############################################################################
# Data creation
# -------------
#
# This first section creates the data using random points drawn from gaussian
# distributions

n_variables = 1000 # number of random variables
n_epochs = 500 # total number of epochs
prop = 5 # proportion (in percent) of epochs in the first condition

# proportion of trials
n_prop = int(np.round(prop * n_epochs / 100))

# create continuous variables
x_1 = np.random.normal(loc=1., size=(n_variables, 1, n_prop))
x_2 = np.random.normal(loc=2., size=(n_variables, 1, n_epochs - n_prop))
x = np.concatenate((x_1, x_2), axis=-1)
y_c = np.r_[np.random.normal(size=(n_prop,)),
np.random.normal(size=(n_epochs - n_prop,))]

# create discret variable
y_d = np.array([0] * n_prop + [1] * (n_epochs - n_prop))

print(f"Smaller dataset : {x_1.shape}")
print(f"Larger dataset : {x_2.shape}")


###############################################################################
# Information shared between a continuous and a discret variable
# --------------------------------------------------------------
#
# In this second section, we define an estimator for computing the information
# shared between a continuous and a discret variable. In a second step, we are
# going to wrap this estimator with a trial-resampling estimator.

# mutual information uncorrected estimator
est = GCMIEstimator(mi_type='cd', biascorrect=False)
mi_1 = est.estimate(x, y_d).squeeze()

# mutual information corrected estimator (with trial-resampling)
est_r = ResamplingEstimator(est, n_resampling=100)
mi_2 = est_r.estimate(x, y_d).squeeze()

df = pd.DataFrame({
'MI': np.r_[mi_1, mi_2],
'Estimator': ['Uncorrected'] * len(mi_1) + ['Corrected'] * len(mi_2)
})

###############################################################################
# .. note::
# As shown below, the effect size for the corrected estimator is slightly
# over the non-corrected one.

sns.displot(df, x='MI', hue='Estimator', kde=True, height=7)
plt.title("Information shared between a continuous and a discrete variable")
plt.tight_layout()
plt.show()


###############################################################################
# Information shared between two continuous variables
# ---------------------------------------------------
#
# In this last section, we define an estimator for computing the information
# shared between two continuous variables. Similarly to above, we are then
# going to wrap this estimator with a trial-resampling estimator.

# distance correlation uncorrected estimator
est = DcorrEstimator()
mi_1 = est.estimate(x, y_c, z=y_d).squeeze()

# distance correlation corrected estimator (with trial-resampling)
est_r = ResamplingEstimator(est, n_resampling=20)
mi_2 = est_r.estimate(x, y_c, z=y_d).squeeze()

df = pd.DataFrame({
'MI': np.r_[mi_1, mi_2],
'Estimator': ['Uncorrected'] * len(mi_1) + ['Corrected'] * len(mi_2)
})

###############################################################################
# .. note::
# As shown below, the effect size for the corrected estimator is slightly
# over the non-corrected one.

sns.displot(df, x='MI', hue='Estimator', kde=True, height=7)
plt.title("Information shared between two continuous variables")
plt.tight_layout()
plt.show()
3 changes: 3 additions & 0 deletions frites/estimator/__init__.py
Expand Up @@ -9,3 +9,6 @@

# distance-based estimators
from .est_dcorr import DcorrEstimator # noqa

# resampling estimator
from .est_resampling import ResamplingEstimator # noqa
90 changes: 90 additions & 0 deletions frites/estimator/est_resampling.py
@@ -0,0 +1,90 @@
"""Resampling estimator."""
import numpy as np

from frites.estimator.est_mi_base import BaseMIEstimator
from frites.utils import nonsorted_unique


class ResamplingEstimator(BaseMIEstimator):

"""Trial-resampling estimator.
In case of unbalanced contrast (i.e. when the number of trials per
condition is very different) it can be interesting to use a
trial-resampling technique to minimize the possibility that the effect
size is driven by the number of trials. To this end, the same number of
trials is used to estimate the effect size and the final
"""

def __init__(self, estimator, n_resampling=100, verbose=None):
"""Init."""
self.name = f'{estimator.name} (n_resampling={n_resampling})'
mi_type = estimator.settings['mi_type']
assert mi_type in ['cc', 'cd']
super(ResamplingEstimator, self).__init__(
mi_type=mi_type, verbose=verbose)
# update internal settings
settings = dict(n_resampling=n_resampling)
self.settings.merge([settings, estimator.settings.data])
# track internals
self._estimator = estimator
self._n_resampling = n_resampling

def estimate(self, x, y, z=None, categories=None):
fcn = self.get_function()
return fcn(x, y, z=z, categories=categories)

def get_function(self):

_fcn = self._estimator.get_function()
n_resampling = self._n_resampling

def estimator(x, y, z=None, categories=None):
# define how to balance the classes
classes = y if z is None else z
u_classes = nonsorted_unique(classes)
assert len(u_classes) == 2, "Currently only works for 2 classes"
n_per_classes = {u: (classes == u).sum() for u in u_classes}
min_nb = min(n_per_classes.values())
choices = []
for n_k, k in enumerate(range(n_resampling)):
_choices = []
for c in u_classes:
_trials = np.where(classes == c)[0]
if n_per_classes[c] == min_nb:
_choices += [_trials]
else:
sd = np.random.RandomState(n_k)
__choices = sd.choice(_trials, min_nb, replace=False)
_choices += [__choices]
choices += [np.concatenate(_choices)]

# run computations
mi = []
for tr in choices:
_x, _y = x[..., tr], y[..., tr]
_cat = None if categories is None else categories[tr]
mi.append(_fcn(_x, _y, z=tr, categories=_cat))

# merge computations
mi = np.stack(mi).mean(0)

return mi

return estimator


if __name__ == '__main__':
from frites.estimator import GCMIEstimator

est = GCMIEstimator(mi_type='cc')
est_r = ResamplingEstimator(est)

x = np.random.rand(33, 1, 400)
y = np.random.rand(400)
z = np.array([0] * 20 + [1] * 380)

mi = est_r.estimate(x, y, z)

print(mi.shape)
print(est_r)
49 changes: 49 additions & 0 deletions frites/estimator/tests/test_resampling_est.py
@@ -0,0 +1,49 @@
"""Test resampling estimator."""
import numpy as np

from frites.estimator import (ResamplingEstimator, GCMIEstimator,
DcorrEstimator, CorrEstimator)


# dataset variables
n_trials = 100
x = np.random.rand(n_trials)
y_d = np.array([0] * 50 + [1] * 50)
cat = np.array([0] * 25 + [1] * 75)
y_c = np.random.rand(n_trials)

# estimator creation
est_gc = GCMIEstimator(mi_type='cc', verbose=False)
est_gd = GCMIEstimator(mi_type='cd')
est_c = CorrEstimator()
est_d = DcorrEstimator()


class TestResamplingEstimator(object):

def test_resampling_cc(self):
"""Test resampling between continuous variables."""
for est in [est_gc, est_c, est_d]:
# category free
est_w = ResamplingEstimator(est)
mi = est_w.estimate(x, y_c, z=y_d)
assert mi.shape == (1, 1)

# with categories
mi = est_w.estimate(x, y_c, z=y_d, categories=cat)
assert mi.shape == (2, 1)

def test_resampling_cd(self):
"""Test resampling between a continuous and a discrete variables."""
# category free
est_w = ResamplingEstimator(est_gd)
mi = est_w.estimate(x, y_d)
assert mi.shape == (1, 1)

# with categories
mi = est_w.estimate(x, y_d, categories=cat)
assert mi.shape == (2, 1)


if __name__ == '__main__':
TestResamplingEstimator().test_resampling_cd()

0 comments on commit 13f6271

Please sign in to comment.