In [1]:
import random
import numpy as np

# this is a practical implementation of
# Understanding Entropy Coding With Asymmetric
# Numeral Systems (ANS): a Statistician’s Perspective


def quantize_pmf(pmf, precision):
     n = 1 << precision
     assert n > 2 * len(pmf)
     qt_pmf = [0 for i in range(len(pmf))]
     
     sum = 0 
     for i in range(len(pmf)):
         val = int(np.around(pmf[i] * n))
         if val == 0:
             val = val + 1
         qt_pmf[i] = val
         sum += val
     
     diff = n - sum
     
     if diff < 0:
         while diff != 0:
             for i in range(len(pmf)):
                 if qt_pmf[i] != 1:
                     qt_pmf[i] = qt_pmf[i] - 1
                     diff += 1
                 if diff == 0:
                     break
    
     elif diff > 0:
         while diff != 0:
             for i in range(len(pmf)):
                 if qt_pmf[i] != 1:
                     qt_pmf[i] = qt_pmf[i] + 1
                     diff -= 1
                 if diff == 0:
                     break

     return qt_pmf

In [2]:
class StreamANS(object):
    def __init__(self, precision, word_size, head_size, compressed=[]):

        assert word_size >= precision
        assert head_size >= word_size + precision

        # precision = log2(sum(freq))
        # word_size = a (maximal) size of items in bulks
        # head_size = maximal size of head

        self.precision = precision
        self.word_size = word_size
        self.head_size = head_size

        self.n = 1 << self.precision
        self.w = 1 << self.word_size
        self.prec_mask = self.n - 1
        self.word_mask = self.w - 1
        self.bulk = compressed.copy()

        if len(self.bulk) == 0:
            # initial bits
            self.head = self.n + random.randint(0, self.n - 1)

        else:
            self.head = 0
            while len(self.bulk) != 0 and (self.head >> (self.head_size - self.word_size)) == 0:
                self.head = (self.head << self.word_size) | self.bulk.pop()
    
    def search_cdfs(self, z, freq):
        sym, freq_sym, offset_sym = None, None, 0
        for s, freq_s in enumerate(freq):
            if offset_sym <= z and z < offset_sym + freq_s:
                sym = s
                freq_sym = freq_s
                break
            
            offset_sym = offset_sym + freq_s
        return sym, freq_sym, offset_sym

    def push(self, sym, freq):
        freq_sym = freq[sym]
        if (self.head >> (self.head_size - self.precision)) >= freq_sym:
            self.bulk.append(self.head & self.word_mask)
            self.head = self.head >> self.word_size
            # print(self.bulk[-1], self.head)

        u = self.head % freq_sym
        offset = sum(freq[0:sym])
        z = u + offset
        # print(z, u)
        self.head = self.head // freq_sym
        self.head = (self.head << self.precision) + z

    def pop(self, freq):
        z = self.head & self.prec_mask
        self.head = self.head >> self.precision
        sym, freq_sym, offset_sym = self.search_cdfs(z, freq)
        u = z - offset_sym
        # print(z, u)

        self.head = self.head * freq_sym + u

        # check
        
        if len(self.bulk) > 0:
            if (self.head >> (self.head_size - self.word_size)) == 0:
                # print(self.bulk[-1], self.head)
                self.head = (self.head << self.word_size) | self.bulk.pop()

        return sym


    def get_compressed(self):
        compressed = self.bulk.copy()
        head = self.head
        while head != 0:
            compressed.append(head & self.word_mask)
            head = head >> self.word_size
        
        return compressed


In [3]:
def src_seq(n, p):
    return np.random.choice(len(p), size=n, p=p).tolist()
    
def randpmf(size):
    pmf = np.abs(np.random.randn(size))
    pmf = pmf / pmf.sum()
    return pmf


precision = 24
word_size = 32
head_size = 64
assert word_size >= precision
assert head_size >= word_size + precision

alphabet_size = 8
pmf = randpmf(alphabet_size)
freq = quantize_pmf(pmf, precision)

message_length = 100
msg = src_seq(message_length, pmf)


coder = StreamANS(precision, word_size, head_size)


for i in range(len(msg)):
    coder.push(msg[i], freq)


code = coder.get_compressed()

codelength = 0


for pack in code:
    codelength += pack.bit_length()

print('{:-^100}'.format(''))
print('ans codelength: {}'.format(codelength))
print('optimal codelength: {}'.format(message_length * (pmf * (-1) * np.log2(pmf)).sum()))
print('{:-^100}'.format(''))



dec = []
for i in reversed(range(len(msg))):
    dec.append(coder.pop(freq))

dec.reverse()
print('sender message : {}'.format(msg[0:25]))

print('receiv message : {}'.format(dec[0:25]))


print('all matches : {}'.format(dec == msg))

----------------------------------------------------------------------------------------------------
ans codelength: 289
optimal codelength: 264.21490024286203
----------------------------------------------------------------------------------------------------
sender message : [6, 3, 5, 2, 5, 3, 6, 4, 6, 1, 2, 7, 6, 4, 2, 4, 3, 4, 5, 2, 4, 2, 3, 3, 3]
receiv message : [6, 3, 5, 2, 5, 3, 6, 4, 6, 1, 2, 7, 6, 4, 2, 4, 3, 4, 5, 2, 4, 2, 3, 3, 3]
all matches : True


In [4]:
coder2 = StreamANS(precision, word_size, head_size, code)

dec = []
for i in reversed(range(len(msg))):
    dec.append(coder2.pop(freq))

dec.reverse()
print('decoded message : {}'.format(dec[0:25]))


decoded message : [6, 3, 5, 2, 5, 3, 6, 4, 6, 1, 2, 7, 6, 4, 2, 4, 3, 4, 5, 2, 4, 2, 3, 3, 3]
