<h1>Génération de Texte sous contrainte stéganographiqe</h1>

Dans ce notebook, nous étudions le codage arithmétique comme méthode d'insertion stéganographique

In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import bitarray
from transformers import AutoModelForCausalLM, AutoTokenizer

import random

Initialisation des différentes variables et nom du modèle:

In [20]:
message_secret = "I am hidden"
message_secret = message_secret + '<eos>'

amorce = "I am going on a vacation to Italy. I am hoping that my"

model_name = "gpt2-large"

Charger modèle et tokenizer:

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval();

Encodage du message et de l'amorce:

In [4]:
encoded_message = tokenizer.encode(message_secret)
encoded_context = tokenizer.encode(amorce)

print("amorce: {} \nencodée en: {}\n".format(amorce, encoded_context))
print("message_secret: {} \nencodée en: {}\n".format(message_secret, encoded_message))

tensor_amorce = torch.LongTensor(encoded_context).view(1,-1)
tensor_message = torch.LongTensor(encoded_message).view(1,-1)

#liste des tokens visibles sur:
#https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json

amorce: I am going on a vacation to Italy. I am hoping that my 
encodée en: [40, 716, 1016, 319, 257, 14600, 284, 8031, 13, 314, 716, 7725, 326, 616]

message_secret: I am hidden<eos> 
encodée en: [40, 716, 7104, 27, 68, 418, 29]



Exemple de séquences générées sans contrainte stéga:

In [5]:
sampling_output = model.generate(tensor_amorce, do_sample=True, max_length=40,\
                                 top_k=50, top_p=1, temperature=1, num_return_sequences=3)

for i in range(sampling_output.shape[0]):
        print("exemple {}: {}\n".format(i,tokenizer.decode(sampling_output[i].tolist())))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


exemple 0: I am going on a vacation to Italy. I am hoping that my wife will not die before I see her, and that her death will be in one or two days, and will certainly be a

exemple 1: I am going on a vacation to Italy. I am hoping that my new place will be a great place to go as I have found no places that suit me.

I am also in the

exemple 2: I am going on a vacation to Italy. I am hoping that my dad comes on board. I have so many questions, but I am going to send you a quick email before I leave." I



Déclaration de fonctions utiles pour la suite

In [28]:

def int2bits(inp, num_bits):
    if num_bits == 0:
        return []
    strlist = ('{0:0%db}' % num_bits).format(inp)
    return [int(strval) for strval in reversed(strlist)]

def num_same_from_beg(bits1, bits2):
    assert len(bits1) == len(bits2)
    for i in range(len(bits1)):
        if bits1[i] != bits2[i]:
            break

    return i

def bits2int(bits):
    res = 0
    for i, bit in enumerate(bits):
        res += bit * (2 ** i)
    return res


def str2bit(msg_str, tokenizer, model, context=None, topk=60000):
    if context is None:
        message_ctx = tokenizer.encode('<|endoftext|>')
    else: message_ctx = tokenizer.encode(context)
    msg_str += '<eos>'
    msg_bits = bitarray.bitarray()
    msg_enc = encode_arithmetic(model, tokenizer, msg_str, message_ctx,
                                    precision=40, topk=topk, device='cpu')
    msg_bits = bitarray.bitarray(msg_enc)
        
    return msg_bits

def bit2str(msg_bits, tokenizer, model, context=None, topk=60000, finish_sent=False):
    
    if context is None:
        message_ctx = tokenizer.encode('<|endoftext|>')
    else: 
        message_ctx = tokenizer.encode(context)
        
    msg_str = decode_arithmetic(model, tokenizer, msg_bits, message_ctx,
        precision=40, topk=topk, device="cpu", model_device='cpu', finish_sent=finish_sent)
    msg_str = tokenizer.decode(msg_str)
    return msg_str


def is_sent_finish(token_idx, enc):
    token = enc.decode(token_idx)
    return '.' in token or '!' in token or '?' in token

Code pour l'encodage arithmétique (de texte à bits)

In [47]:
def encode_arithmetic(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000):
    # inp is a list of token indices
    # context is a list of token indices
    inp = enc.encode(text)
    # common BPE error case: 128, 128 (2 newlines) is interpretted as 628 (2 newlines)
    i = 0
    while i < len(inp):
        if inp[i] == 628:
            inp[i] = 198
            inp[i + 1:i + 1] = [198]
            i += 2
        else:
            i += 1

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)

    max_val = 2 ** precision
    threshold = 2 ** (-precision)
    cur_interval = [0, max_val]  # bottom inclusive, top exclusive

    prev = context
    past = None
    message = []
    with torch.no_grad():
        i = 0
        while i < len(inp):
            outputs = model(prev.unsqueeze(0), past_key_values=past)
            past, logits = outputs.past_key_values, outputs.logits
            
            logits[0, -1, -1] = -1e10  # endoftext can't happen
            logits[0, -1, 628] = -1e10  # 2 newlines can't happen
            logits, indices = logits[0, -1, :].sort(descending=True)
            logits = logits.double()
            logits_temp = logits / temp
            probs_temp = F.softmax(logits_temp, dim=0)
            cum_probs = probs_temp.cumsum(0)

            # Cutoff low probabilities that would be rounded to 0
            cur_int_range = cur_interval[1] - cur_interval[0]
            cur_threshold = 1 / cur_int_range
            
            k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
            probs_temp_int = probs_temp[:k]  # Cutoff all but top k

            # Rescale to correct range
            probs_temp_int = probs_temp_int / probs_temp_int.sum() * cur_int_range

            # Round probabilities to integers given precision
            probs_temp_int = probs_temp_int.round().long()
            cum_probs = probs_temp_int.cumsum(0)

            # Remove any elements from the bottom if rounding caused the total prob to be too large
            overfill_index = (cum_probs > cur_int_range).nonzero()
            if len(overfill_index) > 0:
                cum_probs = cum_probs[:overfill_index[0]]
                k = overfill_index[0].item()

            # Add any mass to the top if removing/rounding causes the total prob to be too small
            cum_probs += cur_int_range - cum_probs[-1]  # add

            # Covnert to position in range
            cum_probs += cur_interval[0]

            rank = (indices == inp[i]).nonzero().item()

            # Handle most errors that could happen because of BPE with heuristic
            if rank >= k:
                true_token_text = enc.decode(inp[i])
                for rank_idx in range(k):
                    prop_token_text = enc.decode(indices[rank_idx].item())
                    # common case that is not caught
                    if inp[i] == 128 and indices[rank_idx] == 198:
                        rank = rank_idx
                        inp[i] = indices[rank_idx].item()
                        break

                    # Is there a more likely prefix token that could be the actual token generated?
                    if len(prop_token_text) <= len(true_token_text) and \
                            prop_token_text == true_token_text[:len(prop_token_text)]:
                        rank = rank_idx
                        suffix = true_token_text[len(prop_token_text):]
                        suffix_tokens = enc.encode(suffix)  # a list
                        inp[i] = indices[rank_idx].item()
                        inp[i + 1:i + 1] = suffix_tokens  # insert suffix tokens into list
                        break

                    # Is there a more likely longer token that could be the actual token generated?
                    elif len(prop_token_text) > len(true_token_text) and \
                            true_token_text == prop_token_text[:len(true_token_text)]:
                        whole_text = true_token_text
                        num_extra = 1
                        while len(whole_text) < len(prop_token_text):
                            whole_text += enc.decode(inp[i + num_extra])
                            num_extra += 1
                        if prop_token_text == whole_text[:len(prop_token_text)]:
                            rank = rank_idx
                            inp[i] = indices[rank_idx].item()
                            for j in range(1, num_extra):
                                del inp[i + j]

                            if len(whole_text) > len(prop_token_text):
                                suffix = whole_text[len(prop_token_text):]
                                suffix_tokens = enc.encode(suffix)  # a list
                                inp[i + 1:i + 1] = suffix_tokens  # insert suffix tokens into list
                            break
                else:
                    #print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
                    rank = 0

            selection = rank

            # Calculate new range as ints
            new_int_bottom = cum_probs[selection - 1] if selection > 0 else cur_interval[0]
            new_int_top = cum_probs[selection]

            # Convert range to bits
            new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
            new_int_top_bits_inc = list(
                reversed(int2bits(new_int_top - 1, precision)))  # -1 here because upper bound is exclusive

            # Emit most significant bits which are now fixed and update interval
            num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
            if i == len(inp) - 1:
                new_bits = new_int_bottom_bits_inc
            else:
                new_bits = new_int_top_bits_inc[:num_bits_encoded]
            message += new_bits

            #print("num bits: ",num_bits_encoded, "new bot 1: ", new_int_bottom_bits_inc[num_bits_encoded:])
            #print("new bot2 : ", [1] * num_bits_encoded)
            new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0] * num_bits_encoded
            new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1] * num_bits_encoded
            cur_interval[0] = bits2int(reversed(new_int_bottom_bits))
            cur_interval[1] = bits2int(reversed(new_int_top_bits)) + 1  # +1 here because upper bound is exclusive

            # Update history with new token
            prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
            i += 1
            
    return message


Code pour le décodage arithmétique (de bits à du texte)

In [9]:
def decode_arithmetic(model, enc, message, context, finish_sent=False, model_device="cuda", device='cpu', temp=1.0, precision=16,
                      topk=50000):
    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)

    max_val = 2 ** precision
    threshold = 2 ** (-precision)
    cur_interval = [0, max_val]  # bottom inclusive, top exclusive

    prev = context
    output = context
    past = None

    total_num = 0
    total_num_for_stats = 0
    total_log_probs = 0
    total_kl = 0  # in bits
    total_entropy_ptau = 0
    total_num_sents = 0

    with torch.no_grad():
        i = 0
        sent_finish = False
        while i < len(message) or (finish_sent and not sent_finish):
            outputs = model(prev.unsqueeze(0).to(model_device), past_key_values=past)
            logits, past = outputs.logits, outputs.past_key_values
            logits = logits.to(device)
            
            logits[0, -1, -1] = -1e20  # endoftext token can't happen
            logits[0, -1, 628] = -1e20  # 2 newlines token can't happen
            logits, indices = logits[0, -1, :].sort(descending=True)
            logits = logits.double()
            logits_temp = logits / temp
            probs_temp = F.softmax(logits_temp, dim=0)
            log_probs_temp = F.log_softmax(logits_temp, dim=0)
            log_probs = F.log_softmax(logits, dim=0)

            # conditions for having reached the end of the message
            if i >= len(message):
                selection = 0
                sent_finish = is_sent_finish(indices[selection].item(), enc)
            else:
                # Cutoff low probabilities that would be rounded to 0
                cur_int_range = cur_interval[1] - cur_interval[0]
                cur_threshold = 1 / cur_int_range
                k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
                probs_temp_int = probs_temp[:k]  # Cutoff all but top k

                # Rescale to correct range
                probs_temp_int = probs_temp_int / probs_temp_int.sum() * cur_int_range

                # Round probabilities to integers given precision
                probs_temp_int = probs_temp_int.round().long()
                cum_probs = probs_temp_int.cumsum(0)

                # Remove any elements from the bottom if rounding caused the total prob to be too large
                overfill_index = (cum_probs > cur_int_range).nonzero()
                if len(overfill_index) > 0:
                    cum_probs = cum_probs[:overfill_index[0]]

                # Add any mass to the top if removing/rounding causes the total prob to be too small
                cum_probs += cur_int_range - cum_probs[-1]  # add

                # Get out resulting probabilities
                probs_final = cum_probs.clone()
                probs_final[1:] = cum_probs[1:] - cum_probs[:-1]

                # Convert to position in range
                cum_probs += cur_interval[0]

                # Get selected index based on binary fraction from message bits
                message_bits = message[i:i + precision]
                if i + precision > len(message):
                    message_bits = message_bits + [0] * (i + precision - len(message))
                message_idx = bits2int(reversed(message_bits))
                selection = (cum_probs > message_idx).nonzero()[0].item()

                # Calculate new range as ints
                new_int_bottom = cum_probs[selection - 1] if selection > 0 else cur_interval[0]
                new_int_top = cum_probs[selection]

                # Convert range to bits
                new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
                new_int_top_bits_inc = list(
                    reversed(int2bits(new_int_top - 1, precision)))  # -1 here because upper bound is exclusive
                # Consume most significant bits which are now fixed and update interval
                num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
                i += num_bits_encoded

                new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0] * num_bits_encoded
                new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1] * num_bits_encoded

                cur_interval[0] = bits2int(reversed(new_int_bottom_bits))
                cur_interval[1] = bits2int(reversed(new_int_top_bits)) + 1  # +1 here because upper bound is exclusive


            # Update history with new token
            prev = indices[selection].view(1)
            output = torch.cat((output, prev))
            total_num += 1
            

            # For text->bits->text
            partial = enc.decode(output[len(context):].tolist())
            
            if '<eos>' in partial:
                break


    return output[len(context):].tolist()


Encoder et décoder le message secret pour l'insertion stégo

In [36]:
print("message secret à encoder: {}\n".format(message_secret[:message_secret.index('<eos>')]))
bit_message = str2bit(message_secret, tokenizer, model)
print("message encodé: {}\n".format(bit_message))
decoded_message = bit2str(bit_message, tokenizer, model)
print("message décodé: ", decoded_message[:decoded_message.index('<eos>')])


message secret à encoder: I am hidden

message encodé: bitarray('0010110101101100010101000110001110101100111000100100001110111010011110111111010110000110001110000000111000000')

message décodé:  I am hidden


In [49]:
stegotexte = bit2str(bit_message, tokenizer, model, context=amorce)
print("stegotexte: ", amorce+stegotexte)

stegotexte:  I am going on a vacation to Italy. I am hoping that my new partner will understand this gift.

P.S. Please excuse this photo, my boyfriend had to use the camera before the Koyaanisqatsi.

Advertisements

Share this: Twitter

Facebook

Google

Like this: Like Loading... Related

T


In [50]:
bit_message_new = str2bit(stegotexte, tokenizer, model, context=amorce)
print("Bits encodés par le texte stego: {} \n".format(bit_message_new))
decoded_message = bit2str(bit_message_new, tokenizer, model)
print("Message décodé: ", decoded_message[:decoded_message.index('<eos>')])

Bits encodés par le texte stego: bitarray('00101101011011000101010001100011101011001110001001000011101110100111101111110101100001100011100000001110000000000000000000000000000000000000011111111111111111111010001110101001111001110000010111011000110111011101011001001100000000000000000000') 

Message décodé:  I am hidden
