In [1]:
import sys

In [2]:
sys.path.append("../../craystack/")

In [362]:
from rpc.codecs import ClusterCodec, UniformScalarCodec, ROCSortedListCodec, \
   BigUniformScalarCodec, UniformScalarCodec, UniformCodec, BigUniform
from rpc.rans import initialize_ans_state, compute_ans_state_size_in_bytes
from craystack.rans import flatten, unflatten
from sortedcontainers import SortedList

from rpc.rans import uniform_ans_decode, uniform_ans_encode
from craystack.rans import (
    push_with_finer_prec_uniform,
    pop_with_finer_prec_uniform)
from warnings import warn


In [13]:
import numpy as np

In [668]:
!realpath ../../craystack/

/private/home/matthijs/src/NeuralCompressionInternal/craystack


# Set up message 

In [558]:
rs = np.random.RandomState(123)

message = rs.choice(1<<27, size=200, replace=False)

In [559]:
message

array([ 59772995,  75292784, 102499427,  31057994, 115010580,  72426854,
        37756040,   3116445,  85225731,   3230095,  69496873,  37637763,
       104700345,   5649531,   9774730, 105241309, 122165381,  18313249,
        23050223,  95635329,   6709951, 114134362, 115026227,  99015361,
         1452319,  63497982,  98950678,  92912211, 125153849, 124345505,
        67589739,  11260686, 110230693, 115215129,  77483839, 108209701,
       127180119, 103930548,  91773555,  44634222, 101582688, 106721226,
        41434894,  18368807,   5664077,  36866035,  75210853, 104601341,
        51933754, 106073827,  86697163,  67067391,  75753566,  46515579,
        89724764, 125486197,  75686324,  68888975,  85096561, 133683491,
        80480950, 113677183, 116434843,    986322,  34209404,  46335715,
        10930208, 125964652,  98079629,  12209478,  36758074,  14612921,
        54210001,  61337183,  42460378,  20011719,  78927977,  95372569,
        90487549,  55813753,  27995001, 128287181, 

# Reference run

In [560]:
ans_state = initialize_ans_state()
codec = ClusterCodec(cluster_size=len(message))
ans_state = codec.encode(message, ans_state)
# print(flatten(ans_state))

_, new = codec.decode(ans_state)
assert set(new) == set(message)

# Step 1

In [561]:
rans_l = np.uint32(1 << 31)  # the lower bound of the normalisation interval


In [562]:

ans_state = (np.full(1, rans_l, "uint64"), ())
codec = ClusterCodec(cluster_size=len(message))
ans_state = codec.encode(message, ans_state)
# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new) == set(message)

# Step 2

In [564]:
ans_state = (np.full(1, rans_l, "uint64"), ())
roc_codec = ROCSortedListCodec(
    set_size=len(message),
    symbol_codec=BigUniformScalarCodec(log_prec=np.uint32(64)),
    copy_input=False,
)

sorted_seq = SortedList(message)
ans_state = roc_codec.encode(sorted_seq, ans_state)
# codec = ClusterCodec(cluster_size=len(message))
# ans_state = codec.encode(message, ans_state)
# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new) == set(message)

# Step 3

In [566]:
ans_state = (np.full(1, rans_l, "uint64"), ())

symbol_codec=BigUniformScalarCodec(log_prec=np.uint32(64))

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)
    ans_state, index = UniformScalarCodec(prec).decode(ans_state)

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = symbol_codec.encode(symbol, ans_state)

# print(flatten(ans_state))
# ans_state = roc_codec.encode(sorted_seq, ans_state)
_, new = codec.decode(ans_state)
assert set(new) == set(message)

# Step 4

In [567]:
ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, index = UniformScalarCodec(prec).decode(ans_state)
    vectorized_codec = UniformCodec([prec])
    ans_state, symbols = vectorized_codec.decode(ans_state)
    index = symbols[0]
    
    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = symbol_codec.encode(symbol, ans_state)

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new) == set(message)

# Step 5

In [568]:
ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    
    # vectorized_codec = UniformCodec(prec)
    # ans_state, symbols = vectorized_codec.decode(ans_state)

    ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    index = symbols[0]

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = symbol_codec.encode(symbol, ans_state)

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new) == set(message)

# Step 6

In [569]:
ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    symbols, pop = pop_with_finer_prec_uniform(ans_state, [prec], atleast_1d=True)
    ans_state = pop(symbols)
    index = symbols[0]

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = symbol_codec.encode(symbol, ans_state)

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new) == set(message)

# Step 7

In [570]:
atleast_1d_func = lambda x: np.atleast_1d(x).astype("uint64")
rng = np.random.default_rng(0)

def stack_slice(stack, n):
    slc = []
    while n > 0:
        if stack:
            arr, stack = stack
        else:
            warn("Popping from empty message. Generating random data.")
            arr, stack = rng.integers(1 << 32, size=n, dtype="uint32"), ()
        if n >= len(arr):
            slc.append(arr)
            n -= len(arr)
        else:
            slc.append(arr[:n])
            stack = arr[n:], stack
            break
    return stack, np.concatenate(slc)


def stack_extend(stack, arr):
    return arr, stack

In [571]:
def pop_with_finer_prec_uniform_2(ans_state, precisions, atleast_1d: bool = True):
    if atleast_1d:
        precisions = atleast_1d_func(precisions)

    head_, tail_ = ans_state
    # head_ in [2 ^ 32, 2 ^ 64)
    idxs = head_ >= precisions * ((rans_l // precisions) << np.uint8(32))
    if np.any(idxs):
        tail_ = stack_extend(tail_, np.uint32(head_[idxs]))
        head_ = np.copy(head_)  # Ensure no side-effects
        head_[idxs] >>= 32

    # head in [precisions * (2 ^ 32 // precisions), precisions * ((2 ^ 32 // precisions) << 32))
    # s' mod 2^r
    # TODO(dsevero): might be better to make the ANS head a scalar
    cfs = head_ % precisions
    if not atleast_1d:
        cfs = cfs[0]

    def pop(symbols):
        if atleast_1d:
            symbols = atleast_1d_func(symbols)

        # calculate previous state  s = p*(s' // 2^r) + (s' % 2^r) - c
        head = (head_ // precisions) + cfs - symbols

        # check which entries need renormalizing
        idxs = head < rans_l

        # how many 32*n bits do we need from the tail?
        n = np.sum(idxs)
        if n > 0:
            # new_head = 32*n bits from the tail
            # tail = previous tail, with 32*n less bits
            tail, new_head = stack_slice(tail_, n)
            print(n, new_head)

            # update LSBs of head, where needed
            head[idxs] = (head[idxs] << 32) | new_head
        else:
            tail = tail_
        return head, tail

    return cfs, pop

In [572]:
ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    symbols, pop = pop_with_finer_prec_uniform_2(ans_state, [prec], atleast_1d=True)
    
    ans_state = pop(symbols)
    index = symbols[0]

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = symbol_codec.encode(symbol, ans_state)

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new) == set(message)

1 [3653403231]
1 [113639424]
1 [17629184]
1 [128253952]
1 [88932352]
1 [7208960]
1 [25821184]
1 [23003136]
1 [113836032]
1 [5242880]
1 [46333952]
1 [27983872]
1 [78905344]
1 [59768832]
1 [75169792]
1 [106037248]
1 [45809664]
1 [115212288]
1 [113639424]
1 [34144256]
1 [19988480]
1 [104660992]
1 [48234496]
1 [73334784]
1 [19922944]
1 [95944704]
1 [82509824]
1 [35651584]
1 [17825792]
1 [57409536]
1 [4653056]
1 [14090240]
1 [38141952]
1 [67043328]
1 [82182144]
1 [114950144]
1 [88735744]
1 [78315520]
1 [116391936]


  warn("Popping from empty message. Generating random data.")


# Step 8

In [573]:

def pop_with_finer_prec_uniform_3(ans_state, precisions, atleast_1d: bool = True):
    if atleast_1d:
        precisions = atleast_1d_func(precisions)

    assert len(precisions) == 1
    precision = precisions[0]
    
    head_, tail_ = ans_state
    # head_ in [2 ^ 32, 2 ^ 64)
    # idxs = head_ >= precisions * ((rans_l // precisions) << np.uint8(32))
    # if np.any(idxs):
    head_0 = head_[0]
    if head_0 >= precision * ((rans_l // precision) << np.uint8(32)):
        tail_ = stack_extend(tail_, np.array([head_0], dtype='uint32'))
        head_ = np.copy(head_)  # Ensure no side-effects
        head_0 >>= np.uint8(32)

    # head in [precisions * (2 ^ 32 // precisions), precisions * ((2 ^ 32 // precisions) << 32))
    # s' mod 2^r
    # TODO(dsevero): might be better to make the ANS head a scalar

    cfs = head_0 % precision 
    
    def pop(symbols):
        if atleast_1d:
            symbols = atleast_1d_func(symbols)

        # calculate previous state  s = p*(s' // 2^r) + (s' % 2^r) - c
        head = (head_0 // precision) + cfs - symbols

        # check which entries need renormalizing
        if head_0 < rans_l:
            # new_head = 32*n bits from the tail
            # tail = previous tail, with 32*n less bits
            tail, new_head = stack_slice(tail_, 1)

            # update LSBs of head, where needed
            # head[idxs] = (head[idxs] << 32) | new_head
            head = (head << np.uint8(32)) | new_head[0]
        else:
            tail = tail_
        return head, tail

    return cfs, pop

In [575]:
ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    symbols, pop = pop_with_finer_prec_uniform_3(ans_state, [prec])
    
    ans_state = pop(symbols)
    index = symbols

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = symbol_codec.encode(symbol, ans_state)

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new) == set(message)

# Step 9

In [577]:
#symbol_codec=BigUniformScalarCodec(log_prec=np.uint32(64))
symbol_codec_vectorized_codec = BigUniformScalarCodec(log_prec=np.uint32(64))

ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    symbols, pop = pop_with_finer_prec_uniform_3(ans_state, [prec])
    
    ans_state = pop(symbols)
    index = symbols

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    # ans_state = symbol_codec.encode(symbol, ans_state)
    ans_state = symbol_codec_vectorized_codec.encode(np.array([symbol]), ans_state)

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new.ravel()) == set(message)

  symbol = sorted_seq.pop(int(index))


# Step 10

In [578]:
symbol_codec_vectorized_codec = BigUniform(np.uint32(64))

ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    symbols, pop = pop_with_finer_prec_uniform_3(ans_state, [prec])
    
    ans_state = pop(symbols)
    index = symbols

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    # ans_state = symbol_codec.encode(symbol, ans_state)
    # ans_state = symbol_codec_vectorized_codec.encode(np.array([symbol]), ans_state)
    (ans_state,) = symbol_codec_vectorized_codec.push(ans_state, np.array([symbol]))    

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new.ravel()) == set(message)

# Step 11

In [579]:
from craystack import vrans

In [581]:
_uniform_enc_statfun = lambda s: (s, 1)

def codec_push_1(message, symbol, precision): 
    start, freq = _uniform_enc_statfun(symbol)
    return vrans.push(message, start, freq, precision),

def codec_push(message, symbol, precision):
    for lower in [0, 16, 32, 48]:
        s = (symbol >> lower) & ((1 << 16) - 1)
        diff = np.where(precision >= lower, precision - lower, 0)
        p = np.minimum(diff, 16)
        
        # message, = Uniform(p).push(message, s)
        message, = codec_push_1(message, s, p)
    return message,    

ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    symbols, pop = pop_with_finer_prec_uniform_3(ans_state, [prec])
    
    ans_state = pop(symbols)
    index = symbols

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    # ans_state = symbol_codec.encode(symbol, ans_state)
    # ans_state = symbol_codec_vectorized_codec.encode(np.array([symbol]), ans_state)
    (ans_state,) = codec_push(ans_state, np.array([symbol]), precision=np.uint32(64))    

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new.ravel()) == set(message)

# Step 12

In [582]:

def vrans_push(x, starts, freqs, precisions):
    starts, freqs, precisions = map(atleast_1d_func, (starts, freqs, precisions))
    head, tail = x
    # assert head.shape == starts.shape == freqs.shape
    idxs = head >= ((rans_l >> precisions) << 32) * freqs
    if np.any(idxs):
        tail = stack_extend(tail, np.uint32(head[idxs]))
        head = np.copy(head)  # Ensure no side-effects
        head[idxs] >>= 32
    head_div_freqs, head_mod_freqs = np.divmod(head, freqs)
    return (head_div_freqs << precisions) + head_mod_freqs + starts, tail


def codec_push(message, symbol, precision):
    for lower in [0, 16, 32, 48]:
        s = (symbol >> lower) & ((1 << 16) - 1)
        diff = np.where(precision >= lower, precision - lower, 0)
        p = np.minimum(diff, 16)
        
        # message, = Uniform(p).push(message, s)
        message = vrans_push(message, s, 1, p)
    return message    

ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    symbols, pop = pop_with_finer_prec_uniform_3(ans_state, [prec])
    
    ans_state = pop(symbols)
    index = symbols

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = codec_push(ans_state, np.array([symbol]), precision=np.uint32(64))    

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new.ravel()) == set(message)

# Step 13

In [609]:
def stack_slice_13(stack):
    slc = []
    print("AAAA")
    assert n == 1
    if stack:
        arr, stack = stack
    else:
        warn("Popping from empty message. Generating random data.")
        arr, stack = rng.integers(1 << 32, size=n, dtype="uint32"), ()
    return stack, arr

def stack_extend(stack, arr):
    return arr, stack


def pop_with_finer_prec_uniform_13(ans_state, precision):
    precision = np.uint64(precision)
    
    head_, tail_ = ans_state
    # head_ in [2 ^ 32, 2 ^ 64)
    head_0 = head_[0]
    if head_0 >= precision * ((rans_l // precision) << np.uint8(32)):
        tail_ = stack_extend(tail_, np.array([head_0], dtype='uint32'))
        head_0 >>= np.uint8(32)

    # head in [precisions * (2 ^ 32 // precisions), precisions * ((2 ^ 32 // precisions) << 32))
    # s' mod 2^r


    cfs = head_0 % precision 

    symbol = cfs
        
    # calculate previous state  s = p*(s' // 2^r) + (s' % 2^r) - c
    head = (head_0 // precision) # + cfs - symbol
    # check which entries need renormalizing
    if head_0 < rans_l:
        # new_head = 32*n bits from the tail
        # tail = previous tail, with 32*n less bits
        tail, new_head = stack_slice_13(tail_)

        # update LSBs of head, where needed
        # head[idxs] = (head[idxs] << 32) | new_head
        head = (head << np.uint8(32)) | new_head[0]
    else:
        tail = tail_
    # return np.array([head], dtype='uint64'), tail

    return cfs, (np.array([head], dtype='uint64'), tail)


def vrans_push_13(x, start, freq, precision):
    start = int(start)
    freq = int(freq)
    precision = int(precision)
    head, tail = x
    head = int(head[0])
    if head >= ((rans_l >> precision) << 32) * freq:
        tail = stack_extend(tail, np.array([head & ((1<<32) - 1)], dtype=np.uint32))
        head >>= 32

    assert freq == 1
    head_div_freq, head_mod_freq = head, 0
    head_2 = (head_div_freq << precision) + head_mod_freq + start
    return np.array([head_2], dtype='uint64'), tail
   
def codec_push_13(message, symbol, precision):
    symbol = int(symbol) 
    precision = int(precision)
    for lower in [0, 16, 32, 48]:
        s = (symbol >> lower) & ((1 << 16) - 1)
        p = min(max(precision - lower, 0), 16)        
        message = vrans_push_13(message, s, 1, p)
    return message    

ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    prec = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    symbols, ans_state = pop_with_finer_prec_uniform_13(ans_state, prec)
    
    # ans_state = pop(symbols)
    index = symbols

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = codec_push_13(ans_state, symbol, precision=64)    

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new.ravel()) == set(message)

# Step 14

In [640]:
def stack_slice_14(stack):
    slc = []
    print("AAAA")
    assert n == 1
    if stack:
        arr, stack = stack
    else:
        warn("Popping from empty message. Generating random data.")
        arr, stack = rng.integers(1 << 32, size=n, dtype="uint32"), ()
    return stack, arr

def stack_extend(stack, arr):
    return arr, stack


def pop_with_finer_prec_uniform_14(ans_state, precision):
    precision = np.uint64(precision)
    
    head_, tail_ = ans_state
    # head_ in [2 ^ 32, 2 ^ 64)
    head_0 = head_[0]
    if head_0 >= precision * ((rans_l // precision) << np.uint8(32)):
        tail_ = stack_extend(tail_, np.array([head_0], dtype='uint32'))
        head_0 >>= np.uint8(32)

    # head in [precisions * (2 ^ 32 // precisions), precisions * ((2 ^ 32 // precisions) << 32))
    # s' mod 2^r


    cfs = head_0 % precision 

    symbol = cfs
        
    # calculate previous state  s = p*(s' // 2^r) + (s' % 2^r) - c
    head = (head_0 // precision) # + cfs - symbol
    # check which entries need renormalizing
    if head_0 < rans_l:
        # new_head = 32*n bits from the tail
        # tail = previous tail, with 32*n less bits
        tail, new_head = stack_slice_14(tail_)

        # update LSBs of head, where needed
        # head[idxs] = (head[idxs] << 32) | new_head
        head = (head << np.uint8(32)) | new_head[0]
    else:
        tail = tail_
    # return np.array([head], dtype='uint64'), tail

    return cfs, (np.array([head], dtype='uint64'), tail)


def vrans_push_14(x, start, freq, precision):
    start = int(start)
    freq = int(freq)
    precision = int(precision)
    head, tail = x
    head = int(head[0])
    if head >= ((rans_l >> precision) << 32) * freq:
        tail = stack_extend(tail, np.array([head & ((1<<32) - 1)], dtype=np.uint32))
        head >>= 32

    assert freq == 1
    head_div_freq, head_mod_freq = head, 0
    head_2 = (head_div_freq << precision) + head_mod_freq + start
    return np.array([head_2], dtype='uint64'), tail
   
def codec_push_14(message, symbol, precision):
    symbol = int(symbol) 
    precision = int(precision)
    for lower in [0, 16, 32, 48]:
        s = (symbol >> lower) & ((1 << 16) - 1)
        p = min(max(precision - lower, 0), 16)        
        message = vrans_push_13(message, s, 1, p)
    return message    

ans_state = (np.full(1, rans_l, "uint64"), ())

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    nmax = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    index, ans_state = pop_with_finer_prec_uniform_14(ans_state, nmax)
    
    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    ans_state = codec_push_14(ans_state, symbol, precision=64)    

# print(flatten(ans_state))
_, new = codec.decode(ans_state)
assert set(new.ravel()) == set(message)

# Step 15

In [657]:
class ANSState: 

    def __init__(self): 
        self.head = np.array([rans_l], dtype='uint64')
        self.tail = ()

    def extend_stack(self, x): 
        x = np.array([x], dtype=np.uint32)
        self.tail = x, self.tail

    def set_head(self, x): 
        self.head = np.array([x], dtype='uint64')

    def stack_slice(self): 
        if self.tail: 
            arr, stack = self.tail
            self.tail = stack 
            return arr
        else: 
            warn("Popping from empty message. Generating random data.")
            return rng.integers(1 << 32, size=n, dtype="uint32")

    def get_head(self): 
        return self.head[0]


def pop_with_finer_prec_uniform_15(ans_state, precision):
    precision = np.uint64(precision)
    
    # head_ = ans_state.head
    # tail_ = ans_state
    # head_ in [2 ^ 32, 2 ^ 64)
    # head_0 = head_[0]
    head_0 = ans_state.get_head()
    if head_0 >= precision * ((rans_l // precision) << np.uint8(32)):
        # tail_ = stack_extend(tail_, np.array([head_0], dtype='uint32'))
        # stack_extend_15(ans_state, np.array([head_0], dtype='uint32'))
        ans_state.extend_stack(head_0)
        head_0 >>= np.uint8(32)

    # head in [precisions * (2 ^ 32 // precisions), precisions * ((2 ^ 32 // precisions) << 32))
    # s' mod 2^r

    cfs = head_0 % precision 

    symbol = cfs
        
    # calculate previous state  s = p*(s' // 2^r) + (s' % 2^r) - c
    head = (head_0 // precision) # + cfs - symbol
    # check which entries need renormalizing
    if head_0 < rans_l:
        # new_head = 32*n bits from the tail
        # tail = previous tail, with 32*n less bits
        # new_head = stack_slice_15(ans_state)
        new_head = ans_state.stack_slice()

        # update LSBs of head, where needed
        # head[idxs] = (head[idxs] << 32) | new_head
        head = (head << np.uint8(32)) | new_head[0]
        
    # return np.array([head], dtype='uint64'), tail
    ans_state.set_head(head)
    return cfs


def vrans_push_15(ans_state, start, freq, precision):
    start = int(start)
    freq = int(freq)
    precision = int(precision)
    # head, tail = x
    head = int(ans_state.get_head())
    if head >= ((rans_l >> precision) << 32) * freq:
        # stack_extend_15(ans_state, np.array([head & ((1<<32) - 1)], dtype=np.uint32))
        ans_state.extend_stack(head & ((1<<32) - 1))
        head >>= 32

    assert freq == 1
    head_div_freq, head_mod_freq = head, 0
    head_2 = (head_div_freq << precision) + head_mod_freq + start
    ans_state.set_head(head_2)
       
def codec_push_15(ans_state, symbol, precision):
    symbol = int(symbol) 
    precision = int(precision)
    for lower in [0, 16, 32, 48]:
        s = (symbol >> lower) & ((1 << 16) - 1)
        p = min(max(precision - lower, 0), 16)        
        vrans_push_15(ans_state, s, 1, p)

ans_state = ANSState()

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    nmax = np.uint32(set_size - i)

    # ans_state, symbols = uniform_ans_decode(ans_state, [prec])
    index = pop_with_finer_prec_uniform_15(ans_state, nmax)
    
    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    codec_push_15(ans_state, symbol, precision=64)    

# print(flatten(ans_state))
_, new = codec.decode((ans_state.head, ans_state.tail))
assert set(new.ravel()) == set(message)

# Step 16

In [667]:
class ANSState: 

    def __init__(self): 
        self.head = np.array([rans_l], dtype='uint64')
        self.tail = ()

    def extend_stack(self, x): 
        x = np.array([x], dtype=np.uint32)
        self.tail = x, self.tail

    def set_head(self, x): 
        self.head = np.array([x], dtype='uint64')

    def stack_slice(self): 
        if self.tail: 
            arr, stack = self.tail
            self.tail = stack 
            return arr
        else: 
            warn("Popping from empty message. Generating random data.")
            return rng.integers(1 << 32, size=n, dtype="uint32")

    def get_head(self): 
        return self.head[0]

    def get_head_tail(self): 
        return self.head, self.tail 


def pop_with_finer_prec_uniform_16(ans_state, nmax):
    nmax = np.uint64(nmax)

    head_0 = ans_state.get_head()
    if head_0 >= nmax * ((rans_l // nmax) << np.uint8(32)):
        print("AAA")
        ans_state.extend_stack(head_0)
        head_0 >>= np.uint8(32)

    # head in [precisions * (2 ^ 32 // precisions), precisions * ((2 ^ 32 // precisions) << 32))
    # s' mod 2^r
    cfs = head_0 % nmax 

    # symbol = cfs
        
    # calculate previous state  s = p*(s' // 2^r) + (s' % 2^r) - c
    head = head_0 // nmax
    # check which entries need renormalizing
    if head_0 < rans_l:
        print("BBB")
        # tail = previous tail, with 32*n less bits
        new_head = ans_state.stack_slice()

        # update LSBs of head, where needed
        head = (head << np.uint8(32)) | new_head[0]
        
    ans_state.set_head(head)
    return cfs


def vrans_push_16(ans_state, start, precision):
    start = int(start)
    precision = int(precision)
    head = int(ans_state.get_head())
    if head >= ((rans_l >> precision) << 32):
        ans_state.extend_stack(head & ((1<<32) - 1))
        head >>= 32

    head_2 = (head << precision) + start
    ans_state.set_head(head_2)
       
def codec_push_16(ans_state, symbol, precision):
    symbol = int(symbol) 
    precision = int(precision)
    for lower in [0, 16, 32, 48]:
        s = (symbol >> lower) & ((1 << 16) - 1)
        p = min(max(precision - lower, 0), 16)        
        vrans_push_16(ans_state, s, p)

ans_state = ANSState()

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    nmax = np.uint32(set_size - i)

    index = pop_with_finer_prec_uniform_16(ans_state, nmax)
    
    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    codec_push_16(ans_state, symbol, precision=64)    

# print(flatten(ans_state))
_, new = codec.decode(ans_state.get_head_tail())
assert set(new.ravel()) == set(message)

# Step 14 -- missed

In [630]:
class ANSState: 

    def __init__(self): 
        self.head = np.uint64(rans_l)
        self.tail = []

    def extend_stack(self, x): 
        arr = np.uint32(x)
        self.tail.append(arr)

    def set_head(self, x): 
        self.head = np.uint64(x)

    def stack_slice_1(self): 
        if len(self.stack) == 0: 
            warn("Popping from empty message. Generating random data.")
            return rng.integers(1 << 32, size=n, dtype="uint32")
        else: 
            return self.stack.pop(-1)

    def to_nested_representation(self): 
        tt = ()
        for t in reversed(self.tail): 
            tt = np.array([t]), tt
        return np.array([self.head]), tt

In [656]:
def stack_slice_14(stack):
    slc = []
    print("AAAA")
    assert n == 1
    if stack:
        arr, stack = stack
    else:
        warn("Popping from empty message. Generating random data.")
        arr, stack = rng.integers(1 << 32, size=n, dtype="uint32"), ()
    return stack, arr


def stack_extend_14(ansstate, arr):
    ansstate.extend_stack(arr)


def pop_with_finer_prec_uniform_14(ans_state, precision):
    precision = np.uint64(precision)

    
    # head_, tail_ = ans_state
    # head_ in [2 ^ 32, 2 ^ 64)
    head_0 = ans_state.head
    if head_0 >= precision * ((rans_l // precision) << np.uint8(32)):
        print("BBBB")
        ans_state.extend_stack(head_0)
        head_0 >>= np.uint8(32)

    # head in [precisions * (2 ^ 32 // precisions), precisions * ((2 ^ 32 // precisions) << 32))
    # s' mod 2^r
    cfs = head_0 % precision 

    symbol = cfs
        
    # calculate previous state  s = p*(s' // 2^r) + (s' % 2^r) - c
    head = (head_0 // precision) # + cfs - symbol
    # check which entries need renormalizing
    if head_0 < rans_l:
        print("AAAA")
        # new_head = 32*n bits from the tail
        # tail = previous tail, with 32*n less bits
        new_head = ans_state.stack_slice_1(ans_state)

        # update LSBs of head, where needed
        # head[idxs] = (head[idxs] << 32) | new_head
        head = (head << np.uint8(32)) | new_head
    # return np.array([head], dtype='uint64'), tail

    ans_state.set_head(head)
    return cfs # , (np.array([head], dtype='uint64'), tail)


def vrans_push_14(ans_state, start, precision):
    start = int(start)
    precision = int(precision)
    # put 32 lsb of head in the tail 
    head = int(ans_state.head)
    if head >= ((rans_l >> precision) << 32):
        ans_state.extend_stack(head & ((1<<32) - 1))
        head >>= 32

    head_2 = (head << precision) + start
    ans_state.set_head(head_2)
   
def codec_push_14(message, symbol, precision):
    symbol = int(symbol) 
    precision = int(precision)
    # encode 16 by 16 bits
    for lower in [0, 16, 32, 48]:
        s = (symbol >> lower) & ((1 << 16) - 1)
        p = min(max(precision - lower, 0), 16)        
        vrans_push_14(ans_state, s, p)

ans_state = ANSState()

sorted_seq = SortedList(message)
set_size = len(sorted_seq)
for i in range(set_size):
    # Sample/Decode, without replacement, an index using ANS.
    # Initialize a uniform codec for the indices.
    maxval = np.uint32(set_size - i)

    index, pop_with_finer_prec_uniform_14(ans_state, maxval)

    # `index` is NDArray[uint], need to cast to int to pick the element.
    symbol = sorted_seq.pop(int(index))

    # Encode the element into the ans state.
    codec_push_14(ans_state, symbol, precision=64)    

# print(flatten(ans_state))
_, new = codec.decode(ans_state.to_nested_representation())
assert set(new.ravel()) == set(message)

  head = int(ans_state.head)


AttributeError: 'ANSState' object has no attribute 'to_nested_representation'

In [634]:
ans_state.to_nested_representation()

(array([13193052160], dtype=uint64),
 (array([215089167], dtype=uint32),
  (array([3755398540], dtype=uint32),
   (array([689897494], dtype=uint32),
    (array([4251646922], dtype=uint32),
     (array([1703936], dtype=uint32),
      (array([1133650421], dtype=uint32),
       (array([3014656], dtype=uint32),
        (array([4130340910], dtype=uint32),
         (array([3457999310], dtype=uint32),
          (array([2375876655], dtype=uint32),
           (array([3099478415], dtype=uint32),
            (array([3211264], dtype=uint32),
             (array([3827687055], dtype=uint32),
              (array([4653056], dtype=uint32),
               (array([2020147277], dtype=uint32),
                (array([2788355731], dtype=uint32),
                 (array([2093613136], dtype=uint32),
                  (array([3526424115], dtype=uint32),
                   (array([5242880], dtype=uint32),
                    (array([3567727739], dtype=uint32),
                     (array([5636096], dtype=uint3

In [669]:
1+2

3