Skip to content

Commit

Permalink
Allow multi-dimensional data + internal repetitions for y and z
Browse files Browse the repository at this point in the history
  • Loading branch information
EtienneCmb committed Feb 4, 2021
1 parent 4bb5d7c commit 0015bf5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
25 changes: 24 additions & 1 deletion frites/estimator/est_gcmi.py
Expand Up @@ -44,7 +44,7 @@ class GCMIEstimator(BaseMIEstimator):

def __init__(self, mi_type='cc', copnorm=True, biascorrect=True,
demeaned=False, tensor=True, gpu=False, verbose=None):
self._est_name = 'Gaussian Copula Mutual Information Estimator'
self.name = 'Gaussian Copula Mutual Information Estimator'
super(GCMIEstimator, self).__init__(mi_type=mi_type, verbose=verbose)

# =========================== Core function ===========================
Expand Down Expand Up @@ -150,6 +150,24 @@ def estimator(x, y, z=None, categories=None):
if (mi_type == 'ccc') and (z is not None):
z = copnorm_cat_nd(z, categories, axis=-1)

# nd var support
assert x.ndim >= 3
reshape = None
if x.ndim > 3:
head_shape = list(x.shape)[0:-2]
reshape = (head_shape, np.prod(head_shape))
tail_shape = list(x.shape)[-2::]
x = x.reshape([reshape[1]] + tail_shape)

# repeat y and z(if needed)
if (mi_type != 'cd') and (y.ndim == 1):
n_var, n_mv, _ = x.shape
y = np.tile(y, (n_var, n_mv, 1))
if (mi_type == 'ccc') and (y.ndim == 1):
n_var, n_mv, _ = x.shape
z = np.tile(z, (n_var, n_mv, 1))

# compute (potentially categorical) MI
n_var = x.shape[0]
args = ()
if isinstance(categories, np.ndarray):
Expand All @@ -168,6 +186,11 @@ def estimator(x, y, z=None, categories=None):
if mi_type in ['ccd', 'ccc']:
args = [z]
mi = core_fun(x, y, *args, **kwargs)[np.newaxis, :]

# retrieve original shape (if needed)
if reshape is not None:
mi = mi.reshape([mi.shape[0]] + reshape[0])

return mi
return estimator

Expand Down
8 changes: 4 additions & 4 deletions frites/estimator/est_mi_base.py
Expand Up @@ -28,17 +28,17 @@ def __init__(self, mi_type='cc', verbose=None):
settings = {'description': desc}
self.settings = Attributes(attrs=settings, section_name='Settings')
self._kwargs = dict()
assert hasattr(self, '_est_name')
assert hasattr(self, 'name')

logger.info(f"{self._est_name} ({mi_type})")
logger.info(f"{self.name} ({mi_type})")

def __repr__(self):
"""Overall representation."""
return '*** ' + self._est_name + ' ***\n' + self.settings.__repr__()
return '*** ' + self.name + ' ***\n' + self.settings.__repr__()

def _repr_html_(self):
"""IPython representation."""
title = f"<h3><br>{self._est_name}</br></h3>"
title = f"<h3><br>{self.name}</br></h3>"
return title + self.settings._repr_html_()

def estimate(self, x, y, z=None, categories=None):
Expand Down

0 comments on commit 0015bf5

Please sign in to comment.