Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the perplexity #33

Merged
merged 4 commits into from
Oct 10, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 2 additions & 50 deletions dit/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,15 @@
Import several functions as shorthand.
"""

from dit.algorithms import conditional_entropy, entropy
from dit.utils.misc import flatten

from dit import (Distribution as D,
ScalarDistribution as SD,
)

from dit.algorithms import (coinformation as I,
common_information as K,
total_correlation as T,
perplexity as P,
jensen_shannon_divergence as JSD,
)

__all__ = ['D', 'SD', 'H', 'I', 'K', 'T', 'JSD']

def H(dist, rvs=None, crvs=None, rv_names=None):
"""
Parameters
----------
dist : Distribution
The distribution from which the entropy is calculated.
rvs : list, None
The indexes of the random variable used to calculate the entropy. If
None, then the total correlation is calculated over all random
variables.
crvs : list, None
The indexes of the random variables to condition on. If None, then no
variables are condition on.
rv_names : bool
If `True`, then the elements of `rvs` are treated as random variable
names. If `False`, then the elements of `rvs` are treated as random
variable indexes. If `None`, then the value `True` is used if the
distribution has specified names for its random variables.

Returns
-------
H : float
The entropy.

Raises
------
ditException
Raised if `dist` is not a joint distribution.
"""
if dist.is_joint():
if rvs is None:
# Set to entropy of entire distribution
rvs = list(range(dist.outcome_length()))
rv_names = False
else:
# this will allow inputs of the form [0, 1, 2] or [[0, 1], [2]],
# allowing uniform behavior with the mutual information like
# measures.
rvs = set(flatten(rvs))
if crvs is None:
crvs = []
else:
return entropy(dist)

return conditional_entropy(dist, rvs, crvs, rv_names)
from dit.algorithms.entropy2 import entropy2 as H
1 change: 1 addition & 0 deletions dit/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .shannon import entropy, conditional_entropy, mutual_information
from .total_correlation import total_correlation
from .coinformation import coinformation
from .perplexity import perplexity
from .jsd import jensen_shannon_divergence
from .common_info import common_information
from .lattice import insert_join, insert_meet
47 changes: 47 additions & 0 deletions dit/algorithms/entropy2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
A version of the entropy with signature common to the other multivariate
measures.
"""

from .shannon import conditional_entropy, entropy
from ..utils.misc import flatten

def entropy2(dist, rvs=None, crvs=None, rv_names=None):
"""
Parameters
----------
dist : Distribution
The distribution from which the entropy is calculated.
rvs : list, None
The indexes of the random variable used to calculate the entropy. If
None, then the entropy is calculated over all random variables.
crvs : list, None
The indexes of the random variables to condition on. If None, then no
variables are condition on.
rv_names : bool
If `True`, then the elements of `rvs` are treated as random variable
names. If `False`, then the elements of `rvs` are treated as random
variable indexes. If `None`, then the value `True` is used if the
distribution has specified names for its random variables.

Returns
-------
H : float
The entropy.
"""
if dist.is_joint():
if rvs is None:
# Set to entropy of entire distribution
rvs = list(range(dist.outcome_length()))
rv_names = False
else:
# this will allow inputs of the form [0, 1, 2] or [[0, 1], [2]],
# allowing uniform behavior with the mutual information like
# measures.
rvs = set(flatten(rvs))
if crvs is None:
crvs = []
else:
return entropy(dist)

return conditional_entropy(dist, rvs, crvs, rv_names)
49 changes: 49 additions & 0 deletions dit/algorithms/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
The perplexity of a distribution.
"""

from .shannon import conditional_entropy, entropy
from ..utils.misc import flatten

def perplexity(dist, rvs=None, crvs=None, rv_names=None):
"""
Parameters
----------
dist : Distribution
The distribution from which the perplexity is calculated.
rvs : list, None
The indexes of the random variable used to calculate the perplexity.
If None, then the perpelxity is calculated over all random variables.
crvs : list, None
The indexes of the random variables to condition on. If None, then no
variables are condition on.
rv_names : bool
If `True`, then the elements of `rvs` are treated as random variable
names. If `False`, then the elements of `rvs` are treated as random
variable indexes. If `None`, then the value `True` is used if the
distribution has specified names for its random variables.

Returns
-------
P : float
The perplexity.
"""

base = dist.get_base(numerical=True) if dist.is_log() else 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


if dist.is_joint():
if rvs is None:
# Set to entropy of entire distribution
rvs = list(range(dist.outcome_length()))
rv_names = False
else:
# this will allow inputs of the form [0, 1, 2] or [[0, 1], [2]],
# allowing uniform behavior with the mutual information like
# measures.
rvs = set(flatten(rvs))
if crvs is None:
crvs = []
else:
return base**entropy(dist)

return base**conditional_entropy(dist, rvs, crvs, rv_names)
44 changes: 44 additions & 0 deletions dit/algorithms/tests/test_perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import division

from nose.tools import *

from dit import (ScalarDistribution as SD,
Distribution as D)
from dit.algorithms import perplexity as P
from six.moves import range


def test_p1():
for i in range(2, 10):
assert_almost_equal(P(SD([1/i]*i)), i)

def test_p2():
for i in range(2, 10):
d = SD([1/i]*i)
d.set_base(i)
assert_almost_equal(P(d), i)

def test_p3():
for i in range(2, 10):
d = D([str(_) for _ in range(i)], [1/i]*i)
assert_almost_equal(P(d), i)

def test_p4():
for i in range(2, 10):
d = D([str(_) for _ in range(i)], [1/i]*i)
d.set_base(i)
assert_almost_equal(P(d), i)

def test_p5():
d = D(['00', '01', '10', '11'], [1/4]*4)
assert_almost_equal(P(d), 4)
assert_almost_equal(P(d, [0]), 2)
assert_almost_equal(P(d, [1]), 2)
assert_almost_equal(P(d, [0], [1]), 2)
assert_almost_equal(P(d, [1], [0]), 2)

def test_p6():
d = D(['00', '11'], [1/2]*2)
assert_almost_equal(P(d), 2)
assert_almost_equal(P(d, [0], [1]), 1)
assert_almost_equal(P(d, [1], [0]), 1)