In [1]:
import random
import numpy as np


def quantize_pmf(pmf, precision):
     n = 1 << precision
     assert n > 2 * len(pmf)
     qt_pmf = np.zeros(pmf.shape, dtype=np.long)
     
     sum = 0 
     for i in range(len(pmf)):
         val = int(np.around(pmf[i] * n))
         if val == 0:
             val = val + 1
         qt_pmf[i] = val
         sum += val
     
     diff = n - sum
     
     if diff < 0:
         while diff != 0:
             for i in range(len(pmf)):
                 if qt_pmf[i] != 1:
                     qt_pmf[i] = qt_pmf[i] - 1
                     diff += 1
                 if diff == 0:
                     break
    
     elif diff > 0:
         while diff != 0:
             for i in range(len(pmf)):
                 if qt_pmf[i] != 1:
                     qt_pmf[i] = qt_pmf[i] + 1
                     diff -= 1
                 if diff == 0:
                     break

     return qt_pmf


class UniformCoder(object):
    def __init__(self, code=0):
        self.code = code

    def push(self, sym, base):
        assert base > sym
        self.code = self.code * base + sym

    def pop(self, base):
        sym = self.code % base
        self.code = self.code // base
        return sym


class SlowANS(object):
    def __init__(self, code):
        self.stack = UniformCoder(code)

    def push(self, sym, freqs):
        assert sym < len(freqs)
        n = sum(freqs)
        cdfs = freqs.cumsum()
        offset = 0 if sym == 0 else cdfs[sym - 1]
        sym_range = freqs[sym]
        z = offset + random.randint(0, sym_range - 1)
        self.stack.push(z, base=n)

    def pop(self, freqs):
        n = sum(freqs)
        cdfs = freqs.cumsum()
        z = self.stack.pop(base=n)
        out = np.searchsorted(cdfs, z, 'right')
        return out


class NearOptimalSlowANS(object):
    def __init__(self, code):
        self.stack = UniformCoder(code)

    def push(self, sym, freqs):
        assert sym < len(freqs)
        n = sum(freqs)
        cdfs = freqs.cumsum()
        offset = 0 if sym == 0 else cdfs[sym - 1]


        z = self.stack.pop(base=freqs[sym]) + offset
            
        self.stack.push(z, base=n)

    def pop(self, freqs):
        n = sum(freqs)
        cdfs = freqs.cumsum()
        z = self.stack.pop(base=n)
        sym = np.searchsorted(cdfs, z, 'right')
        offset = 0 if sym == 0 else cdfs[sym - 1]
        self.stack.push(z - offset, base=freqs[sym])
        return sym


In [2]:
def randpmf(size):
    pmf = np.abs(np.random.randn(size))
    pmf = pmf / pmf.sum()
    return pmf


msg = [0, 0, 1 ,4, 7]
precisions = [4, 12, 16, 14, 15] # any precisions

# generate random pmfs for each symbols
freqs = [quantize_pmf(randpmf(5), precision=4),
         quantize_pmf(randpmf(15), precision=8),
         quantize_pmf(randpmf(100), precision=15),
         quantize_pmf(randpmf(64), precision=10),
         quantize_pmf(randpmf(17), precision=7)]



In [3]:
coder = SlowANS(0)

for i in range(len(msg)):
    coder.push(msg[i], freqs[i])

print('code: {}'.format(coder.stack.code))


dec = []
for i in reversed(range(len(msg))):
    dec.append(coder.pop(freqs[i]))

dec.reverse()
print(dec)

code: 2237730799279
[0, 0, 1, 4, 7]


In [4]:
init_bits = random.randint(0, len(freqs[0]) - 1)
new_coder = NearOptimalSlowANS(init_bits)

for i in range(len(msg)):
    new_coder.push(msg[i], freqs[i])

print('new code: {}'.format(new_coder.stack.code))


dec = []
for i in reversed(range(len(msg))):
    dec.append(new_coder.pop(freqs[i]))

dec.reverse()
print(dec)

new code: 1871790
[0, 0, 1, 4, 7]
