-
Notifications
You must be signed in to change notification settings - Fork 19
/
est_mi_base.py
57 lines (43 loc) · 1.81 KB
/
est_mi_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Base estimator for the mutual information."""
from frites.io import Attributes
from frites.config import CONFIG
from frites.io import set_log_level, logger
class BaseMIEstimator(object):
"""Base class for mutual-information estimators.
Parameters
----------
mi_type : {'cc', 'cd', 'ccd', 'ccc'}
Mutual information type :
* 'cc' : MI between two continuous variables
* 'cd' : MI between a continuous and a discret variables
* 'ccd' : MI between two continuous variables conditioned by a
third discret one
* 'ccc' : MI between two continuous variables conditioned by a
third continuous one
"""
def __init__(self, mi_type='cc', verbose=None):
"""Init."""
set_log_level(verbose)
desc = CONFIG['MI_REPR'][mi_type]
settings = {'description': desc}
self.settings = Attributes(attrs=settings, section_name='Settings')
self._kwargs = dict()
assert hasattr(self, 'name')
logger.info(f"{self.name} ({mi_type})")
def __repr__(self):
"""Overall representation."""
return '*** ' + self.name + ' ***\n' + self.settings.__repr__()
def _repr_html_(self):
"""IPython representation."""
title = f"<h3><br>{self.name}</br></h3>"
return title + self.settings._repr_html_()
def estimate(self, x, y, z=None, categories=None):
"""Estimate the (possibly conditional) mutual-information."""
raise NotImplementedError()
def get_function(self):
"""Get the function to execute.
The returned function should have the following signature :
* fcn(x, y, z=None, categories=None)
and should returned an array of shape (n_categories, n_var).
"""
raise NotImplementedError()