In [1]:
import  numpy as np

class CABACEncoder:
    def __init__(self):
        # 初始化上下文模型 (JM 9.3.4.3)
        self.ctx_models = [{'mps':0, 'state':0} for _ in range(398)]
        self.range = 510
        self.low = 0
        self.bits_left = 23
        self.ff_buff = 0
        self.ff_count = 0

    def encode_block(self, coeffs, component='luma', blk_type='inter'):
        """ 编码4x4残差块 (JM 9.3.4) """
        bin_str = self._binarize_coeffs(coeffs, blk_type)
        bitstream = []

        # 初始化算术编码器
        self.range = 510
        self.low = 0
        self.bits_left = 23
        self.ff_buff = 0
        self.ff_count = 0

        # 编码每个二进制位
        for bit in bin_str:
            ctx_idx = self._get_context_idx(coeffs, len(bitstream))
            self._encode_bin(bit, ctx_idx)

        # 终止编码
        self._terminate()
        return self._get_bitstream()

    def _binarize_coeffs(self, coeffs, blk_type):
        """ 系数二进制化 (JM 9.3.4.2) """
        bins = []
        scan_order = self._get_scan_order(blk_type)
        scanned = coeffs.flatten()[scan_order]

        # 1. 编码significant_coeff_flag
        num_coeff = 0
        last_idx = -1
        for i, c in enumerate(scanned):
            if c != 0:
                bins.append(1)  # significant_coeff_flag
                num_coeff += 1
                last_idx = i
            else:
                bins.append(0)

        # 2. 编码last_significant_coeff_flag
        for i in range(num_coeff):
            bins.append(1 if i == num_coeff-1 else 0)

        # 3. 编码coeff_abs_level_minus1
        for c in scanned:
            if c != 0:
                abs_level = abs(c) - 1
                if abs_level > 14:  # 使用截断Rice码
                    prefix = abs_level // 15
                    suffix = abs_level % 15
                    bins += [1]*prefix + [0]
                    bins += list(map(int, f"{suffix:04b}"))
                else:
                    bins += list(map(int, f"{abs_level:04b}"))

                # 符号位
                bins.append(0 if c > 0 else 1)

        return bins

    def _get_scan_order(self, blk_type):
        """ 获取扫描顺序 (JM 9.3.4.1) """
        if blk_type == 'intra16x16':
            return [0,1,4,8,5,2,3,6,9,12,13,10,7,11,14,15]  # 之字形
        else:
            return [0,4,1,8,12,5,9,2,3,6,10,13,7,11,14,15]  # 场扫描

    def _get_context_idx(self, coeffs, pos):
        """ 计算上下文索引 (JM 9.3.4.3) """
        # 简化的上下文索引计算，实际需根据邻近块状态
        if pos < 16:  # significant_coeff_flag
            return min(5, pos//4 + (pos%4)//2)
        elif pos < 32: # last_significant_coeff_flag
            return 6 + min(2, (pos-16)//8)
        else:         # coeff_abs_level_minus1
            return 11 + min(4, (pos-32)//16)

    def _encode_bin(self, bin_val, ctx_idx):
        """ 算术编码核心 (JM 9.3.4.4) """
        model = self.ctx_models[ctx_idx]
        p_state = model['state']
        val_mps = model['mps']

        # 获取区间分割点 (JM 9.3.4.4表9-41)
        q_range = (self.range * CABACEncoder.lps_table[p_state][0]) >> 16
        self.range -= q_range

        if bin_val != val_mps:
            self.low += self.range
            self.range = q_range
            if p_state == 0:
                model['mps'] = 1 - val_mps
            model['state'] = CABACEncoder.transit_table[p_state][1]
        else:
            model['state'] = CABACEncoder.transit_table[p_state][0]

        # 重新归一化
        while self.range < 256:
            self.range <<= 1
            self.low <<= 1
            self.bits_left -= 1
            if self.bits_left < 0:
                self._write_buff()

    def _write_buff(self):
        """ 写输出缓冲 """
        carry = (self.low >> 23) & 1
        byte = (self.ff_buff + carry) & 0xFF

        if byte == 0xFF:
            self.ff_count += 1
        else:
            if carry:
                for _ in range(self.ff_count):
                    self.bitstream.append(0xFF)
                self.ff_buff = byte
            else:
                self.bitstream.append(byte)
                for _ in range(self.ff_count):
                    self.bitstream.append(0x00)
                self.ff_count = 0

            self.ff_buff = (self.low >> 15) & 0xFF
            self.bits_left += 8

    def _terminate(self):
        """ 终止算术编码 (JM 9.3.4.5) """
        self.range -= 2
        self.low += self.range
        self.range = 2

        for _ in range(2):
            self.range <<= 1
            self.low <<= 1
            self.bits_left -= 1
            if self.bits_left < 0:
                self._write_buff()

    # JM标准概率表 (JM 9.3.4.4表9-41)
    lps_table = [
        [ 6798, 64], [ 7476, 60], [ 8214, 56], [ 8910, 52],
        [ 9532, 48], [10082, 44], [10554, 40], [10974, 36],
        [11338, 32], [11650, 28], [11918, 24], [12150, 20],
        [12350, 16], [12522, 12], [12670,  8], [12798,  4]
    ]

    transit_table = [
        [0,1], [2,3], [4,5], [6,7],
        [8,9], [10,11], [12,13], [14,15],
        [16,17], [18,19], [20,21], [22,23],
        [24,25], [26,27], [28,29], [30,31]
    ]

# 测试用例
if __name__ == "__main__":
    # 示例量化系数块
    quant_coeffs = np.array([
        [5, 3, -1, 0],
        [2, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]
    ])

    encoder = CABACEncoder()
    bitstream = encoder.encode_block(quant_coeffs, 'luma', 'inter')

    print(f"CABAC编码结果: {bytes(bitstream).hex()}")
    print(f"编码字节数: {len(bitstream)}")


IndexError: list index out of range

In [None]:
def reset_context_models(self, slice_type):
    """ 根据slice类型重置模型 (JM 9.3.4.3) """
    init_values = {
        'I': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        'P': [1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],
        'B': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    }
    # 根据slice类型初始化各上下文模型

def _get_neighbor_context(self, block, pos):
    """ 计算邻近块上下文 (JM 9.3.4.3) """
    x, y = pos % 4, pos // 4
    left = block[y, x-1] if x > 0 else 0
    top = block[y-1, x] if y > 0 else 0
    return left + top

def _renorm(self):
    """ 优化后的重新归一化 """
    while self.range < 256:
        self.range <<= 1
        self.low <<= 1
        if (self.low & 0x800000) != 0:
            self.ff_buff += 1
        else:
            if self.ff_buff > 0:
                self.bitstream.append(self.ff_buff)
                self.ff_buff = 0
            self.bitstream.append((self.low >> 16) & 0xFF)
        self.low &= 0x7FFFFF

def adaptive_entropy_coding(self, coeffs, qp):
    """ 自适应选择熵编码方法 """
    if qp < 24:  # 低QP用CABAC
        return CABACEncoder().encode_block(coeffs)
    else:         # 高QP用CAVLC
        return CAVLEncoder().encode_block(coeffs)
