Skip to content

Commit

Permalink
Add moment-based maximum entropy class.
Browse files Browse the repository at this point in the history
  • Loading branch information
chebee7i committed Feb 11, 2015
1 parent 726e730 commit e335546
Showing 1 changed file with 106 additions and 7 deletions.
113 changes: 106 additions & 7 deletions dit/algorithms/maxentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def moment(f, pmf, center=0, n=1):
return ((f - center)**n * pmf).sum()


def moment_constraints(pmf, n_variables, symbol_map, m, with_replacement=True):
def moment_constraints(pmf, n_variables, m, symbol_map, with_replacement=True):
"""
Returns `A` and `b` in `A x = b`, for an Ising-like system.
Expand All @@ -297,15 +297,15 @@ def moment_constraints(pmf, n_variables, symbol_map, m, with_replacement=True):
each random variable.
n_variables : int
The number of random variables.
symbol_map : array-like
A mapping from the ith symbol to a real number that is to be used in
the calculation of moments. For example, symbol_map=[-1, 1] corresponds
to the typical Ising model.
m : int | list
The size of the moments to constrain. When `m=2`, pairwise means
are constrained to equal the pairwise means in `pmf`. When `m=3`,
three-way means are constrained to equal those in `pmf.
If m is a list, then include all m-way moments in the list.
symbol_map : array-like
A mapping from the ith symbol to a real number that is to be used in
the calculation of moments. For example, symbol_map=[-1, 1] corresponds
to the typical Ising model.
with_replacement : bool
If `True`, variables are selected with replacement. The standard Ising
does not select with replacement, and so terms like <xx>, <yy> do not
Expand All @@ -329,7 +329,8 @@ def moment_constraints(pmf, n_variables, symbol_map, m, with_replacement=True):
d = AbstractDenseDistribution(n_variables, n_symbols)

if len(pmf) != d.n_elements:
raise ValueError('Length of `pmf` != n_symbols ** n_variables')
msg = 'Length of `pmf` != n_symbols ** n_variables. Symbol map: {0!r}'
raise ValueError(msg.format(symbol_map))

# Begin with the normalization constraint.
A = [ np.ones(d.n_elements) ]
Expand Down Expand Up @@ -392,7 +393,7 @@ def moment_constraint_rank(dist, m, symbol_map=None, cumulative=True, with_repla
if symbol_map is None:
symbol_map = range(n_symbols)

A, b = moment_constraints(pmf, n_variables, symbol_map, mvals,
A, b = moment_constraints(pmf, n_variables, mvals, symbol_map,
with_replacement=with_replacement)
C, d, rank = as_full_rank(A, b)

Expand Down Expand Up @@ -623,6 +624,60 @@ def build_linear_equality_constraints(self):
self.A = A
self.b = b

class MomentMaximumEntropy(MaximumEntropy):
"""
Find maximum entropy distribution subject to k-way marginal constraints.
k=0 should reproduce the behavior of MaximumEntropy.
"""
def __init__(self, dist, k, symbol_map, cumulative=True, with_replacement=True, prng=None):
"""
Initialize optimizer.
Parameters
----------
dist : distribution
The distribution used to specify the marginal constraints.
k : int
The number of variables in the constrained marginals.
symbol_map : list
The mapping from states to real numbers. This is used while taking
moments.
"""
self.k = k
self.symbol_map = symbol_map
self.cumulative = cumulative
self.with_replacement = with_replacement
super(MomentMaximumEntropy, self).__init__(dist, prng=prng)


def build_linear_equality_constraints(self):
from cvxopt import matrix

# Dimension of optimization variable
n = self.n

if self.cumulative:
k = range(self.k + 1)
else:
k = [self.k]

args = (self.pmf, self.n_variables, k, self.symbol_map)
kwargs = {'with_replacement': self.with_replacement}
A, b = moment_constraints(*args, **kwargs)
A, b, rank = as_full_rank(A, b)
if rank > n:
raise ValueError('More independent constraints than parameters.')

A = matrix(A)
b = matrix(b) # now a column vector

self.A = A
self.b = b


def marginal_maxent_dists(dist, k_max=None, jitter=True, show_progress=True):
"""
Return the marginal-constrained maximum entropy distributions.
Expand Down Expand Up @@ -657,3 +712,47 @@ def marginal_maxent_dists(dist, k_max=None, jitter=True, show_progress=True):
dists.append(d)

return dists


def moment_maxent_dists(dist, symbol_map, k_max=None, jitter=True, with_replacement=True, show_progress=True):
"""
Return the marginal-constrained maximum entropy distributions.
"""
dist = dit.expanded_samplespace(dist, union=True)
dist.make_dense()

if jitter:
# This is sometimes necessary. If your distribution does not have
# full support than convergence can be difficult to come by.
dist.pmf = dit.math.pmfops.jittered(dist.pmf)

pmf = dist.pmf
n_variables = dist.outcome_length()
n_symbols = len(dist.alphabet[0])
symbols = dist.alphabet[0]

if k_max is None:
k_max = n_variables

outcomes = list(dist._product(symbols, repeat=n_variables))

if with_replacement:
text = 'with replacement'
else:
text = 'without replacement'

dists = []
for k in range(k_max + 1):
msg = "Constraining maxent dist to match {0}-way moments, {1}."
print()
print(msg.format(k, text))
print()
opt = MomentMaximumEntropy(dist, k, symbol_map, with_replacement=with_replacement)
pmf_opt = opt.optimize(show_progress=show_progress)
d = dit.Distribution(outcomes, pmf_opt)
dists.append(d)

return dists


0 comments on commit e335546

Please sign in to comment.