Skip to content

Commit

Permalink
Finish rv_names -> rv_mode for dit.other.
Browse files Browse the repository at this point in the history
  • Loading branch information
chebee7i committed Feb 23, 2015
1 parent bc8a85e commit 0c67f29
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions dit/other/extropy.py
Expand Up @@ -2,11 +2,12 @@
The extropy
"""

from ..math.ops import LogOperations
from ..helpers import RV_MODES
from ..math.ops import get_ops

import numpy as np

def extropy(dist, rvs=None, rv_names=None):
def extropy(dist, rvs=None, rv_mode=None):
"""
Returns the extropy J[X] over the random variables in `rvs`.
Expand All @@ -22,11 +23,12 @@ def extropy(dist, rvs=None, rv_names=None):
The indexes of the random variable used to calculate the extropy.
If None, then the extropy is calculated over all random variables.
This should remain `None` for ScalarDistributions.
rv_names : bool
If `True`, then the elements of `rvs` are treated as random variable
names. If `False`, then the elements of `rvs` are treated as random
variable indexes. If `None`, then the value `True` is used if the
distribution has specified names for its random variables.
rv_mode : str, None
Specifies how to interpret the elements of `rvs`. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of
`rvs` are interpreted as random variable indices. If equal to 'names',
the the elements are interpreted as random variable names. If `None`,
then the value of `dist._rv_mode` is consulted.
Returns
-------
Expand All @@ -44,15 +46,15 @@ def extropy(dist, rvs=None, rv_names=None):
import dit
dist = dit.ScalarDistribution([dist, 1-dist])
rvs = None
rv_names = False
rv_mode = RV_MODES.INDICES

if dist.is_joint():
if rvs is None:
# Set to entropy of entire distribution
rvs = list(range(dist.outcome_length()))
rv_names = False
rv_mode = RV_MODES.INDICES

d = dist.marginal(rvs, rv_names=rv_names)
d = dist.marginal(rvs, rv_mode=rv_mode)
else:
d = dist

Expand All @@ -63,7 +65,7 @@ def extropy(dist, rvs=None, rv_names=None):
terms = -base**npmf * npmf
else:
# Calculate entropy in bits.
log = LogOperations(2).log
log = get_ops(2).log
npmf = 1 - pmf
terms = -npmf * log(npmf)

Expand Down

0 comments on commit 0c67f29

Please sign in to comment.