In [14]:
import random
import numpy as np

# this is a practical implementation of
# Understanding Entropy Coding With Asymmetric
# Numeral Systems (ANS): a Statistician’s Perspective


def quantize_pmf(pmf, precision):
     n = 1 << precision
     assert n > 2 * len(pmf)
     qt_pmf = [0 for i in range(len(pmf))]
     
     sum = 0 
     for i in range(len(pmf)):
         val = int(np.around(pmf[i] * n))
         if val == 0:
             val = val + 1
         qt_pmf[i] = val
         sum += val
     
     diff = n - sum
     
     if diff < 0:
         while diff != 0:
             for i in range(len(pmf)):
                 if qt_pmf[i] != 1:
                     qt_pmf[i] = qt_pmf[i] - 1
                     diff += 1
                 if diff == 0:
                     break
    
     elif diff > 0:
         while diff != 0:
             for i in range(len(pmf)):
                 if qt_pmf[i] != 1:
                     qt_pmf[i] = qt_pmf[i] + 1
                     diff -= 1
                 if diff == 0:
                     break

     return qt_pmf


class StreamANS(object):
    def __init__(self, precision):
        self.precision = precision
        self.n = 1 << self.precision
        self.mask = self.n - 1
        self.disk = []
        self.head = self.n + int(random.randint(0, self.n - 1))
    
    def flush(self):
        # self.disk.append(self.head % self.n)
        # self.head = self.head // self.n
        self.disk.append(self.head & self.mask)
        self.head = self.head >> self.precision

    def load(self):
        # self.head = (self.head * self.n) + self.disk.pop()
        self.head = (self.head << self.precision) | self.disk.pop()

    def push(self, sym, freq):
        freq_sym = freq[sym]

        # if (self.head // self.n) >= freq[sym]:
        if (self.head >> self.precision) >= freq_sym:
            self.flush()

        u = self.head % freq_sym
        self.head = self.head // freq_sym
        offset = sum(freq[0:sym])
        z = u + offset
        self.head = self.head * self.n + z

    def search_cdfs(self, z, freq):
        sym, freq_sym, offset_sym = None, None, 0
        for s, freq_s in enumerate(freq):
            if offset_sym <= z and z < offset_sym + freq_s:
                sym = s
                freq_sym = freq_s
                break
            
            offset_sym = offset_sym + freq_s
        return sym, freq_sym, offset_sym


    def pop(self, freq):
        # z = self.head % self.n
        z = self.head & self.mask

        # self.head = self.head // self.n
        self.head = self.head >> self.precision
        sym, freq_sym, offset_sym = self.search_cdfs(z, freq)
        u = z - offset_sym
        self.head = self.head * freq_sym + u

        # if (self.head // self.n) == 0 and len(self.disk) != 0:
        if (self.head >> self.precision) == 0 and len(self.disk) != 0:
            self.load()

        return sym

    def get_compressed(self):
        compressed = self.disk.copy()
        head = self.head
        while head != 0:
            # compressed.append(head % self.n)
            compressed.append(head & self.mask)

            # self.head = head // self.n
            head = head >> self.precision
        
        return compressed


In [15]:
def src_seq(n, p):
    return np.random.choice(len(p), size=n, p=p).tolist()
    
def randpmf(size):
    pmf = np.abs(np.random.randn(size))
    pmf = pmf / pmf.sum()
    return pmf


precision = 32
alphabet_size = 256
pmf = randpmf(alphabet_size)
freq = quantize_pmf(pmf, precision)

message_length = 100
msg = src_seq(message_length, pmf)


coder = StreamANS(precision)


for i in range(len(msg)):
    coder.push(msg[i], freq)


print('sender message : {}'.format(msg[0:25]))

code = coder.get_compressed()

codelength = 0

for pack in code:
    codelength += pack.bit_length()

print('{:-^100}'.format(''))
print('ans codelength: {}'.format(codelength))
print('optimal codelength: {}'.format(message_length * (pmf * (-1) * np.log2(pmf)).sum()))
print('{:-^100}'.format(''))

dec = []
for i in reversed(range(len(msg))):
    dec.append(coder.pop(freq))

dec.reverse()
print('receiv message : {}'.format(dec[0:25]))


print('all matches : {}'.format(dec == msg))

sender message : [188, 67, 116, 53, 93, 249, 64, 173, 184, 67, 24, 1, 242, 166, 66, 189, 38, 202, 24, 86, 199, 95, 139, 93, 183]
----------------------------------------------------------------------------------------------------
ans codelength: 772
optimal codelength: 759.6923000371964
----------------------------------------------------------------------------------------------------
receiv message : [188, 67, 116, 53, 93, 249, 64, 173, 184, 67, 24, 1, 242, 166, 66, 189, 38, 202, 24, 86, 199, 95, 139, 93, 183]
all matches : True


In [16]:
message_length = 100
num_trials = 200
avg_codelength = 0

for _ in range(num_trials):

    msg = src_seq(message_length, pmf)
    coder = StreamANS(precision)
    for i in range(len(msg)):
        coder.push(msg[i], freq)
    
    code = coder.get_compressed()
    codelength = 0
    for pack in code:
        codelength += pack.bit_length()
    avg_codelength += codelength

avg_codelength = avg_codelength / num_trials
print(avg_codelength)

767.515
