Skip to content

Commit

Permalink
Python: document and test EmpiricalDistribution::shift
Browse files Browse the repository at this point in the history
  • Loading branch information
robamler committed Feb 12, 2024
1 parent 69967a3 commit 683071c
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 8 deletions.
140 changes: 132 additions & 8 deletions src/pybindings/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,131 @@ impl EmpiricalDistribution {
/// list must equal the dimension of axis `i` of the array provided to the constructor, and each
/// list entry specifies how the corresponding distribution is updated.
///
/// TODO: examples
/// ## Warning
///
/// When shifting multiple grid points, they will be shifted one after the other. This can lead
/// to surprising results if any of the entries in `new` and `old` coincide. For example, one
/// might think that the following example would swap the points at positions `1.1` and `2.1`.
/// But this is not what it does. Instead, it moves all 10 points at position `1.1` to position
/// `2.1`, thus merging them with the `20` points that are already there. It then moves the
/// resulting 30 points at `2.1` to position `1.1`, ending up with 30 points at position `1.1`
/// and the original 70 points at position `7.1`.
///
/// ```python
/// original_points = np.array([1.1, 2.1, 7.1], dtype=np.float32)
/// original_counts = np.array([10, 20, 70], dtype=np.uint32)
/// distribution = constriction.quant.EmpiricalDistribution(
/// original_points, counts=original_counts)
///
/// # Warning: this does not swap the entries. It merges them into position `1.1`.
/// distribution.shift(
/// np.array([1.1, 2.1], dtype=np.float32),
/// np.array([2.1, 1.1], dtype=np.float32)
/// )
///
/// points, counts = distribution.points_and_counts()
/// print(f"points and counts after first shift: {points}, {counts}")
/// ```
///
/// This prints:
///
/// ```text
/// points and counts after first shift: [1.1 7.1], [30 70]
/// ```
///
/// ## Example 1: shifting a single point
///
/// ```python
/// rng = np.random.default_rng(123)
/// matrix = rng.binomial(10, 0.3, size=(4, 5)).astype(np.float32)
/// print(f"matrix = {matrix}\n")
///
/// distribution = constriction.quant.EmpiricalDistribution(matrix)
/// points, counts = distribution.points_and_counts()
/// print(f"points and counts before shifting: {points}, {counts}")
/// print(f"entropy before shifting: {distribution.entropy_base2()}\n")
///
/// distribution.shift(2., 2.5)
/// points, counts = distribution.points_and_counts()
/// print(f"points and counts after first shift: {points}, {counts}")
/// print(f"entropy after first shift: {distribution.entropy_base2()}\n")
///
/// distribution.shift(3., 2.5)
/// points, counts = distribution.points_and_counts()
/// print(f"points and counts after second shift: {points}, {counts}")
/// print(f"entropy after second shift: {distribution.entropy_base2()}")
/// ```
///
/// This prints:
///
/// ```text
/// matrix = [[4. 1. 2. 2. 2.]
/// [4. 5. 2. 4. 5.]
/// [3. 2. 4. 2. 4.]
/// [3. 5. 2. 4. 3.]]
///
/// points and counts before shifting: [1. 2. 3. 4. 5.], [1 7 3 6 3]
/// entropy before shifting: 2.088376522064209
///
/// points and counts after first shift: [1. 2.5 3. 4. 5. ], [1 7 3 6 3]
/// entropy after first shift: 2.088376522064209
///
/// points and counts after second shift: [1. 2.5 4. 5. ], [ 1 10 6 3]
/// entropy after second shift: 1.647730827331543
/// ```
///
/// Notice that the second shift merges two grid points, thus reducing the entropy.
///
/// ## Example 2: shifting multiple points, without `specialize_along_axis`
///
/// We can carry out the same two shifts from Example 1 with a single method call:
///
/// ```python
/// distribution.shift(
/// np.array([2., 3.], dtype=np.float32),
/// np.array([2.5, 2.5], dtype=np.float32)
/// )
/// ```
///
/// ## Example 3: shifting multiple points, with `specialize_along_axis`
///
/// The following example still uses the same `matrix` as Examples 1 and 2 above:
///
/// ```python
/// distribution = constriction.quant.EmpiricalDistribution(matrix, specialize_along_axis=0)
/// points, counts = distribution.points_and_counts()
/// print(f"points before shifting: [{', '.join(str(p) for p in points)}]")
/// print(f"counts before shifting: [{', '.join(str(c) for c in counts)}]\n")
///
/// original_positions = [
/// 2., # move all `2.`s in the first row ...
/// np.array([], dtype=np.float32), # move nothing in the second row ...
/// np.array([3., 2.], dtype=np.float32), # move all `3.`s and `2.`s in the third row ...
/// np.array([3.], dtype=np.float32), # move all `3.`s in the fourth row ...
/// ]
/// target_positions = [
/// 2.1, # ... to `2.1`.
/// np.array([], dtype=np.float32), # ... to nothing.
/// np.array([3.3, 2.3], dtype=np.float32), # ... to `3.3` and `2.3`, respectively.
/// np.array([30.], dtype=np.float32), # ... to `30.`.
/// ]
///
/// distribution.shift(original_positions, target_positions)
///
/// points, counts = distribution.points_and_counts()
/// print(f"points after shifting: [{', '.join(str(p) for p in points)}]")
/// print(f"counts after shifting: [{', '.join(str(c) for c in counts)}]")
/// ```
///
/// This prints:
///
/// ```text
/// points before shifting: [[1. 2. 4.], [2. 4. 5.], [2. 3. 4.], [2. 3. 4. 5.]]
/// counts before shifting: [[1 3 1], [1 2 2], [2 1 2], [1 2 1 1]]
///
/// points after shifting: [[1. 2.1 4. ], [2. 4. 5.], [2.3 3.3 4. ], [ 2. 4. 5. 30.]]
/// counts after shifting: [[1 3 1], [1 2 2], [2 1 2], [1 1 1 2]]
/// ```
pub fn shift(&mut self, old: &PyAny, new: &PyAny, index: Option<usize>) -> PyResult<()> {
if let Some(index) = index {
let EmpiricalDistributionImpl::Multiple { distributions, .. } = &mut self.0 else {
Expand Down Expand Up @@ -676,13 +800,7 @@ impl EmpiricalDistribution {
distribution: &mut crate::quant::EmpiricalDistribution,
mk_err: impl Fn() -> PyErr,
) -> Result<(), PyErr> {
if let Ok(old) = old.extract::<f32>() {
if let Ok(new) = new.extract::<f32>() {
let count = distribution.remove_all(F32::new(old)?);
distribution.insert(F32::new(new)?, count);
return Ok(());
}
} else if let Ok(old) = old.extract::<PyReadonlyArrayDyn<'_, f32>>() {
if let Ok(old) = old.extract::<PyReadonlyArrayDyn<'_, f32>>() {
if let Ok(new) = new.extract::<PyReadonlyArrayDyn<'_, f32>>() {
if old.dims() == new.dims() && old.dims().ndim() == 1 {
let old = old.as_array();
Expand All @@ -695,6 +813,12 @@ impl EmpiricalDistribution {
return Ok(());
}
}
} else if let Ok(old) = old.extract::<f32>() {
if let Ok(new) = new.extract::<f32>() {
let count = distribution.remove_all(F32::new(old)?);
distribution.insert(F32::new(new)?, count);
return Ok(());
}
}
Err(mk_err())
}
Expand Down
152 changes: 152 additions & 0 deletions tests/python/test_docexamples_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,155 @@ def test_points_and_counts_example2b():
points, counts=counts, specialize_along_axis=1)
reconstructed_entropies = reconstructed_distribution.entropy_base2()
assert np.all(reconstructed_entropies == original_entropies)


def test_shift_warning():
original_points = np.array([1.1, 2.1, 7.1], dtype=np.float32)
original_counts = np.array([10, 20, 70], dtype=np.uint32)
distribution = constriction.quant.EmpiricalDistribution(
original_points, counts=original_counts)

# Warning: this does not swap the entries. It merges them into position `1.1`.
distribution.shift(
np.array([1.1, 2.1], dtype=np.float32),
np.array([2.1, 1.1], dtype=np.float32)
)

points, counts = distribution.points_and_counts()
assert np.allclose(points, [1.1, 7.1])
assert np.all(counts == [30, 70])


def test_shift_example1():
rng = np.random.default_rng(123)
matrix = rng.binomial(10, 0.3, size=(4, 5)).astype(np.float32)

distribution = constriction.quant.EmpiricalDistribution(matrix)
points, counts = distribution.points_and_counts()
assert np.all(points == [1., 2., 3., 4., 5.])
assert np.all(counts == [1, 7, 3, 6, 3])
entropy1 = distribution.entropy_base2()

distribution.shift(2., 2.5)
points, counts = distribution.points_and_counts()
assert np.all(points == [1., 2.5, 3., 4., 5.])
assert np.all(counts == [1, 7, 3, 6, 3])
assert distribution.entropy_base2() == entropy1

distribution.shift(3., 2.5)
points, counts = distribution.points_and_counts()
assert np.all(points == [1., 2.5, 4., 5.])
assert np.all(counts == [1, 10, 6, 3])
assert distribution.entropy_base2() < entropy1


def test_shift_example2():
rng = np.random.default_rng(123)
matrix = rng.binomial(10, 0.3, size=(4, 5)).astype(np.float32)

distribution = constriction.quant.EmpiricalDistribution(matrix)
points, counts = distribution.points_and_counts()
assert np.all(points == [1., 2., 3., 4., 5.])
assert np.all(counts == [1, 7, 3, 6, 3])
entropy1 = distribution.entropy_base2()

distribution.shift(
np.array([2., 3.], dtype=np.float32),
np.array([2.5, 2.5], dtype=np.float32)
)

points, counts = distribution.points_and_counts()
assert np.all(points == [1., 2.5, 4., 5.])
assert np.all(counts == [1, 10, 6, 3])
assert distribution.entropy_base2() < entropy1


def test_shift_example3a():
rng = np.random.default_rng(123)
matrix = rng.binomial(10, 0.3, size=(4, 5)).astype(np.float32)

distribution = constriction.quant.EmpiricalDistribution(
matrix, specialize_along_axis=0)
points, counts = distribution.points_and_counts()
points_expected = [[1., 2., 4.], [2., 4., 5.],
[2., 3., 4.], [2., 3., 4., 5.]]
counts_expected = [[1, 3, 1], [1, 2, 2], [2, 1, 2], [1, 2, 1, 1]]
for (p, pe) in zip(points, points_expected):
assert np.allclose(p, pe)
for (c, ce) in zip(counts, counts_expected):
assert np.all(c == ce)

entropies = distribution.entropy_base2()

original_positions = [
2.,
np.array([], dtype=np.float32),
np.array([3., 2.], dtype=np.float32),
np.array([3.], dtype=np.float32),
]
target_positions = [
2.1,
np.array([], dtype=np.float32),
np.array([3.3, 2.3], dtype=np.float32),
np.array([30.], dtype=np.float32),
]

distribution.shift(original_positions, target_positions)

points, counts = distribution.points_and_counts()
points_expected = [[1., 2.1, 4.], [2., 4., 5.],
[2.3, 3.3, 4.], [2., 4., 5., 30.]]
counts_expected = [[1, 3, 1], [1, 2, 2], [2, 1, 2], [1, 1, 1, 2]]
for (p, pe) in zip(points, points_expected):
assert np.allclose(p, pe)
for (c, ce) in zip(counts, counts_expected):
assert np.all(c == ce)

assert np.all(distribution.entropy_base2() == entropies)


def test_shift_example3b():
rng = np.random.default_rng(123)
matrix = rng.binomial(10, 0.3, size=(4, 5)).astype(np.float32)

distribution = constriction.quant.EmpiricalDistribution(
matrix, specialize_along_axis=1)
points, counts = distribution.points_and_counts()

points_expected = [[3., 4.], [1., 2., 5.],
[2., 4.], [2., 4.], [2., 3., 4., 5.]]
counts_expected = [[2, 2], [1, 1, 2], [3, 1], [2, 2], [1, 1, 1, 1]]
for (p, pe) in zip(points, points_expected):
assert np.allclose(p, pe)
for (c, ce) in zip(counts, counts_expected):
assert np.all(c == ce)

entropies = distribution.entropy_base2()

original_positions = [
np.array([3.], dtype=np.float32),
1.,
np.array([2., 4.], dtype=np.float32),
np.array([], dtype=np.float32),
np.array([5., 4.3], dtype=np.float32),
]
target_positions = [
np.array([30.], dtype=np.float32),
1.,
np.array([20.3, 4.3], dtype=np.float32),
np.array([], dtype=np.float32),
np.array([50., -10.], dtype=np.float32),
]

distribution.shift(original_positions, target_positions)

points, counts = distribution.points_and_counts()
points_expected = [[4., 30.], [1., 2., 5.],
[4.3, 20.3], [2., 4.], [2., 3., 4., 50.]]
counts_expected = [[2, 2], [1, 1, 2], [1, 3], [2, 2], [1, 1, 1, 1]]
for (p, pe) in zip(points, points_expected):
assert np.allclose(p, pe)
for (c, ce) in zip(counts, counts_expected):
assert np.all(c == ce)

assert np.all(distribution.entropy_base2() == entropies)

0 comments on commit 683071c

Please sign in to comment.