Skip to content

Commit

Permalink
Refactor downsample(). Improve projections()
Browse files Browse the repository at this point in the history
Changes:
    1) Update API to match simplex_grid().
    2) Lower-level function now takes op to determine min/max route.
    3) projections() now works with multiple pmfs.
    4) projections() can find other nearby grid points.
  • Loading branch information
chebee7i committed Oct 28, 2014
1 parent 2566b1a commit be84e57
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 85 deletions.
172 changes: 93 additions & 79 deletions dit/math/pmfops.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def convex_combination(pmfs, weights=None):
mixture = (pmfs * weights[:, np.newaxis]).sum(axis=0)
return mixture

def downsample(pmf, depth, base=2, method='componentL1'):
def downsample(pmf, subdivisions, method='componentL1'):
"""
Returns the nearest pmf on a triangular grid.
Expand All @@ -112,12 +112,12 @@ def downsample(pmf, depth, base=2, method='componentL1'):
----------
pmf : NumPy array, shape (n,) or (k, n)
The pmf on the ``(n-1)``-simplex.
depth : int
Controls the density of the grid. The number of points on the simplex
is given by: (base**depth + length - 1)! / (base**depth)! / (length-1)!
At each depth, the number of points is exponentially increased.
base : int
The rate at which we divide probabilities..
subdivisions : int
The number of subdivisions for the interval [0, 1]. The grid considered
is such that each component will take on values at the boundaries of
the subdivisions. For example, subdivisions corresponds to
:math:`[[0, 1/2], [1/2, 1]]` and thus, each component can take the
values 0, 1/2, or 1. So one possible pmf would be (1/2, 1/2, 0).
method : str
The algorithm used to determine what `nearest` means. The default
method, 'componentL1', moves each component to its nearest grid
Expand All @@ -134,51 +134,75 @@ def downsample(pmf, depth, base=2, method='componentL1'):
"""
if method in _methods:
return _methods[method](pmf, depth, base)
return _methods[method](pmf, subdivisions)
else:
raise NotImplementedError('Unknown method.')

def downsample_componentL1(pmf, depth, base=2):
def _downsample_componentL1(pmf, i, op, locs):
"""
Low-level function to incrementally project a pmf.
Parameters
----------
pmf : NumPy array, shape (n, k)
A 2D NumPy array that is modified in-place. The columns represent
the various pmfs. The rows represent each component.
i : int
The component to be projected.
op : callable
This is np.argmin or np.argmax. It determines the projection.
locs : NumPy array
The subdivisions for each component.
"""
# Find insertion indexes
insert_index = np.searchsorted(locs, pmf[i])
# Define the indexes of clamped region for each component.
lower = insert_index - 1
upper = insert_index
clamps = np.array([lower, upper])
# Actually get the clamped region
gridvals = locs[clamps]
# Calculate distance to each point, per component.
distances = np.abs(gridvals - pmf[i])
# Determine which index each component was closest to.
# desired[i] == 0 means that the lower index was closer
# desired[i] == 1 means that the upper index was closer
desired = op(distances, axis=0)
# Pull those indexes from the clamping indexes
# So when desired[i] == 1, we want to pull the upper index.
locations = np.where(desired, upper, lower)
pmf[i] = locs[locations]
# Now renormalize the other components of the distribution...
temp = pmf.transpose() # View
prev_Z = temp[..., :i+1].sum(axis=-1)
zeros = np.isclose(prev_Z, 1)
Z = (1 - prev_Z) / temp[..., i+1:].sum(axis=-1)
temp[..., i+1:] *= Z[..., np.newaxis]
# This assumes len(shape) == 2.
temp[zeros, i+1:] = 0
return locations

def downsample_componentL1(pmf, subdivisions):
"""
Clamps each component, one-by-one.
Renormalizes and uses updated insert indexes as you go.
"""
N = base**depth
locs = np.linspace(0, 1, N + 1)
locs = np.linspace(0, 1, subdivisions + 1)

out = np.atleast_2d(pmf).transpose().copy()
# Go through each component.
# Go through each component and move to closest component.
op = np.argmin
for i in range(out.shape[0] - 1):
# Find insertion indexes
insert_index = np.searchsorted(locs, out[i])
# Define the indexes of clamped region for each component.
clamps = np.array([insert_index - 1, insert_index])
# Actually get the clamped region
gridvals = locs[clamps]
# Calculate distance to each point, per component.
distances = np.abs(gridvals - out[i])
# Determine which index each component was closest to.
desired = np.argmin(distances, axis=0)
# Pull those indexes from the clamping indexes
locations = np.where(desired, insert_index, insert_index - 1)
out[i] = locs[locations]
# Now renormalize the other components of the distribution...
temp = out.transpose() # View
prev_Z = temp[..., :i+1].sum(axis=-1)
zeros = np.isclose(prev_Z, 1)
Z = (1 - prev_Z) / temp[..., i+1:].sum(axis=-1)
temp[..., i+1:] *= Z[..., np.newaxis]
# This assumes len(shape) == 2.
temp[zeros, i+1:] = 0
locations = _downsample_componentL1(out, i, op, locs)

out = out.transpose()
out[...,-1] = 1 - out[...,:-1].sum(axis=-1)
if len(pmf.shape) == 1:
out = out[0]
return out

def clamped_indexes(pmf, depth, base=2):
def clamped_indexes(pmf, subdivisions):
"""
Returns the indexes of the component values that clamp the pmf.
Expand All @@ -187,83 +211,73 @@ def clamped_indexes(pmf, depth, base=2):
clamps : NumPy array, shape (2,n) or (2,k,n)
"""
N = base**depth
locs = np.linspace(0, 1, N + 1)
locs = np.linspace(0, 1, subdivisions + 1)
# Find insertion indexes
insert_index = np.searchsorted(locs, pmf)
# Define the indexes of clamped region for each component.
clamps = np.array([insert_index - 1, insert_index])

return clamps, locs

def projections(pmf, depth, base=2, method=None):
def projections(pmf, subdivisions, ops=None):
"""
Returns the projections on the way to the nearest grid point.
The original pmf is included in the final output.
Parameters
----------
pmf : NumPy array, shape (n,)
The pmf on the ``(n-1)``-simplex.
depth : int
Controls the density of the grid. The number of points on the simplex
is given by: (base**depth + length - 1)! / (base**depth)! / (length-1)!
At each depth, the number of points is exponentially increased.
base : int
The rate at which we divide probabilities..
pmf : NumPy array, shape (n,) or (k, n)
The pmf on the ``(n-1)``-simplex. Optionally, provide `k` pmfs.
subdivisions : int
The number of subdivisions for the interval [0, 1]. The grid considered
is such that each component will take on values at the boundaries of
the subdivisions. For example, subdivisions corresponds to
:math:`[[0, 1/2], [1/2, 1]]` and thus, each component can take the
values 0, 1/2, or 1. So one possible pmf would be (1/2, 1/2, 0).
method : str
The algorithm used to determine what `nearest` means. The default
method, 'componentL1', moves each component to its nearest grid
value using the L1 norm.
Other Parameters
----------------
ops : list
A list of `n-1` callables, where `n` the number of components in the
pmf. Each element in the list is a callable the determines how the
downsampled pmf's are constructed by specifying which of the lower
and upper clamped location indexes should be chosen. If `None`, then
`ops` is a list of `np.argmin` and will select the closest grid point.
Returns
-------
d : NumPy array, shape (n,n)
d : NumPy array, shape (n,n) or (n,k,n)
The projections leading to the downsampled pmf.
See Also
--------
downsample, dit.simplex_grid
"""
# We can only have 1 pmf.
assert(len(pmf.shape) == 1)
locs = np.linspace(0, 1, subdivisions + 1)

N = base**depth
locs = np.linspace(0, 1, N + 1)
out = np.atleast_2d(pmf).transpose().copy()
projs = [out.copy()]

out = pmf.copy()
# Go through each component.
if ops is None:
# Take closest point in regional cell.
ops = [np.argmin] * (out.shape[0] - 1)

projs = [out.copy()]
for i in range(out.shape[0] - 1):
# Find insertion indexes
insert_index = np.searchsorted(locs, out[i])
# Define the indexes of clamped region for each component.
clamps = np.array([insert_index - 1, insert_index])
# Actually get the clamped region
gridvals = locs[clamps]
# Calculate distance to each point, per component.
distances = np.abs(gridvals - out[i])
# Determine which index each component was closest to.
desired = np.argmin(distances, axis=0)
# Pull those indexes from the clamping indexes
locations = np.where(desired, insert_index, insert_index - 1)
out[i] = locs[locations]
# Now renormalize the other components of the distribution...
prev_Z = out[:i+1].sum(axis=-1)
zeros = np.isclose(prev_Z, 1)
if zeros:
Z = 0
else:
Z = (1 - prev_Z) / out[i+1:].sum(axis=-1)
out[i+1:] *= Z
# Go through each component and move to closest component.
for i, op in zip(range(out.shape[0] - 1), ops):
_downsample_componentL1(out, i, op, locs)
projs.append(out.copy())

return np.asarray(projs)

projs = np.asarray(projs)
projs = np.swapaxes(projs, 1, 2)
if len(pmf.shape) == 1:
projs = projs[:,0,:]
return projs

_methods = {
'componentL1': downsample_componentL1
Expand Down
22 changes: 16 additions & 6 deletions dit/math/tests/test_pmfops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,20 @@ def test_downsample_onepmf():
# One pmf
d1 = np.array([0, .51, .49])
d2_ = np.array([0, .5, .5])
d2 = module.downsample(d1, 1)
d2 = module.downsample(d1, 2)
np.testing.assert_allclose(d2, d2_)

def test_downsample_twopmf():
# Two pmf
d1 = np.array([[0, .51, .49], [.6, .3, .1]])
d2_ = np.array([[0, .5, .5], [.5, .5, 0]])
d2 = module.downsample(d1, 1)
d2 = module.downsample(d1, 2)
np.testing.assert_allclose(d2, d2_)

def test_downsample_badmethod():
d1 = np.array([0, .51, .49])
assert_raises(
NotImplementedError, module.downsample, d1, 3, method='whatever'
NotImplementedError, module.downsample, d1, 2**3, method='whatever'
)

def test_projections1():
Expand All @@ -66,7 +66,7 @@ def test_projections1():
[ 0. , 0.92998325, 0.07001675],
[ 0. , 0.875 , 0.125 ]
])
d2 = module.projections(d, 3)
d2 = module.projections(d, 2**3)
np.testing.assert_allclose(d2, d2_)

def test_projections2():
Expand All @@ -76,13 +76,23 @@ def test_projections2():
[ 0.5 , 0.48979592, 0.01020408],
[ 0.5 , 0.5 , 0. ]
])
d2 = module.projections(d, 3)
d2 = module.projections(d, 2**3)
np.testing.assert_allclose(d2, d2_, rtol=1e-7, atol=1e-8)

def test_projections_max():
d = np.array([ 0.51, 0.48, 0.01])
d2_ = np.array([
[ 0.51 , 0.48 , 0.01 ],
[ 0.625 , 0.36734694, 0.00765306],
[ 0.625 , 0.25 , 0.125 ]
])
d2 = module.projections(d, 2**3, [np.argmax, np.argmax, np.argmax])
np.testing.assert_allclose(d2, d2_, rtol=1e-7, atol=1e-8)

def test_clamps():
d = np.array([.51, .48, .01])
out_ = (np.array([[4, 3, 0], [5, 4, 1]]),
np.array([ 0., 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.]))
out = module.clamped_indexes(d, 3)
out = module.clamped_indexes(d, 2**3)
np.testing.assert_allclose(out[0], out_[0])
np.testing.assert_allclose(out[1], out_[1], rtol=1e-7, atol=1e-8)

0 comments on commit be84e57

Please sign in to comment.