Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vbq #45

Open
wants to merge 81 commits into
base: main
Choose a base branch
from
Open

Vbq #45

wants to merge 81 commits into from

Conversation

robamler
Copy link
Collaborator

Implement a fast Variational Bayesian Quantization with a dynamically adjusting empirical prior.

See:

This will be used internally by
`constriction::quant::EmpiricalDistribution`.
This enforces a discipline on private fields.
This avoids confusion because both `pos` and `accum` can be used as a
"key" when searching for an entry. Accordingly, the type parameter for
pos is now `P`.

Also, rename `amount` to `count` to indicate that this should probably
be an integer (so that tree restructuring operations don't lead to
rounding errors).
Includes unit tests, which pass (also in miri).
Add tests for `left_cumulative` and `quantile_function` in empty and
almost empty trees.
Removes the parent pointers from all nodes in an `AugmentedBTree` as
these turn out to be unnecessary for our use cases. This greatly
simplifies the implementation because it removes aliasing, so we can use
normal owned pointers (`Box`) and no longer need to deal with raw
pointers. It should also make the implementation trivially unwind safe,
which wasn't obviously the case before.
This is actually a no-op for the types for which we've used `BoundedVec`
so far, but it wouldn't be for a `BoundedVec<T, CAP>` where `T`
implements `Drop`.
Implements `BoundedVec<T, CAP>` as a wrapper around
`BoundedPairOfVecs<T, (), CAP>`, and combine the fields `separators` and
`first_child` of `NonLeafNode` into a single field whose type enforces
that the two always have the same length. This makes it a bit easier to
reason about the implementation of `NonLeafNode::insert`.
Many functions in `constriction` check assertions that involve bit
lengths generic parameters. So far, these checks were implemented with
simple `assert!` macros. This approach did not incur any run-time cost
since the checks can be trivially evaluated at compile time and are
therefore pretty much guaranteed to be optimized away as dead code
during monomorphization (assuming that the assertions are satisfied).
However, checking for assertions at run time still leaves room for
accidental misuse of a function that can only be detected once control
flow actually reaches an incorrect function call during testing.

With this commit, we now check most assertions at compile time using
const evaluation and a trick discussed at
<https://morestina.net/blog/1940>. This leads to compile-time errors for
all incorrect function calls, even if an incorrect function call is not
reachable from any unit tests.

This is a *breaking change* as it breaks correct code in edge cases
where users validate generic parameters at run time. For example, assume
a consumer of `constriction` contains the following function to encode a
sequence of dice throws:

```rust
use constriction::{stream::model::UniformModel, UnwrapInfallible};

fn encode_dice<const MAX_PRECISION: usize>(
    dice: &[u32],
) -> Result<Vec<u32>, Box<dyn std::error::Error>> {
    let mut coder = constriction::stream::stack::DefaultAnsCoder::new();
    if MAX_PRECISION <= 32 {
        let model = UniformModel::<u32, MAX_PRECISION>::new(6);
        coder.encode_iid_symbols_reverse(dice, model)?;
        Ok(coder.into_compressed().unwrap_infallible())
    } else {
        let model = UniformModel::<u32, 32>::new(6);
        coder.encode_iid_symbols_reverse(dice, model)?;
        Ok(coder.into_compressed().unwrap_infallible())
    }
}

fn caller() {
    let dice = [3, 2, 4, 0, 1];
    dbg!(encode_dice::<24>(&dice).unwrap()); // no problem here
    dbg!(encode_dice::<33>(&dice).unwrap()); // <-- compiler error
}
```

Here, the function `encode_dice` has a const generic parameter
`MAX_PRECISION`, which allows the caller to set the precision with which
probabilities are represented for encoding. The function prevents misuse
by checking if `MAX_PRECISION` is higher than its highest allowed value
of 32, in which case it uses precision 32 instead. The function `caller`
calls `encode_dice` with precisions 24 and 33. Before this commit, this
would have worked, and rust would have emitted two different
monomorphized variants of `encode_dice` that each contain only the
respective relevant branch of the `if` statement. While the method
`encode_iid_symbols_reverse` contained an `assert!` macro that would
have panicked for a precision of 33, this statement was unreachable.

But with this commit, the example won't compile anymore because the
assertion is now checked at compile time. These compile-time checks
occur before dead code elimination. They are therefore also performed on
the branch that implements encoding with a precision of 33, even though
this branch would never be reached at runtime.

The above example shows that this commit is technically a breaking
change and should therefore warrant incrementing the leading version
number. In practice, however, it seems very unlikely that the
illustrated issue would arise in real-world code.
Current version compiles but is not yet complete.

Our strategy:
- steal only from one neighbor (but still plan to merge three siblings
  into two if stealing doesn't resolve underflow)
- require `CAP >= 4` so we don't have to care about quite as many
  edge cases for now (we can always release this constraint later).
This will allow us to reduce the amount of code duplication, e.g., for
the implementations of `{NonLeafNode, LeafNode}::remove`.

Use `PairOfBoundedVecs` for both leaf and non-leaf nodes. For non-leaf
nodes, the pair of vecs are the list of separators and their right child
pointers. For leaf nodes, the second bounded vec contains `()` instead
of child pointers. This should completely compile away.

Tests pass, also in miri.
Currently need to be run with
`cargo bench --features benchmark-internals`.
This method combines `AugmentedBTree::remove` and
`AugmentedBTree::insert` into a single operation that avoids
unnecessarily repeated tree traversals if the remove and insert
positions are close to each other.

This method was motivated by the fact that shifts by small amounts are
expected to come up frequently in VBQ, and avoiding unnecessarily
repeated tree traversals might speed things up. Unfortunately, the
implementation ended up much more complicated than expected because
it turns out that there are lots of edge cases. And benchmarks don't
show any advantages in speed compared to simply calling `remove`
followed by `insert`. Therefore, I'll remove the method
`AugmentedBTree::remove` in an upcoming commit.
Both the type and the meaning of the return value changed:

- The type changed from `Option<C>` to `Result<C, NotFoundError>` based
  on experience from implementing the python FFI.

- The wrapped value upon success is not the number of points with the
  provided `value` that were present in the distribution *before* we
  removed one. The previous implementation used to return the count
  *after* removal (i.e., one less). This change makes
  `EmpiricalDistribution::remove` consistent with
  `EmpiricalDistribution::remove_all`, for which the previous convention
  would not have been useful.
This has always been supported by the internal tree structure and just
not exposed by the public API because it wasn't clear whether it's
actually useful and we wanted to avoid committing to an implementation
that supports this. But it now turns out that this will be needed by the
python API to implement `EmpiricalDistribution::update_all`, which seems
like evidence that it's a useful feature that's worth exposing publicly.
Still needs doc examples and tests.
Adds this functionality to both the rust and the python API.

See code example in the python documentation of `points_and_counts`.

This also fixes the method `points_and_counts` of the Python API, which
used to return numpy arrays of numpy arrays. It now returns lists of
numpy arrays, which has always been the intention.
We now first remove *all* items listed in `old` and record their counts
before inserting the items listed in `new`. This is slightly less
efficient than the previous implementation that removed and inserted
item by item, but it makes the shifts independent from each other, thus
allowing us to use `shift` also for, e.g., swapping grid positions.
These seem reasonable but they are not confirmed by any benchmarks.
Factors out similar constructors for `EmpiricalDistribution` and
`RatedGrid` into a trait. This will hopefully allow me to avoid
repeating lots of code when implementing python front ends to the two
sets of constructors.
Factors out similar methods for `EmpiricalDistribution` and
`RatedGrid` into a trait. This will hopefully allow me to avoid
repeating lots of code when implementing python front ends to these two
methods.
This will allow reusing most of the code of the pybindings for `vbq`
for the pybindings for `rate_distortion_quantization`.
These should always have been generic over `CAP`, it was just an
oversight that they were only implemented for the default `CAP`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant