In [1]:
from decimal import Decimal, getcontext
from fractions import Fraction
import math

In [101]:
class OptimizedFrequencyTable:

    def __init__(self, freqs):
        self.counts = freqs
        self.total = sum(freqs.values())
        self.freqs = self.__generate_freqs()
        self.freq_ranges = self.__generate_freqs_range()

    def increment(self, symbol):
        if symbol not in self.counts:
            self.counts[symbol] = 0
        self.counts[symbol] += 1
        self.total += 1

        self.freqs = self.__generate_freqs()
        self.freq_ranges = self.__generate_freqs_range()

    def __generate_freqs(self, round=True):
        freqs = {symbol: Fraction(self.counts[symbol], self.total) for symbol in self.symbols}
        if not round:
            return freqs

        freqs = {symbol: self.floor_quantize(freqs[symbol]) for symbol in freqs}

        prob_remainder = Fraction(1, 1) - sum(freqs.values())
        for symbol, prob in freqs.items():
            if prob_remainder <= 0:
                break
            if prob < prob_remainder:
                prob_remainder -= freqs[symbol]
                freqs[symbol] *= 2

        return freqs

    def __generate_freqs_range(self):
        freq_ranges = {}

        lower = Fraction(0, 1)
        for symbol, prob in self.freqs.items():
            freq_ranges[symbol] = (lower, prob)

            lower += prob

        return freq_ranges

    @property
    def symbols(self):
        return sorted(self.counts, key=lambda x: self.counts[x], reverse=True)

    def to_bytes(self):
        encoded = chr(len(self.counts)).encode()
        for symbol, freq in self.counts.items():
            encoded += symbol.encode()
            encoded += freq.to_bytes(4, byteorder='big')
        return encoded

    @staticmethod
    def from_bytes(encoded):
        freqs = {}

        count = encoded[0]

        for idx in range(1, count * 5, 5):
            freqs[chr(encoded[idx])] = int.from_bytes(encoded[idx + 1:idx + 5], byteorder='big')

        return OptimizedFrequencyTable(freqs)

    @staticmethod
    def from_text(text):
        freqs = {}
        for symbol in set(text):
            freqs[symbol] = text.count(symbol)

        return OptimizedFrequencyTable(freqs=freqs)

    @staticmethod
    def build_simple(num_symbols):
        freqs = {chr(num): 1 for num in range(num_symbols)}

        return OptimizedFrequencyTable(freqs)

    @staticmethod
    def floor_quantize(number):
        floor_exponent = abs(math.floor(math.log2(number)))
        return Fraction(1, 2 ** floor_exponent)

In [102]:
from bisect import bisect_right

class ArithmeticCoder:

    def encode(self, freqs):
        lower, delta = Fraction(0, 1), Fraction(1, 1)

        symbol = yield
        while symbol != None:
            current_lower, current_delta = freqs.freq_ranges[symbol]
            lower += current_lower * delta
            delta *= current_delta

            symbol = yield

        return lower

    def decode(self, encoded, length, freqs):
        for _ in range(length):
            range_lower_bounds = [symbol_range[0] for symbol_range in freqs.freq_ranges.values()]
            symbol_idx = bisect_right(range_lower_bounds, encoded) - 1
            symbol = freqs.symbols[symbol_idx]

            lower, delta = freqs.freq_ranges[symbol]
            encoded = (encoded - lower) / delta

            yield symbol

In [110]:
def calculate_precision(number):
    return math.ceil(math.log2(number) / 8)

def encode_fraction(fraction):
    precision = calculate_precision(fraction.denominator)
    encoded_number = int((2 ** (precision * 8)) * fraction)

    return precision.to_bytes(8, byteorder='big') + encoded_number.to_bytes(calculate_precision(encoded_number), byteorder='big')

def decode_fraction(encoded):
    precision = int.from_bytes(encoded[:8], byteorder='big')
    encoded_number = int.from_bytes(encoded[8:], byteorder='big')

    return Fraction(encoded_number, 2 ** (precision * 8))

In [104]:
from tqdm import tqdm

def arithmetic_encode(input_path, output_path):
    with open(input_path) as input:
        text = input.read()

    freqs = OptimizedFrequencyTable.from_text(text)
    encoder = ArithmeticCoder().encode(freqs)

    encoder.send(None)
    for symbol in tqdm(text, desc='Encoding'):
        encoder.send(symbol)
    
    try:
        encoder.send(None)    
    except StopIteration as encoded:
        encoded_number = encode_fraction(encoded.value)
        encoded_freqs = freqs.to_bytes()

        with open(output_path, 'wb') as output:
            payload = encoded_freqs + len(text).to_bytes(8, byteorder='big') + encoded_number
            output.write(payload)

def arithmetic_decode(input_path, output_path):
    with open(input_path, 'rb') as input:
        encoded = input.read()

    freqs = OptimizedFrequencyTable.from_bytes(encoded)
    freqs_table_length = encoded[0] * 5 + 1
    length = int.from_bytes(encoded[freqs_table_length:freqs_table_length + 8], byteorder='big')
    encoded_result = decode_fraction(encoded[freqs_table_length + 8:])

    decoder = ArithmeticCoder().decode(encoded_result, length, freqs)

    decoded = ''
    for symbol in tqdm(decoder, desc='Decoding'):
        decoded += symbol

    with open(output_path, 'w') as output:
        output.write(decoded)

In [106]:
with open('dickens.txt', 'rb') as file:
    data = file.read()

In [111]:
text = data[:100000].decode()
with open('a.txt', 'w') as file:
    file.write(text)

arithmetic_encode('a.txt', 'b.txt')
arithmetic_decode('b.txt', 'c.txt')

with open('b.txt', 'rb') as file:
    encoded = file.read()
with open('c.txt') as file:
    decoded = file.read()

print(f'Encoded size: {len(encoded)}')
print(f'Successfully: {decoded == text}')

Encoding: 100%|██████████| 100000/100000 [10:39:22<00:00,  2.61it/s]
Decoding: 100000it [4:38:52,  5.98it/s] 

Encoded size: 57195
Successfully: True





In [96]:
from tqdm import tqdm

def adaptive_arithmetic_encode(input_path, output_path):
    with open(input_path, 'r') as input:
        text = input.read()

    freqs = OptimizedFrequencyTable.build_simple(num_symbols=256)
    encoder = ArithmeticCoder().encode(freqs)

    encoder.send(None)
    for symbol in tqdm(text, desc='Encoding'):
        encoder.send(symbol)
        freqs.increment(symbol)
    
    try:
        encoder.send(None)
    except StopIteration as encoded:
        encoded_number = encode_fraction(encoded.value)

        with open(output_path, 'wb') as output:
            payload = len(text).to_bytes(8, byteorder='big') + encoded_number
            output.write(payload)

def adaptive_arithmetic_decode(input_path, output_path):
    with open(input_path, 'rb') as input:
        encoded = input.read()

    length = int.from_bytes(encoded[:8], byteorder='big')
    encoded_result = decode_fraction(encoded[8:])

    freqs = OptimizedFrequencyTable.build_simple(num_symbols=256)
    decoder = ArithmeticCoder().decode(encoded_result, length, freqs)

    decoded = ''
    for symbol in tqdm(decoder, desc='Decoding'):
        decoded += symbol
        freqs.increment(symbol)

    with open(output_path, 'w') as output:
        output.write(decoded)

In [None]:
text = data[:1000].decode()
with open('a.txt', 'w') as file:
    file.write(text)

adaptive_arithmetic_encode('a.txt', 'b.txt')
adaptive_arithmetic_decode('b.txt', 'c.txt')

with open('b.txt', 'rb') as file:
    encoded = file.read()
with open('c.txt') as file:
    decoded = file.read()

print(f'Encoded size: {len(encoded)}')
print(f'Successfully: {decoded == text}')