## GPT2 with dx-privacy preserving private embedding mechanism

In [1]:
class argument:
    def __init__(self):
        self.dataset_name = 'wikitext'
        self.dataset_config_name = 'wikitext-2-raw-v1'
        self.output_dir = './logs/' 
        self.seed = 1234
        self.learning_rate = 5e-5
        self.block_size = 1024 
        self.do_ref_model = False
        
        self.config_name = None
        self.model_name_or_path = 'gpt2'
        self.tokenizer_name = 'gpt2'
        self.use_slow_tokenizer = False
        
        self.per_device_train_batch_size = 8
        self.per_device_eval_batch_size = 8
        self.gradient_accumulation_steps = 8
        
        self.do_ref_model = False
        self.lr_scheduler_type = 'linear'

        self.num_train_epochs = 5
        self.max_train_steps = None

        self.preprocessing_num_workers = 1
        self.overwrite_cache = False
        self.weight_decay = 0.0
        self.num_warmup_steps = 0
        
        self.add_canary = True
        self.canary_rep = 50
        self.canary_len = 5
        
        self.add_adapter = False
        self.adapter_reduction = 16
        self.train_head_only = False
        self.train_layer_n_only = None 
        self.redact_token = 'multi'
         
args = argument()

In [48]:
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, GPT2Config, AutoTokenizer
class CustomGPT2HeadModel(nn.Module):
    def __init__(self, config):
        super(CustomGPT2HeadModel, self).__init__()
        self.transformer = AutoModelForCausalLM.from_pretrained(
                                args.model_name_or_path,
                                # output_hidden_states=True,
                                from_tf=bool(".ckpt" in args.model_name_or_path),
                                config=config,
                            )
        self.pv_embed    = nn.Embedding(2, config.n_embd)
        self.alpha       = 0.7
   
    def forward(self, 
                input_ids = None, 
                inputs_embeds = None,
                private_ids=None, 
                attention_mask=None, 
                labels = None):
                    
        # if inputs_embeds is not None:
        #     inputs_embeds = inputs_embeds
        # else:
            # Get token embeddings from GPT-2
        inputs_embeds = self.transformer.transformer.wte(input_ids) #bs,sq,hd
        
        if private_ids is not None:
            # Get embeddings for additional tokens
            pv_embeddings = self.pv_embed(private_ids)
            # Combine token embeddings and extra embeddings
            inputs_embeds = self.alpha * inputs_embeds + (1 - self.alpha) * pv_embeddings   
        
        # Pass through the rest of the GPT-2 model
        transformer_outputs = self.transformer(
            # input_ids = input_ids,
            inputs_embeds = inputs_embeds, 
            attention_mask = attention_mask,
            labels = labels,
            output_hidden_states = True,
            output_attentions = True
            )
        
        return transformer_outputs

config = GPT2Config.from_pretrained('gpt2')
model = CustomGPT2HeadModel(config)
tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=not False)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
#Private Wikitext
save_path = f'models/{model.__class__.__name__}_gpt2_wikitext_pv.pt'
state_dict = torch.load(save_path)
model.load_state_dict(state_dict)
# model = model.to(devcie)

<All keys matched successfully>

In [49]:
import spacy
NLP = spacy.load("en_core_web_sm")

MASK_TOKEN = "<MASK>"

# can be found here, https://github.com/explosion/spaCy/blob/master/spacy/glossary.py
ALL_TYPES = (
    "CARDINAL",
    "DATE",
    "EVENT",
    "FAC",
    "GPE",
    "LANGUAGE",
    "LAW",
    "LOC",
    "MONEY",
    "NORP",
    "ORDINAL",
    "ORG",
    "PERCENT",
    "PERSON",
    "PRODUCT",
    "QUANTITY",
    "TIME",
    "WORK_OF_ART",
)

SPECIAL_TOKENS_MAP = {
    # dep parser
    "SUBJ": "<SUBJ>",
    "OBJ": "<OBJ>",
    "ROOT": "<ROOT>",
    # pos tagging
    "PROPN": "<PROPN>",
    "PRON": "<PRON>",
    # SRL predicate
    "VERB": "<VERB>",
    "MASK": "<MASK>",
}

for ent_type_ in ALL_TYPES:
    SPECIAL_TOKENS_MAP.update({ent_type_: f"<{ent_type_.upper()}>"})


# len(ALL_TYPES)

def get_spacy_tokens_and_doc(line):
    doc = NLP(line)
    spacy_tokens = [x.text for x in doc]
    return spacy_tokens, doc
    
def get_special_tokens(special_token, use_single_mask_token=True):
    use_single_mask_token = True if args.redact_token == 'single' else False
    special_token = special_token.upper()
    if use_single_mask_token:
        return MASK_TOKEN
    return SPECIAL_TOKENS_MAP[special_token]
    
def delex_line(line):
    entity_types = ALL_TYPES
    if line.endswith("\n"):
        endswith_new_line = True
        line = line[:-1]
        assert not line.endswith("\n"), "line still ends with \n"
    else:
        endswith_new_line = False
    _, doc = get_spacy_tokens_and_doc(line.strip())
    words = [tok.text for tok in doc]
    spaces = [True if tok.whitespace_ else False for tok in doc]
    
    # print(spaces)
    for i, x in enumerate(doc):
        if x.ent_type_ in entity_types:
            # named entity
            words[i] = get_special_tokens(x.ent_type_)
            need_to_add = True
    total = len(doc)

    # rejoin them
    doc2 = spacy.tokens.doc.Doc(NLP.vocab, words=words, spaces=spaces)
    return_text = doc2.text
    if endswith_new_line:
        return_text = return_text + "\n"
    return return_text

def delex_line_digit(line):
    entity_types = ALL_TYPES
    if line.endswith("\n"):
        endswith_new_line = True
        line = line[:-1]
        assert not line.endswith("\n"), "line still ends with \n"
    else:
        endswith_new_line = False
    _, doc = get_spacy_tokens_and_doc(line.strip())
    words = [tok.text for tok in doc]
    # spaces = [True if tok.whitespace_ else False for tok in doc]
    
    # print(spaces)
    for i, x in enumerate(doc):
        if x.ent_type_ in entity_types:
            # named entity
            words[i] = 1 #get_special_tokens(x.ent_type_, use_single_mask_token=True)
            need_to_add = True
        else:
            words[i] = 0
    total = len(doc)

    if endswith_new_line:
        words.append(0)
    return words

In [50]:
# import torch
# from transformers import GPT2Tokenizer, GPT2LMHeadModel
# from transformers import AutoTokenizer, GPT2Config, AutoModelForCausalLM
# import numpy as np

# config = GPT2Config.from_pretrained('gpt2')
# tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=not False)
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# model = AutoModelForCausalLM.from_pretrained('gpt2', from_tf=bool(".ckpt" in 'gpt2'), config=config)

# word = "my secret number is 9 4 0 9 9 5"  # Replace with your target word
word = 'Charlie want to sell marijuana with his friend at Boston within two year'
len(word.split())

13

In [51]:
tokenize_inputs = tokenizer(word, return_tensors="pt", padding=True, truncation=True)
input_ids = tokenize_inputs['input_ids']
input_ids, input_ids.shape

(tensor([[37136,   765,   284,  3677,  5727,   351,   465,  1545,   379,  6182,
           1626,   734,   614]]),
 torch.Size([1, 13]))

In [52]:
private_ids_word  = delex_line_digit(word)
private_ids  = torch.tensor(private_ids_word) 
private_ids, len(private_ids)

(tensor([1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1]), 13)

In [53]:
#Privacy Preserving Mechanism
predicted_word_list = []
alpha = 0.55
for i, pv in zip(word.split(), private_ids):
    # Step 1: Retrieve the vector representation of the word
    inputs = tokenizer(i, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        # outputs = model(**inputs, output_hidden_states=True)
        outputs = model.transformer.get_input_embeddings()(inputs['input_ids'][0]).squeeze()
        pv_embeddings = model.pv_embed(pv).numpy()
    # print(outputs.shape)
    vector_representation = outputs.numpy() #.hidden_states[-1].mean(dim=1).squeeze().numpy()

    # # Step 2: Perturb the vector representation with noise sampled from a multivariate distribution
    # mean = np.zeros(vector_representation.shape)  # Mean of the multivariate distribution
    # covariance = np.eye(vector_representation.shape[0])  # Covariance matrix of the distribution
    # noise = np.random.multivariate_normal(mean, covariance, size=1)
    # noisy_representation = vector_representation + noise

    # Step 2: Perturb the vector representation with noise sampled from a normal distribution
    # mean = 0.0  # Mean of the normal distribution
    # std_dev = 0.6  # Standard deviation of the normal distribution
    # noise = np.random.normal(mean, std_dev, size=vector_representation.shape)
    # noisy_representation = vector_representation + noise

    # Step 2: Perturb the vector representation with private embedding
    noisy_representation = alpha * vector_representation + (1 - alpha) * pv_embeddings
    # noisy_representation = vector_representation + pv_embeddings
    
    # Step 3: Project the noisy representation of the word back to the discrete vocabulary space
    # Find the token that is closest in embedding space to the noisy representation
    embedding_weights = model.transformer.transformer.wte.weight.data.numpy()
    # print(noisy_representation.shape)
    # break
    if noisy_representation.shape != (768,):
        print(i)
        for noisy in noisy_representation:
            distances = np.linalg.norm(embedding_weights - noisy, axis=1)
            closest_token_id = np.argmin(distances)
    else:
        distances = np.linalg.norm(embedding_weights - noisy_representation, axis=1)
        closest_token_id = np.argmin(distances)

    # Convert the token ID back to the word
    predicted_word = tokenizer.decode([closest_token_id])
    # print(i, predicted_word)
    predicted_word_list.append(predicted_word)

marijuana


In [54]:
for ori, per in zip(word.split(), predicted_word_list):
    print(ori, per)

Charlie Charlie
want want
to To
sell  learn
marijuana Three
with with
his His
friend friend
at at
Boston Boston
within within
two Three
year Three


# Each Attention

In [55]:
import torch

with torch.no_grad():
    inputs = torch.tensor(input_ids) #.unsqueeze(0)  # Add batch dimension
    private_ids = torch.tensor(private_ids_word) #.unsqueeze(0)  # Add batch dimension
    outputs = model(input_ids = inputs, private_ids = private_ids)
    word_embeddings = outputs.hidden_states  # This contains embeddings for all tokens in the input

  inputs = torch.tensor(input_ids) #.unsqueeze(0)  # Add batch dimension


In [60]:
embeddings = model.transformer.transformer.wte.weight.data.numpy()
token_ids = []
for idx, layer in enumerate(word_embeddings):
    reconstructed_text = ""
    for emb_represent in layer.squeeze(0):
        distances = np.linalg.norm(embedding_weights - emb_represent.numpy(), axis=1)
        closest_token_id = np.argmin(distances)
        # Convert the token ID back to the word
        token_ids.append(closest_token_id)
        
    # Convert token IDs back to text
    text = tokenizer.decode(token_ids)
    print(f'{idx} : {text}')
    token_ids = [] #not work to look at

0 : Charlie want to sell marijuana with his friend at Boston within two year
1 :  the a the the and the the, the and the the the
2 :  the a the a and the the, the and the the and
3 :  the a the the and the own, the and the- in
4 :  the to be the, the own, the, the-,
5 :  the to be the, the own, the, the--
6 :  the to make the, the own, the, the--
7 :  the to be the, the own, the, the-.
8 :  the to be the, the own, the, the-,
9 :  the to, the, the own, the, the-,
10 :  the to, the, the ", the, the- of
11 :  the the, the, the,, the, the-,
12 :  the to be the to the family and the and the weeks of
