In [39]:
import mpmath
import random
from fractions import Fraction

# # 2. 构建字符区间映射函数
# 引入一个辅助函数 build_interval_map(p_dict)，用于根据概率字典生成每个字符对应的区间范围。
# 示例输出格式：{'a': (low_a, high_a), 'b': (low_b, high_b), ...}

def build_interval_map(p_dict):
    interval_map = {}
    low = mpmath.mpf("0.0")
    for symbol, prob in p_dict.items():
        high = low + prob
        interval_map[symbol] = (low, high)
        low = high
    return interval_map

def arithmetic_encode(input_str, p_dict, dps=10):
    low = mpmath.mpf("0.0")
    high = mpmath.mpf("1.0")

    interval_map = build_interval_map(p_dict)

    steps = []
    for symbol in input_str:
        current_range = high - low
        sym_low, sym_high = interval_map[symbol]

        new_low = low + current_range * sym_low
        new_high = low + current_range * sym_high

        steps.append({
            'symbol': symbol,
            'low': f"{float(low):.{dps}f}",
            'high': f"{float(high):.{dps}f}",
            'new_low': f"{float(new_low):.{dps}f}",
            'new_high': f"{float(new_high):.{dps}f}",
            'calculation': f"range: {float(sym_low):.{dps}f}-{float(sym_high):.{dps}f}, "
                           f"new_low = {float(low):.{dps}f} + ({float(high):.{dps}f} - {float(low):.{dps}f}) * {float(sym_low):.{dps}f} = {float(new_low):.{dps}f}, "
                           f"new_high = {float(low):.{dps}f} + ({float(high):.{dps}f} - {float(low):.{dps}f}) * {float(sym_high):.{dps}f} = {float(new_high):.{dps}f}"
        })

        low, high = new_low, new_high

    encoded_value = (low + high) / 2
    return encoded_value, steps

def arithmetic_decode(encoded_value, p_dict, length, dps=10):
    low = mpmath.mpf("0.0")
    high = mpmath.mpf("1.0")
    decoded_str = []
    steps = []

    interval_map = build_interval_map(p_dict)

    for _ in range(length):
        current_range = high - low
        if current_range == 0:
            # 填充p_dict第一个字符
            symbol = list(p_dict.keys())[0]
            decoded_str.append(symbol)
            continue
        else:
            value_norm = (encoded_value - low) / current_range  # 归一化到 [0,1]

        # 找出落在哪个区间
        for symbol, (sym_low, sym_high) in interval_map.items():
            if sym_low <= value_norm < sym_high:
                decoded_str.append(symbol)
                new_low = low + current_range * sym_low
                new_high = low + current_range * sym_high

                steps.append({
                    'symbol': symbol,
                    'low': f"{float(low):.{dps}f}",
                    'high': f"{float(high):.{dps}f}",
                    'value_norm': f"{float(value_norm):.{dps}f}",
                    'interval': f"[{float(sym_low):.{dps}f}, {float(sym_high):.{dps}f})",
                    'new_low': f"{float(new_low):.{dps}f}",
                    'new_high': f"{float(new_high):.{dps}f}",
                    'calculation': f"value_norm = {float(value_norm):.{dps}f} 属于 [{float(sym_low):.{dps}f}, {float(sym_high):.{dps}f}) → 符号 '{symbol}'"
                })

                low, high = new_low, new_high
                break

    return ''.join(decoded_str), steps



def print_encoding_steps(steps, dps=10):
    print("="*50)
    print("编码过程详解：")
    for i, step in enumerate(steps, 1):
        print(f"步骤 {i}: 处理符号 '{step['symbol']}'")
        print(f"  当前区间: [{step['low']}, {step['high']})")
        print(f"  计算: {step['calculation']}")
        print(f"  新区间: [{step['new_low']}, {step['new_high']})")
        print("-"*50)
    print("encode_value: ",encoded_value)

def print_decoding_steps(steps, dps=10):
    print("="*50)
    print("解码过程详解：")
    for i, step in enumerate(steps, 1):
        print(f"步骤 {i}: 解码符号 '{step['symbol']}'")
        print(f"  当前区间: [{step['low']}, {step['high']})")
        print(f"  分割点: {step['split']}")
        print(f"  判断: {step['calculation']}")
        print(f"  新区间: [{step['new_low']}, {step['new_high']})")
        print("-"*50)


# 浮点类型转分数 转二进制
def mpf_to_fraction_str(num):
    # Fraction
    frac = Fraction(float(num)) # 限制分母大小
    return f"{frac.numerator}/{frac.denominator}", frac

def fraction_to_mpf(frac):
    # 二进制
    return mpmath.mpf(f"{frac.numerator}") / mpmath.mpf(f"{frac.denominator}")


def compare_binary_strings(enc_str, dec_str):

    # 比较两个二进制串，找出不同的 bit 位，并可视化显示差异。
    #
    # 参数:
    #     bin_str1 (str): 第一个二进制串
    #     bin_str2 (str): 第二个二进制串
    #
    # 返回:
    #     diff_positions (list): 不同 bit 位的位置列表
    #     visualization (str): 差异的可视化字符串

    # 确保两个二进制串长度相同
    if len(enc_str) != len(dec_str):
        print("两个二进制串的长度必须相同")

    # 找出不同的 bit 位
    diff_positions = []

    # 可视化显示差异
    visualization = []
    for i in range(len(enc_str)):
        if i < len(dec_str):
            if enc_str[i] != dec_str[i]:
                visualization.append(f"\033[91m{enc_str[i]}\033[0m")  # 红色显示不同
                diff_positions.append(i)
            else:
                visualization.append(f"\033[92m{enc_str[i]}\033[0m")  # 绿色显示相同
        else:
            diff_positions.append(i)
            visualization.append(f"\033[92m{enc_str[i]}\033[0m")  # 正常显示相同

    visualization = "".join(visualization)

    return diff_positions, visualization

def generate_ascii_code():
    # 生成所有 ASCII 字符 (0 ~ 255)
    all_ascii_chars = [chr(i) for i in range(32, 127)]

    # 随机生成概率分布（总和为 1）
    probabilities = [random.random() for _ in all_ascii_chars]
    total = sum(probabilities)
    probabilities = [p / total for p in probabilities]

    # 构建 p_dict
    return {chr(i): mpmath.mpf(f"{prob}") for i, prob in zip(range(32, 127), probabilities)}

# 示例使用
if __name__ == "__main__":
    # 输入参数
    # 1011010111111111111111111111111111100000000000000010010110101010101010010101010101010010101010101
    # input_str = ''.join(random.choices('01', k=10000))
    # input_str = "10110101111111111111111111111111111000000000000000100101101010101010100101010101010100101010101011101110011100011011001"    # 要编码的二进制字符串
    # 自定义字符集和概率分布
    mpmath.mp.dps = 10

    p_dict = generate_ascii_code()

    codec_words = ''.join(p_dict.keys())
    print("\n编码字符集:", codec_words)
    input_str = ''.join(random.choices(codec_words, k=1000)) #'"abacababcabccba"  # 可以是任意字符组合

    print("原始字符串:", input_str)
    print("\n概率表K: ", p_dict.keys())
    print("\n概率表V: ", p_dict.values())

    mpmath.mp.dps = 3000 # 精度太小出现除零异常
    # 编码过程
    encoded_value, encode_steps = arithmetic_encode(input_str, p_dict)
    print("\n最终编码值:", encoded_value)

    # 转换为分数精度损失很大
    frac_str, frac = mpf_to_fraction_str(encoded_value)
    print("\n编码值的分数表示:", frac_str)
    # print_encoding_steps(encode_steps)

    # 解码过程
    print("\n开始解码...")
    # encoded_value = fraction_to_mpf(frac)
    print("编码值:", encoded_value)
    decoded_str, decode_steps = arithmetic_decode(encoded_value, p_dict, len(input_str))

    # print_decoding_steps(decode_steps)

    print("\n解码结果:", decoded_str)

    diff_position, visualization = compare_binary_strings(input_str, decoded_str)
    print("差异可视化(enc_str):", visualization)

    print("解码结果与原始字符串是否相同:", len(diff_position) == 0)
    print("误差率: ", len(diff_position) / len(input_str))

    # print_decoding_steps(decode_steps)


编码字符集:  !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~
原始字符串: xOS[$+gdpTs[N} 4)t=lcH_&V0OBD@;l.<;eAkIdE3_^q%34%]97+FR;)I|mLIl_vpj*SUkpHmNyZWM,~V<~x!gqY?$cn\UqZ>Mi.eg&an]3m>^"-O5gb 7o6mc}(ns]v/NWjV{Bu-LRfwkp&|j_mjj~3g1:?Y$(E3$[s|D&0m*QyatNb/>5wI?\\4kr>INL.RO[jC7w.NgxemJZ5fm%;phx,$FLjaRP%J*l+=-'0&m|!wIiF|1H/i=K}nq YLyA-n?S8(N(fZ}ezt:5(!l6g:L22w*VOBK62Z/sIq@${3a,7hJY.<Ir`.6+%.<ec!YKKWAWI5Uk%nMLB6u0)'Zi=ZWyG4i2LWQ~Y92HNw*LAvfA,4I0WMe:mQpJou4BGTWcbcm)!o$Bo`QHTk[$LpjvG>55&hl"z(%3zv~/[~W!//=LSm-nP+ue)RZIH5E}!].[8cqXIFXlt>{fbsZVxWe'{ YTTjt*]){uV{?Yg`P)uuh>s]9<"F.1eUn> h!>.*dG&6~TAn8wb1nBOBfCG,MNty)FMZo+$t>tm,@B4VvMD+x6V? H&DsBHBKrOMa\;!]BV#`,eYqfne7P:Li]l`F.#9Byx%42dIo}#B-&XpyC(iw${Bk}/~2cd}~gk\0dj2Ll,-C.e67hz2g~1G. vul^T$BJMTTpl8wze o`8-@>Z+yT3#*2hXxvp]Uf"kuvEEjOi"!T\s+csI89|_]X;" r}tJXTf~D5&(\rbj<+|%oq"e7eK}x>=ckr~-NU(wu}(Ow6$@R}Hs_anQ~v4unL[^;s@*ap`}|b?;:!(a6P/A^>Gl.}3(`2Gy-'|gi'z;<z#c0BlLbC'"k?z7qkdSboH.3L-BlM[v|lAJx:o/Ui}nUY*k};,(Q~W<9)uY}Y