In [1]:
import numpy as np
from tqdm import tqdm_notebook as tqdm

In [2]:
def initialize(P):
    """
    Calculates the (unnormalized) CDF from P as well as its total mass
    
    returns - C, D, R
    """
    
    C = []
    D = []
    
    c = 0
    
    for p in P:
        
        C.append(c)
        
        c += p
        
        D.append(c)
        
    return C, D, D[-1]

In [3]:
def encoder(message, P, precision=32):
    
    # Calculate some stuff
    C, D, R = initialize(P)
    
    whole = 2**precision
    half = 2**(precision - 1)
    quarter = 2**(precision - 2)
    
    low = 0
    high = whole
    s = 0
    
    code = ""
    
    for k in tqdm(range(len(message))):
        
        width = high - low
        
        # Find interval for next symbol
        high = low + (width * D[message[k]]) // R
        low = low + (width * C[message[k]]) // R
        
        # Interval subdivision
        while high < half or low > half:
            
            # First case: we're in the lower half
            if high < half:
                code += "0" + "1" * s
                s = 0
                
                # Interval rescaling
                low *= 2
                high *= 2
                
            # Second case: we're in the upper half
            elif low > half:
                code += "1" + "0" * s
                s = 0
                
                low = (low - half) * 2
                high = (high - half) * 2
             
        # Middle rescaling
        while low > quarter and high < 3 * quarter:
            s += 1
            low = (low - quarter) * 2
            high = (high - quarter) * 2
            
    # Final emission step
    s += 1
    
    if low <= quarter:
        code += "0" + "1" * s
    else:
        code += "1" + "0" * s
    
    return code

In [4]:
def decoder(code, P, precision=32):
    
    # Calculate some stuff
    C, D, R = initialize(P)
    
    whole = 2**precision
    half = 2**(precision - 1)
    quarter = 2**(precision - 2)
    
    low = 0
    high = whole
    
    with tqdm(total=len(code)) as pbar:
        # Initialize representation of binary lower bound
        z = 0
        i = 0

        while i < precision and i < len(code):
            if code[i] == '1':
                z += 2**(precision - i - 1)
            i += 1

            # Update the progress bar
            pbar.update(1)

        message = []
    
    
        while True:

            # Find the current symbol
            for j in range(len(C)):

                width = high - low

                # Find interval for next symbol
                high_ = low + (width * D[j]) // R
                low_ = low + (width * C[j]) // R

                if low_ <= z < high_:

                    # Emit the current symbol
                    message.append(j)

                    # Update bounds
                    high = high_
                    low = low_

                    # Are we at the end?
                    if j == 0:
                        return message

                 # Interval rescaling
                while high < half or low > half:

                    # First case: we're in the lower half
                    if high < half:
                        low *= 2
                        high *= 2

                        z *= 2

                    # Second case: we're in the upper half
                    elif low > half:
                        low = (low - half) * 2
                        high = (high - half) * 2

                        z = (z - half) * 2

                    # Update the precision of the lower bound
                    if i < len(code) and code[i] == '1':
                        z += 1

                    i += 1
                    
                    # Update the progress bar
                    pbar.update(1)

                # Middle rescaling
                while low > quarter and high < 3 * quarter:
                    low = (low - quarter) * 2
                    high = (high - quarter) * 2
                    z = (z - quarter) * 2

                    # Update the precision of the lower bound
                    if i < len(code) and code[i] == '1':
                        z += 1

                    i += 1
                    
                    # Update the progress bar
                    pbar.update(1)

In [5]:
P = [1, 2, 2]
message = [2, 1, 0]

code = encoder(message, P)

print(code)

decoded = decoder(code, P)

print(decoded)

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


101100


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))


[2, 1, 0]


In [6]:
P = [5, 5, 50, 40]
message = [2, 3, 2, 0]

code = encoder(message, P)

print(code)

decoded = decoder(code + "000000101101101", P)

print(decoded)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))


011011000


HBox(children=(IntProgress(value=0, max=24), HTML(value='')))


[2, 3, 2, 0]


In [7]:
num_symbols = 2**10
message_length = 10000

P = np.ones(num_symbols + 1, dtype=np.int32)
P[1:] = np.random.choice(1000, size=num_symbols) + 1

message = np.zeros(message_length, dtype=np.int32)

message[:-1] = np.random.choice(num_symbols, size=message_length - 1) + 1

#print(message)

code = encoder(message, P)
decoded = decoder(code, P)

np.all(message == decoded)

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=103959), HTML(value='')))




True

In [10]:
# Pad the code
code += "0" * (8 - len(code) % 8) if len(code) % 8 != 0 else ""

message_bytes = [int('0b' + code[s:s + 8], 2) for s in range(0, len(code), 8)]

with open("../../compression/test.miracle", "wb") as compressed_file:
    compressed_file.write(bytes(message_bytes))

In [11]:
with open("../../compression/test.miracle", "rb") as compressed_file:
    compressed = ''.join(["{:08b}".format(x) for x in compressed_file.read()])
    decompressed = decoder(compressed, P)
    
np.all(decompressed == message)

HBox(children=(IntProgress(value=0, max=104704), HTML(value='')))




True