In [16]:
import mpmath
import numpy as np
import random
from fractions import Fraction

# # 2. 构建字符区间映射函数
# 引入一个辅助函数 build_interval_map(p_dict)，用于根据概率字典生成每个字符对应的区间范围。
# 示例输出格式：{'a': (low_a, high_a), 'b': (low_b, high_b), ...}

def build_interval_map(p_dict):
    interval_map = {}
    low = mpmath.mpf("0.0")
    for symbol, prob in p_dict.items():
        high = low + prob
        interval_map[symbol] = (low, high)
        low = high
    return interval_map

def arithmetic_encode(input_str, p_dict, dps=10):
    low = mpmath.mpf("0.0")
    high = mpmath.mpf("1.0")

    interval_map = build_interval_map(p_dict)

    steps = []
    for symbol in input_str:
        current_range = high - low
        sym_low, sym_high = interval_map[symbol]

        new_low = low + current_range * sym_low
        new_high = low + current_range * sym_high

        steps.append({
            'symbol': symbol,
            'low': f"{float(low):.{dps}f}",
            'high': f"{float(high):.{dps}f}",
            'new_low': f"{float(new_low):.{dps}f}",
            'new_high': f"{float(new_high):.{dps}f}",
            'calculation': f"range: {float(sym_low):.{dps}f}-{float(sym_high):.{dps}f}, "
                           f"new_low = {float(low):.{dps}f} + ({float(high):.{dps}f} - {float(low):.{dps}f}) * {float(sym_low):.{dps}f} = {float(new_low):.{dps}f}, "
                           f"new_high = {float(low):.{dps}f} + ({float(high):.{dps}f} - {float(low):.{dps}f}) * {float(sym_high):.{dps}f} = {float(new_high):.{dps}f}"
        })

        low, high = new_low, new_high

    encoded_value = (low + high) / 2
    return encoded_value, steps

def arithmetic_decode(encoded_value, p_dict, length, dps=10):
    low = mpmath.mpf("0.0")
    high = mpmath.mpf("1.0")
    decoded_str = []
    steps = []

    interval_map = build_interval_map(p_dict)

    for _ in range(length):
        current_range = high - low
        value_norm = (encoded_value - low) / current_range  # 归一化到 [0,1]

        # 找出落在哪个区间
        for symbol, (sym_low, sym_high) in interval_map.items():
            if sym_low <= value_norm < sym_high:
                decoded_str.append(symbol)
                new_low = low + current_range * sym_low
                new_high = low + current_range * sym_high

                steps.append({
                    'symbol': symbol,
                    'low': f"{float(low):.{dps}f}",
                    'high': f"{float(high):.{dps}f}",
                    'value_norm': f"{float(value_norm):.{dps}f}",
                    'interval': f"[{float(sym_low):.{dps}f}, {float(sym_high):.{dps}f})",
                    'new_low': f"{float(new_low):.{dps}f}",
                    'new_high': f"{float(new_high):.{dps}f}",
                    'calculation': f"value_norm = {float(value_norm):.{dps}f} 属于 [{float(sym_low):.{dps}f}, {float(sym_high):.{dps}f}) → 符号 '{symbol}'"
                })

                low, high = new_low, new_high
                break

    return ''.join(decoded_str), steps



def print_encoding_steps(steps, dps=10):
    print("="*50)
    print("编码过程详解：")
    for i, step in enumerate(steps, 1):
        print(f"步骤 {i}: 处理符号 '{step['symbol']}'")
        print(f"  当前区间: [{step['low']}, {step['high']})")
        print(f"  计算: {step['calculation']}")
        print(f"  新区间: [{step['new_low']}, {step['new_high']})")
        print("-"*50)
    print("encode_value: ",encoded_value)

def print_decoding_steps(steps, dps=10):
    print("="*50)
    print("解码过程详解：")
    for i, step in enumerate(steps, 1):
        print(f"步骤 {i}: 解码符号 '{step['symbol']}'")
        print(f"  当前区间: [{step['low']}, {step['high']})")
        print(f"  分割点: {step['split']}")
        print(f"  判断: {step['calculation']}")
        print(f"  新区间: [{step['new_low']}, {step['new_high']})")
        print("-"*50)


# 浮点类型转分数 转二进制
def mpf_to_fraction_str(num):
    # Fraction
    frac = Fraction(float(num)) # 限制分母大小
    return f"{frac.numerator}/{frac.denominator}", frac

def fraction_to_mpf(frac):
    # 二进制
    return mpmath.mpf(f"{frac.numerator}") / mpmath.mpf(f"{frac.denominator}")


def compare_binary_strings(bin_str1, bin_str2):
    """
    比较两个二进制串，找出不同的 bit 位，并可视化显示差异。

    参数:
        bin_str1 (str): 第一个二进制串
        bin_str2 (str): 第二个二进制串

    返回:
        diff_positions (list): 不同 bit 位的位置列表
        visualization (str): 差异的可视化字符串
    """
    # 确保两个二进制串长度相同
    if len(bin_str1) != len(bin_str2):
        raise ValueError("两个二进制串的长度必须相同")

    # 找出不同的 bit 位
    diff_positions = [i for i in range(len(bin_str1)) if bin_str1[i] != bin_str2[i]]

    # 可视化显示差异
    visualization = []
    for i in range(len(bin_str1)):
        if i in diff_positions:
            visualization.append(f"\033[91m{bin_str1[i]}\033[0m")  # 红色显示不同
        else:
            visualization.append(f"\033[92m{bin_str1[i]}\033[0m")  # 正常显示相同
    visualization = "".join(visualization)

    return diff_positions, visualization

# 示例使用
if __name__ == "__main__":
    # 输入参数
    # 1011010111111111111111111111111111100000000000000010010110101010101010010101010101010010101010101
    # input_str = ''.join(random.choices('01', k=10000))
    # input_str = "10110101111111111111111111111111111000000000000000100101101010101010100101010101010100101010101011101110011100011011001"    # 要编码的二进制字符串
    # 自定义字符集和概率分布
    mpmath.mp.dps = 10

    p_dict = {'a': mpmath.mpf("0.50"), 'b': mpmath.mpf("0.10"), 'c': mpmath.mpf("0.1"), 'd': mpmath.mpf("0.05"), 'e': mpmath.mpf("0.25")}

    codec_words = ''.join(p_dict.keys())
    print("编码字符集:", codec_words)
    input_str = ''.join(random.choices(codec_words, k=10000)) #'"abacababcabccba"  # 可以是任意字符组合

    print("原始字符串:", input_str)
    print("概率表K: ", p_dict.keys())
    print("概率表V: ", p_dict.values())

    mpmath.mp.dps = 3000
    # 编码过程
    encoded_value, encode_steps = arithmetic_encode(input_str, p_dict)
    print("\n最终编码值:", encoded_value)

    # 转换为分数精度损失很大
    str, frac = mpf_to_fraction_str(encoded_value)
    print("编码值的分数表示:", str)
    # print_encoding_steps(encode_steps)

    # 解码过程
    print("\n开始解码...")
    # encoded_value = fraction_to_mpf(frac)
    print("编码值:", encoded_value)
    decoded_str, decode_steps = arithmetic_decode(encoded_value, p_dict, len(input_str))

    # print_decoding_steps(decode_steps)

    print("\n解码结果:", decoded_str)

    diff_position, visualization = compare_binary_strings(decoded_str, input_str)
    print("差异可视化:", visualization)

    print("解码结果与原始字符串是否相同:", len(diff_position) == 0)
    print("误差率: ", len(diff_position) / len(input_str))

    # print_decoding_steps(decode_steps)

编码字符集: abcde
原始字符串: cabdaceebbcddeccacdbcccbacaebcccedeacedccadecaeaccbcbdeaaeccdcadcbbaadbabaedbaebdbdbadbcddddebdeaacbddebabccceeeedeadddccdbcedeaacababcbebcceeedcaeecbbdabdcddaabddcceaadabdabbecdbccdeedbebbcabcecbcbabdabcdeecaecaeccdabdebcaeedbdeaaedadcedbadbbbbaaeecdbccdcdccdcccbbecebbcdaeedebdeeebcddabddcccdaaceaaaaedaeedaecabbeacabdcbaebbecebccbbebcdbabbbeeeedacbcaedadbeabbceecdadcecbbcceeecbbecdddacaeebceebaeaadbeadcaeeddaebebecccddbaadbdcbecdccebcacacebbceeedeedabbecdddcecdbebeaecbeaaecacaddaeaeebcccccbeeabdaacccbbabedccbbaaccdabcdddbcddbecdccbeccbdccedbecbdaabbcbcdbedabeaaaadbdccabeeeecbaccedebdabeeaeedecdcebcedcdbbabbebeabcdcdbbacccaaaebaeddbdbbaddceacdebdebedeebecdeddecbdbccaebacdccdbbecbdeadecebedaeddaececeabdebaceeadaeddcbccabecbbeeedcabebcbcccbdbdaaacbedccacbebaaceeccccdaebecbdeeadacbeeadbaabbadbdeddaeebcedbdaeddecabbcaaeaecddebbadbdecaecbeadbcaebcababcecdceecdeadadadcbcdbbbdeadddaeacddbbededaaeecaebddbdcbdceeddacaedcbbabdbbcdebdccdebdbbedbdebabcbabadbcbedcabcbcdddceaadea

ZeroDivisionError: 