# Code Examples From the Paper "Understanding Entropy Coding With Asymmetric Numeral Systems (ANS): a Statistician's Perspective"

**Author:** [Robert Bamler](https://robamler.github.io/)<br>
**Date:** 22 December 2021

This notebook contains all code examples from the paper "Understanding Entropy Coding With Asymmetric Numeral Systems (ANS): a Statistician's Perspective" by Robert Bamler.

The final streaming ANS coder is implemented in Listing 7 below.
While this coder implementation works, it is really only intended as an educational demo.
For readers who're looking for an entropy coder, we recommend installing the [constriction package](https://bamler-lab.github.io/constriction/) with `pip install constriction`.

## Listing 1: `UniformCoder`

Use Listings 2 and 3 below to test this implementation.

In [1]:
class UniformCoder: # optimal stream code for uniformly distributed symbols
    def __init__(self, number=0):   # constructor with an optional argument
        self.number = number

    def push(self, symbol, base):   # Encodes a symbol ∈{0, . . . , base −1}.
        self.number = self.number * base + symbol

    def pop(self, base):            # Decodes a symbol ∈{0, . . . , base −1}.
        symbol = self.number % base # “%”  denotes modulo.
        self.number //= base        # “//” denotes integer division.
        return symbol

## Listing 2: Simple Usage Example of `UniformCoder`

In [2]:
coder = UniformCoder() # Defined in Listing 1.

# Encode the message x = (3, 6, 5):
coder.push(3, base=10) # base=10 means that the alphabet is {0, ..., 9}.
coder.push(6, base=10)
coder.push(5, base=10)
print(f"Encoded number: {coder.number}")  # Prints: “Encoded number: 365”
print(f"In binary: {coder.number:b}")     # Prints: “In binary: 101101101”

# Decode the encoded symbols (in reverse order):
print(f"Decoded '{coder.pop(base=10)}'.") # Prints: “Decoded '5'.”
print(f"Decoded '{coder.pop(base=10)}'.") # Prints: “Decoded '6'.”
print(f"Decoded '{coder.pop(base=10)}'.") # Prints: “Decoded '3'.”

Encoded number: 365
In binary: 101101101
Decoded '5'.
Decoded '6'.
Decoded '3'.


## Listing 3: Using `UniformCoder` With Inhomogeneous Alphabets

In [3]:
coder = UniformCoder() # Defined in Listing 1.

# Encode a sequence of symbols from alphabets of varying sizes:
coder.push( 3, base=10) # base=10 means that the alphabet is {0, ..., 9}.
coder.push( 6, base=10)
coder.push(12, base=15) # Setting base=15 switches to alphabet {0, ..., 14}.
coder.push( 4, base=15)
print(f"Binary: {coder.number:b}") # Prints: “Binary: 10000001011100”

# For decoding, use the same sequence of bases but in *reverse* order:
print(f"Decoded '{coder.pop(base=15)}'.") # Prints: “Decoded '4'.”
print(f"Decoded '{coder.pop(base=15)}'.") # Prints: “Decoded '12'.”
print(f"Decoded '{coder.pop(base=10)}'.") # Prints: “Decoded '6'.”
print(f"Decoded '{coder.pop(base=10)}'.") # Prints: “Decoded '3'.”

Binary: 10000001011100
Decoded '4'.
Decoded '12'.
Decoded '6'.
Decoded '3'.


## Listing 4: `SlowAnsCoder`

Use Listing 5 below to test this implementation.

In [4]:
class SlowAnsCoder: # Has near-optimal bitrates but high runtime cost.
    def __init__(self, precision, compressed=0):
        self.n = 2**precision # See Eq. 5 (“**” denotes exponentiation).
        self.stack = UniformCoder(compressed) # Defined in Listing 1.

    def push(self, symbol, m): # Encodes one symbol.
        z = self.stack.pop(base=m[symbol]) + sum(m[0:symbol])
        self.stack.push(z, base=self.n)

    def pop(self, m):          # Decodes one symbol.
        z = self.stack.pop(base=self.n)
        # Find the unique symbol that satisfies z ∈ Z_i(symbol) (real
        # deployments should use a more efficient method than linear search):
        for symbol, m_symbol in enumerate(m):
            if z >= m_symbol:
                z -= m_symbol
            else:
                break
        self.stack.push(z, base=m_symbol)
        return symbol

    def get_compressed(self):
        return self.stack.number

## Listing 5: Usage Example for `SlowAnsCoder` (and for `AnsCoder`)

In [5]:
# Specify an approximated entropy model via precision and m_i(x_i) from Eq. 5:
precision = 4 # For demonstration; deployments should use higher precision.
m = [7, 3, 6] # Sets Q_i(X_i=0) = 7/16, Q_i(X_i=1) = 3/16, and Q_i(X_i=2) = 6/16.

# Encode a message in reversed order so we can decode it in forward order:
example_message = [2, 0, 2, 1, 0]
encoder = SlowAnsCoder(precision) # Also works with AnsCoder (Listing 7).
for symbol in reversed(example_message):
    encoder.push(symbol, m) # We could use a different m for each symbol.
compressed = encoder.get_compressed()

# We could actually reuse the encoder for decoding, but let's pretend that
# decoding occurs on a different machine that receives only “compressed”.
decoder = SlowAnsCoder(precision, compressed)
reconstructed = [decoder.pop(m) for _ in range(5)]
assert reconstructed == example_message # Verify correctness.

## Listing 6: Self-Contained Reimplementation of `SlowAnsCoder`

Use Listing 5 above to test this implementation.

In [6]:
class SlowAnsCoder: # Equivalent to Listing 4, just more self-contained.
    def __init__(self, precision, compressed=0):
        self.n = 2**precision # See Eq. 5 (“**” denotes exponentiation).
        self.compressed = compressed # (== stack.number in SlowAnsCoder)

    def push(self, symbol, m):        # Encodes one symbol.
        z = self.compressed % m[symbol] + sum(m[0:symbol])
        self.compressed //= m[symbol] # “//” denotes integer division.
        self.compressed = self.compressed * self.n + z

    def pop(self, m): # Decodes one symbol.
        z = self.compressed % self.n
        self.compressed //= self.n # “//” denotes integer division.
        for symbol, m_symbol in enumerate(m):
            if z >= m_symbol:
                z -= m_symbol
            else:
                break # We found the symbol that satisfies z ∈ Z_i(symbol).
        self.compressed = self.compressed * m_symbol + z
        return symbol

    def get_compressed(self):
        return self.compressed

## Listing 7: Streaming ANS Coder

The following code illustrates a full implementation of a streaming ANS coder.
While this coder works, it is really only intended as an educational demo.
Readers who want to use an entropy coder for a real project are advised to install [`constriction`](https://bamler-lab.github.io/constriction/) with `pip install constriction`.

To test this implementation, replace `SlowAnsCoder` with `AnsCoder` in Listing 5 above.

In [7]:
class AnsCoder:
    def __init__(self, precision, compressed=[]):
        self.precision = precision
        self.mask = (1 << precision) - 1 # (a string of `precision`` one-bits)
        self.bulk = compressed.copy() # (We will mutate `bulk`` below.)
        self.head = 0
        # Establish invariant (ii):
        while len(self.bulk) != 0 and (self.head >> precision) == 0:
            self.head = (self.head << precision) | self.bulk.pop()

    def push(self, symbol, m): # Encodes one symbol.
        # Check if encoding directly onto head would violate invariant (i):
        if (self.head >> self.precision) >= m[symbol]:
            # Transfer one word of compressed data from head to bulk:
            self.bulk.append(self.head & self.mask) # (“&” is bitwise `and`)
            self.head >>= self.precision
            # At this point, invariant (ii) is definitely violated,
            # but the operations below will restore it.

        z = self.head % m[symbol] + sum(m[0:symbol])
        self.head //= m[symbol]
        self.head = (self.head << self.precision) | z # (This is
            # equivalent to “ self.head * n + z”, just slightly faster.)

    def pop(self, m): # Decodes one symbol.
        z = self.head & self.mask # (same as “ self.head % n” but faster)
        self.head >>= self.precision # (same as “ //= n” but faster)
        for symbol, m_symbol in enumerate(m):
            if z >= m_symbol:
                z -= m_symbol
            else:
                break # We found the symbol that satisfies z ∈ Z_i(symbol).
        self.head = self.head * m_symbol + z

        # Restore invariant (ii) if it is violated (which happens exactly
        # if the encoder transferred data from head to bulk at this point):
        if (self.head >> self.precision) == 0 and len(self.bulk) != 0:
            # Transfer data back from bulk to head (“ |” is bitwise or):
            self.head = (self.head << self.precision) | self.bulk.pop()

        return symbol

    def get_compressed(self):
        compressed = self.bulk.copy() # (We will mutate `compressed`` below.)
        head = self.head
        # Chop `head`` into `precision`-sized words and append to `compressed``:
        while head != 0:
            compressed.append(head & self.mask)
            head >>= self.precision
        return compressed

## Listing 8: Random-Access Decoding with ANS

In [8]:
class SeekableAnsCoder(AnsCoder): # Adds random-access decoding to Listing 7.
    def __init__(self, precision, compressed=[]):
        return super(SeekableAnsCoder, self).__init__(precision, compressed)

    # Inherits push, pop, and get_compressed from Listing 7.

    def checkpoint(self): # Records a point to which we can seek later.
        return (len(self.bulk), self.head)

    def seek(self, checkpoint): # Jumps to a previously taken checkpoint.
        position, head = checkpoint
        if position > len(self.bulk): # “raise” throws an exception.
            raise "This simple demo can only seek forward."
        self.bulk = self.bulk[0:position] # Truncates bulk.
        self.head = head

# Usage example:
precision = 4 # For demonstration; deployments should use higher precision.
m = [7, 3, 6] # Same demo model as in Listing 5.
coder = SeekableAnsCoder(precision)
message = [2, 0, 2, 1, 0, 1, 2, 2, 2, 1, 0, 2, 1, 2, 0, 0, 1, 1, 1, 2]

for symbol in reversed(message[10:20]): # Encode second half of message.
    coder.push(symbol, m)
checkpoint = coder.checkpoint()         # Record a checkpoint.
for symbol in reversed(message[0:10]):  # Encode first half of message.
    coder.push(symbol, m)

assert coder.pop(m) == message[0]       # Decode first symbol.
assert coder.pop(m) == message[1]       # Decode second symbol.
coder.seek(checkpoint)                  # Jump to 11th symbol.
assert [coder.pop(m) for _ in range(10)] == message[10:20] # Decode rest.

## Listing 9: Demonstration of Non-Local Effect of Entropy Models

In [9]:
precision = 4 # For demonstration; deployments should use higher precision.
m_orig = [7, 3, 6] # Same demo entropy model as in Listing 5.
m_mod = [6, 4, 6]  # (Slightly) modified entropy model compared to m_orig.
compressed = [0b1001, 0b1110, 0b0110, 0b1110] # Some example bit string.

# Case 1: decode 4 symbols using entropy model m_orig for all symbols:
decoder = AnsCoder(precision, compressed) # AnsCoder defined in Listing 7.
case1 = [decoder.pop(m_orig) for _ in range(4)]

# Case 2: change the entropy model, but *only* for the first symbol:
decoder = AnsCoder(precision, compressed) # “compressed” hasn't changed.
case2 = [decoder.pop(m_mod)] + [decoder.pop(m_orig) for _ in range(3)]

print(f"case1 = {case1}") # Prints: “case1 = [0, 1, 0, 2]”
print(f"case2 = {case2}") # Prints: “case2 = [1, 1, 2, 0]”
                          #                   ↑     ↑  ↑
                          #   We changed only the | But that affected both
                          # model for this symbol.| of these symbols too.

case1 = [0, 1, 0, 2]
case2 = [1, 1, 2, 0]


## Listing 10: Experimental New Coder That Overcomes the Issue Demonstrated in Listing 9

In [10]:
class ChainCoder: # Prevents the non-local effect shown in Listing 9.
    def __init__(self, precision, compressed, remainders=[]):
        """Initializes a ChainCoder for decoding from `compressed`."""
        self.precision = precision
        self.mask = (1 << precision) - 1
        self.compressed = compressed.copy() # pop decodes from here.
        self.remainders = remainders.copy() # pop encodes onto here.
        self.remainders_head = 0
        # Establish invariant (ii):
        while len(self.remainders) != 0 and \
                (self.remainders_head >> precision) == 0:
            self.remainders_head <<= self.precision
            self.remainders_head |= self.remainders.pop()

    def pop(self, m): # Decodes one symbol.
        z = self.compressed.pop() # Always read a full word from compressed.
        for symbol, m_symbol in enumerate(m):
            if z >= m_symbol:
                z -= m_symbol
            else:
                break # We found the symbol that satisfies z ∈Zi(symbol).

        self.remainders_head = self.remainders_head * m_symbol + z
        if (self.remainders_head >> (2 * self.precision)) != 0:
            # Invariant (i) is violated. Flush one word to remainders.
            self.remainders.append(self.remainders_head & self.mask)
            self.remainders_head >>= self.precision
            # It can easily be shown that invariant (i) is restored here.

        return symbol

    def push(self, symbol, m): # Encodes one symbol.
        if len(self.remainders) != 0 and \
                self.remainders_head < (m[symbol] << self.precision):
            self.remainders_head <<= self.precision
            self.remainders_head |= self.remainders.pop()
            # Invariant (i) is now violated but will be restored below.

        z = self.remainders_head % m[symbol] + sum(m[0:symbol])
        self.remainders_head //= m[symbol]
        self.compressed.append(z)

    def get_compressed(self):
        return self.compressed.copy()

    def get_remainders(self):
        remainders = self.bulk.copy() # (We will mutate `remainders`` below.)
        remainders_head = self.remainders_head
        # Chop `remainders_head`` into `precision`-sized words and append to `remainders``:
        while remainders_head != 0:
            remainders.append(remainders_head & self.mask)
            remainders_head >>= self.precision
        return compressed

### Usage Example

Let's replace `AnsCoder` by `ChainCoder` in the example from Listing 9:

In [11]:
precision = 4 # For demonstration; deployments should use higher precision.
m_orig = [7, 3, 6] # Same demo entropy model as in Listing 5.
m_mod = [6, 4, 6]  # (Slightly) modified entropy model compared to m_orig.
compressed = [0b1001, 0b1110, 0b0110, 0b1110] # Some example bit string.

# Case 1: decode 4 symbols using entropy model m_orig for all symbols:
decoder = ChainCoder(precision, compressed) # ChainCoder defined in Listing 10.
case1 = [decoder.pop(m_orig) for _ in range(4)]

# Case 2: change the entropy model, but *only* for the first symbol:
decoder = ChainCoder(precision, compressed) # “compressed” hasn't changed.
case2 = [decoder.pop(m_mod)] + [decoder.pop(m_orig) for _ in range(3)]

print(f"case1 = {case1}") # Prints: “case1 = [2, 0, 2, 1]”
print(f"case2 = {case2}") # Prints: “case2 = [2, 0, 2, 1]”

case1 = [2, 0, 2, 1]
case2 = [2, 0, 2, 1]


It's no surprise that we obtain a different sequence of decoded symbols than with an `AnsCoder` because we used a different entropy coding algorithm.
The important point is that `case1` and `case` agree on the second, third, and fourth symbol.
(In the above example, they *happen* to also agree on the first symbol, but that's a coincidence.)