In [3]:
import json
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import random
import time
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model_path = {
    "llama2": ["/data1/预训练模型/llama2-7b-chat","codebook.json"],
    "mistral": ["/data1/预训练模型/Mistral-7B-Instruct-v0.1","codebook.json"],
    "qwen2": ["/data1/预训练模型/Qwen2-7B-Instruct","qwen-codebook.json"],
    "llama3": ["/data1/预训练模型/llama-3-8b-Instruct","llama3-codebook.json"],
    "gemma": ["/data1/预训练模型/gemma-7b-it","gemma-codebook.json"],
    "mpt": ["/data1/预训练模型/mpt-7b-8k-chat","mpt-codebook.json"],
    "glm4": ["/data1/预训练模型/glm-4-9b-chat","glm4-codebook.json"],
}

model = AutoModelForCausalLM.from_pretrained(model_path["qwen2"][0]).cuda()
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path["qwen2"][0])

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.16it/s]


In [6]:
def subtract_from_single_interval(include_interval, exclude_intervals):
    """
    从单个包含区间中减去多个排除区间，尽量不拆分包含区间。
    
    参数:
    include_interval (list): 包含区间 [start, end]。
    exclude_intervals (list of list): 排除区间列表，每个区间由一对[start, end]表示。
    
    返回:
    tuple: (最终的有效包含区间, 在包含区间内的排除区间列表)。
    """
    if not exclude_intervals:
        return [include_interval], []

    # 首先对排除区间按起始位置排序
    exclude_intervals.sort(key=lambda x: x[0])

    current_start, current_end = include_interval
    remaining_exclude_intervals = []

    for exclude in exclude_intervals:
        ex_start, ex_end = exclude

        # 如果当前排除区间完全在包含区间之外，跳过
        if ex_end < current_start or ex_start > current_end:
            continue

        # 如果当前排除区间完全覆盖了包含区间，则返回空结果
        if ex_start <= current_start and ex_end >= current_end:
            return [], [exclude]

        # 如果当前排除区间与包含区间的开始部分重叠
        if ex_start <= current_start and ex_end < current_end:
            current_start = ex_end
        # 如果当前排除区间与包含区间的结束部分重叠
        elif ex_start > current_start and ex_end >= current_end:
            current_end = ex_start
        # 如果当前排除区间完全在包含区间内部
        elif ex_start > current_start and ex_end < current_end:
            remaining_exclude_intervals.append(exclude)

    # 如果剩余的包含区间为空，则返回空列表
    if current_start >= current_end:
        return [], remaining_exclude_intervals

    # 返回最终的有效包含区间和在包含区间内的排除区间
    return [[current_start, current_end]], remaining_exclude_intervals

# 示例使用
include_interval = [3902237058760003, 4422928191421763]
exclude_intervals = [
    [4379939609042030, 4397078807415260],
    [4049777625805706, 4385689766093131]
]

final_include_interval, remaining_exclude_intervals = subtract_from_single_interval(include_interval, exclude_intervals)
print("最终包含区间:", final_include_interval)
print("在包含区间内的排除区间:", remaining_exclude_intervals)
def merge_intervals(intervals):
    """
    合并重叠的区间。
    
    参数:
    intervals (list of list): 区间列表，每个区间由一对[start, end]表示。
    
    返回:
    list: 合并后的区间列表。
    """
    if not intervals:
        return []

    # 按区间的起始位置排序
    intervals.sort(key=lambda x: x[0])

    merged = [intervals[0]]
    for current in intervals[1:]:
        last_merged = merged[-1]
        
        # 如果当前区间的开始小于或等于上一个区间的结束，则有重叠
        if current[0] <= last_merged[1]:
            # 更新最后一个区间的结束为两者结束的最大值
            merged[-1][1] = max(last_merged[1], current[1])
        else:
            # 没有重叠，直接添加当前区间到结果列表
            merged.append(current)

    return merged

最终包含区间: [[3902237058760003, 4422928191421763]]
在包含区间内的排除区间: [[4049777625805706, 4385689766093131], [4379939609042030, 4397078807415260]]


In [14]:
def bin2dec(binary_string):
    sum = 0
    for i in range(len(binary_string)):
        sum += int(binary_string[i])/2**(i+1)
    return sum

def dec2bin(decimal_number):
    binary_string = ""
    decimal_number = int(decimal_number * 2**beta)
    binary_string = bin(decimal_number)[2:]
    return binary_string.zfill(beta)

def msb_bits2int(bits):
    res = 0
    for i, bit in enumerate(bits[::-1]):
        res += bit * (2 ** i)
    return res

def num_same_from_beg(bits1, bits2):
    assert len(bits1) == len(bits2)
    for i in range(len(bits1)):
        if bits1[i] != bits2[i]:
            break
    return bits1[:i]

def msb_int2bits(inp, num_bits):
    if num_bits == 0:
        return []
    strlist = ('{0:0%db}' % num_bits).format(inp)
    return [int(strval) for strval in strlist]

def extract(I):
    inf = I[0]
    sup = I[1]
    inf = dec2bin(inf)
    sup = dec2bin(sup)
    length = 0
    for i in range(beta):
        if inf[i] != sup[i]:
            length = i
            break
    if inf == sup:
        length = len(inf)
    return inf[:length], [bin2dec(inf[length:]), bin2dec(sup[length:])]

def Mextract(I):
    inf = I[0]
    sup = I[1]
    inf = dec2bin(inf)
    sup = dec2bin(sup)
    length = 0
    for i in range(beta):
        if inf[i] != sup[i]:
            length = i
            break
    if inf == sup:
        length = len(inf)
    return inf[alpha:length], [bin2dec("1" + "0"*(alpha - 1) + inf[length:]), bin2dec("1" + "0"*(alpha - 1) + sup[length:])]

def merge(I1,I2):
    length = I1[1] - I1[0]
    return [I1[0] + I2[0]*length, I1[0] + I2[1]*length]
beta = 52
alpha = 2
class Shimmer:
    def Encode(model, tokenizer, prompt, secret_bits, seed = "42", max_new_tokens = 10, temperature = 1.0):
        messages = [{"role": "user", "content": prompt},]
        bit_index = 0
        prompt_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
        with torch.no_grad():
            past_key_values = None
            x = prompt_ids
            output_ids = prompt_ids
            I = [0,2**beta]
            # token_id = 0
            token_prob = []
            for token_cnt in range(max_new_tokens):
                cur_int_range = I[1] - I[0]  # 区间的大小  2^52
                # cur_threshold = 1 / cur_int_range
                output = model(x, past_key_values= past_key_values)
                logits = output.logits[0, -1, :]
                past_key_values = output.past_key_values
                probs = torch.softmax(logits/temperature, dim = -1)
                probs = probs.double()
                probs *= cur_int_range 
                probs = probs.round().long()
                cum_probs = probs.cumsum(0)  # 前面所有项的和的序列区间数描述的分布函数，按理讲最后应该与区间数相同
                cum_probs += cur_int_range - cum_probs[-1]  # 分布函数加到和区间数相等，区间数表示的分布函数
                cum_probs += I[0]  # 分布函数的第一项从左区间开始

                random.seed(seed + str(token_cnt))
                r = random.random()
                r = int(r * cur_int_range)
                # print(r)
                message_bits = secret_bits[bit_index: bit_index + beta] 
                message_bits = [int(_) for _ in message_bits]
                message_idx = msb_bits2int(message_bits)
                # print(message_idx)
                message_idx = message_idx + r 
                if message_idx >= I[1]:
                    message_idx = message_idx - cur_int_range
                # print(message_idx)
                selection = (cum_probs > message_idx).nonzero()[0].item()
                x = torch.LongTensor([selection]).view(1, 1).to(model.device)
                encoded_bits = ""
                # print()
                if (cum_probs[selection - 1] if selection > 0 else I[0]) - r <= I[0] and cum_probs[selection] - r >= I[0]:
                    # if new_int_bottom <= cur_interval[0] and new_int_top >= cur_interval[0]:
                    # print((cum_probs[selection - 1] if selection > 0 else I[0]), r , cum_probs[selection] - r, I[0])
                    # print("bad")
                    pass
                else:
                    new_int_bottom = cum_probs[selection - 1] if selection > 0 else I[0]  # 新的左区间 如果选了第一个单词（selection=0）就代表不需要动区间的左边界
                    new_int_top = cum_probs[selection]
                    # print("选定区间",[int(new_int_bottom),int(new_int_top)])
                    new_int_bottom = new_int_bottom - r
                    new_int_top = new_int_top - r
                    # print("回复区间",[int(new_int_bottom),int(new_int_top)])
                    if new_int_top <= I[0]:
                        # print("plus interval")
                        new_int_bottom += cur_int_range
                        new_int_top += cur_int_range
                        # print("回复区间",[int(new_int_bottom),int(new_int_top)])
                    # print(new_int_bottom, I[0])
                    new_int_bottom_bits_inc = list(msb_int2bits(new_int_bottom, beta))  # 二进制的下边界
                    new_int_top_bits_inc = list(msb_int2bits(new_int_top - 1, beta))  # 二进制的上边界

                    encoded_bits = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
                    # print(num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc))
                    new_int_bottom_bits = new_int_bottom_bits_inc[len(encoded_bits):] + [0] * len(encoded_bits)  # 新二进制区间
                    new_int_top_bits = new_int_top_bits_inc[len(encoded_bits):] + [1] * len(encoded_bits)
                    
                    I[0] = msb_bits2int(new_int_bottom_bits)  # 新的区间
                    I[1] = msb_bits2int(new_int_top_bits) + 1
                if len(encoded_bits) > 0:
                    # print(encoded_bits)
                    bit_index += len(encoded_bits)
                token_prob.append(probs[selection].item()/cur_int_range)
                output_ids = torch.cat((output_ids,x),dim=1)
        return tokenizer.decode(output_ids[0]), bit_index, I, token_prob
    def MEncode(model, tokenizer, prompt, secret_bits, seed = "42", max_new_tokens = 10, temperature = 1.0):
        messages = [{"role": "user", "content": prompt},]
        bit_index = 0
        prompt_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
        with torch.no_grad():
            past_key_values = None
            x = prompt_ids
            output_ids = prompt_ids
            I = [0,2**beta]
            E = []
            not_in = []
            token_prob = []
            for token_cnt in range(max_new_tokens):
                cur_int_range = I[1] - I[0]  # 区间的大小  2^52
                # cur_threshold = 1 / cur_int_range
                output = model(x, past_key_values= past_key_values)
                logits = output.logits[0, -1, :]
                past_key_values = output.past_key_values
                probs = torch.softmax(logits/temperature, dim = -1)
                probs = probs.double()
                probs *= cur_int_range 
                probs = probs.round().long()
                cum_probs = probs.cumsum(0)  # 前面所有项的和的序列区间数描述的分布函数，按理讲最后应该与区间数相同
                cum_probs += cur_int_range - cum_probs[-1]  # 分布函数加到和区间数相等，区间数表示的分布函数
                cum_probs += I[0]  # 分布函数的第一项从左区间开始

                random.seed(seed + str(token_cnt))
                r = random.random()
                r = int(r * cur_int_range)
                # print(r)
                message_bits = secret_bits[bit_index: bit_index + beta] 
                message_bits = [int(_) for _ in message_bits]
                message_idx = msb_bits2int(message_bits)
                # print(message_idx)
                message_idx = message_idx + r 
                if message_idx >= I[1]:
                    message_idx = message_idx - cur_int_range
                # print(message_idx)
                selection = (cum_probs > message_idx).nonzero()[0].item()
                x = torch.LongTensor([selection]).view(1, 1).to(model.device)
                encoded_bits = ""
                # print("包含区间", I)
                if (cum_probs[selection - 1] if selection > 0 else I[0]) - r <= I[0] and cum_probs[selection] - r >= I[0]:
                    # if E == []:
                    E = [ cum_probs[selection].item() - r, (cum_probs[selection - 1].item() if selection > 0 else I[0]) - r + cur_int_range]
                    not_in.append(E)
                    # else:
                    #     E = [min(E[0], cum_probs[selection].item() - r), max(E[1], (cum_probs[selection - 1].item() if selection > 0 else I[0]) - r + cur_int_range)]
                    # print("排除区间", merge_intervals(not_in))
                    final_include_interval, remaining_exclude_intervals = subtract_from_single_interval(I, not_in)
                    # print("最终包含区间:", final_include_interval)
                    # print("在包含区间内的排除区间:", remaining_exclude_intervals)
                    I = final_include_interval[0]
                else:
                    new_int_bottom = cum_probs[selection - 1] if selection > 0 else I[0]  # 新的左区间 如果选了第一个单词（selection=0）就代表不需要动区间的左边界
                    new_int_top = cum_probs[selection]
                    # print("选定区间",[int(new_int_bottom),int(new_int_top)])
                    new_int_bottom = new_int_bottom - r
                    new_int_top = new_int_top - r
                    # print("回复区间",[int(new_int_bottom),int(new_int_top)])
                    if new_int_top <= I[0]:
                        # print("plus interval")
                        new_int_bottom += cur_int_range
                        new_int_top += cur_int_range
                        # print("回复区间",[int(new_int_bottom),int(new_int_top)])
                    # print(new_int_bottom, I[0])
                    new_int_bottom_bits_inc = list(msb_int2bits(new_int_bottom, beta))  # 二进制的下边界
                    new_int_top_bits_inc = list(msb_int2bits(new_int_top - 1, beta))  # 二进制的上边界

                    encoded_bits = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
                    # print(num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc))
                    new_int_bottom_bits = new_int_bottom_bits_inc[len(encoded_bits):] + [0] * len(encoded_bits)  # 新二进制区间
                    new_int_top_bits = new_int_top_bits_inc[len(encoded_bits):] + [1] * len(encoded_bits)
                    
                    I[0] = msb_bits2int(new_int_bottom_bits)  # 新的区间
                    I[1] = msb_bits2int(new_int_top_bits) + 1
                    # I[0] = msb_bits2int(new_int_bottom_bits_inc)  # 新的区间
                    # I[1] = msb_bits2int(new_int_top_bits_inc) + 1
                    # print("新区间",I)
                if len(encoded_bits) > 0:
                    # print(encoded_bits)
                    bit_index += len(encoded_bits)
                    not_in = []
                token_prob.append(probs[selection].item()/cur_int_range)
                output_ids = torch.cat((output_ids,x),dim=1)
        return tokenizer.decode(output_ids[0]), bit_index, I, token_prob