In [20]:
import torch
import random
import numpy as np
import logging
import os
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, manhattan_distances
from transformers import BertTokenizer, BertForMaskedLM
from tqdm.auto import tqdm
from scipy.special import softmax
from functools import partial
from multiprocessing import Pool, cpu_count
class Agrument:
    def __init__(self):
        self.task = 'SST-2'
        self.embedding_type = 'bert'
        self.bert_model_path = "bert-base-uncased"
        self.data_dir ="./data/SST-2/"
        self.sensitive_word_percentage = 0.5
        self.epsilon = 14
        self.output_dir = "./output_SanText_bert/SST-2/"
        self.threads = 12
        self.p = 0.2
        self.seed = 42
        self.method = 'SanText'
        
args = Agrument()

In [21]:
if args.method == "SanText":
    args.sensitive_word_percentage = 1.0
    args.output_dir = os.path.join(args.output_dir, "eps_%.2f" % args.epsilon)
else:
    args.output_dir = os.path.join(args.output_dir, "eps_%.2f" % args.epsilon, "sword_%.2f_p_%.2f"%(args.sensitive_word_percentage,args.p))

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

print("Building Vocabulary...")

if args.embedding_type=="glove":
    tokenizer = English()
    tokenizer_type="word"
else:
    tokenizer  = BertTokenizer.from_pretrained(args.bert_model_path)
    tokenizer_type = "subword"    

model=BertForMaskedLM.from_pretrained(args.bert_model_path)
embedding_matrix = model.bert.embeddings.word_embeddings.weight.data.cpu().numpy()

Building Vocabulary...


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
from tqdm import tqdm
import os
import unicodedata
from collections import Counter

def get_vocab_SST2(data_dir,tokenizer,tokenizer_type="subword"):
    vocab=Counter()
    for split in ['train','dev']:
        data_file_path=os.path.join(data_dir,split+".tsv")
        num_lines = sum(1 for _ in open(data_file_path))
        with open(data_file_path, 'r') as csvfile:
                next(csvfile)
                for line in tqdm(csvfile,total=num_lines-1):
                    line=line.strip().split("\t")
                    text = line[0]
                    if tokenizer_type=="subword":
                        tokenized_text = tokenizer.tokenize(text)
                    elif tokenizer_type=="word":
                        tokenized_text = [token.text for token in tokenizer(text)]
                    for token in tokenized_text:
                        vocab[token]+=1
    if tokenizer_type == "subword":
        for token in tokenizer.vocab:
            vocab[token]+=1
    return vocab
    
vocab = get_vocab_SST2(args.data_dir, tokenizer, tokenizer_type=tokenizer_type)

100%|██████████| 67349/67349 [00:04<00:00, 13636.46it/s]
100%|██████████| 872/872 [00:00<00:00, 7794.06it/s]


In [23]:
sensitive_word_count = int(args.sensitive_word_percentage * len(vocab))
words = [key for key, _ in vocab.most_common()]
sensitive_words = words[-sensitive_word_count - 1:]

sensitive_words2id = {word: k for k, word in enumerate(sensitive_words)}
print("#Total Words: %d, #Sensitive Words: %d" % (len(words),len(sensitive_words2id)))

#Total Words: 30522, #Sensitive Words: 30522


In [24]:
sensitive_word_embed = []
all_word_embed=[]

word2id = {}
sword2id = {}
sensitive_count = 0
all_count = 0
for cur_word in tokenizer.vocab:
    if cur_word in vocab and cur_word not in word2id:
        word2id[cur_word] = all_count
        emb = embedding_matrix[tokenizer.convert_tokens_to_ids(cur_word)]
        all_word_embed.append(emb)
        all_count += 1

        if cur_word in sensitive_words2id:
                sword2id[cur_word] = sensitive_count
                sensitive_count += 1
                sensitive_word_embed.append(emb)
        assert len(word2id) == len(all_word_embed)
        assert len(sword2id) == len(sensitive_word_embed)

In [25]:
import numpy as np
all_word_embed=np.array(all_word_embed, dtype='f')
sensitive_word_embed = np.array(sensitive_word_embed, dtype='f')

print("All Word Embedding Matrix: %s" % str(all_word_embed.shape))
print("Sensitive Word Embedding Matrix: %s" % str(sensitive_word_embed.shape))

from scipy.special import softmax
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, manhattan_distances
def cal_probability(word_embed_1, word_embed_2, epsilon=2.0):
    distance = euclidean_distances(word_embed_1, word_embed_2)
    sim_matrix = -distance
    prob_matrix = softmax(epsilon * sim_matrix / 2, axis=1)
    return prob_matrix

print("Calculating Prob Matrix for Exponential Mechanism...")
prob_matrix = cal_probability(all_word_embed, sensitive_word_embed, args.epsilon)

All Word Embedding Matrix: (30522, 768)
Sensitive Word Embedding Matrix: (30522, 768)
Calculating Prob Matrix for Exponential Mechanism...


In [7]:
from SanText import SanText_plus,SanText_plus_init

threads = min(args.threads, cpu_count())

for file_name in ['train.tsv','dev.tsv']:
    data_file = os.path.join(args.data_dir, file_name)
    out_file = open(os.path.join(args.output_dir, file_name), 'w')
    print("Processing file: %s. Will write to: %s" % (data_file,os.path.join(args.output_dir, file_name)))

    num_lines = sum(1 for _ in open(data_file))
    with open(data_file, 'r') as rf:
        # header
        header = next(rf)
        out_file.write(header)
        labels = []
        docs = []
        if args.task == "SST-2":
            for line in tqdm(rf, total=num_lines - 1):
                content = line.strip().split("\t")
                text = content[0]
                label = int(content[1])
                if args.embedding_type == "glove":
                    doc = [token.text for token in tokenizer(text)]
                else:
                    doc = tokenizer.tokenize(text)
                docs.append(doc)
                labels.append(label)
        rf.close()
        
        with Pool(threads, initializer=SanText_plus_init, initargs=(prob_matrix, word2id, sword2id, words, args.p, tokenizer)) as p:
            annotate_ = partial(
                SanText_plus,
            )
            results = list(
                tqdm(
                    p.imap(annotate_, docs, chunksize=32),
                    total=len(docs),
                    desc="Sanitize docs using SanText",
                )
            )
            p.close()

        print("Saving ...")
        
        if args.task == "SST-2":
            for i, predicted_text in enumerate(results):
                write_content = predicted_text + "\t" + str(labels[i]) + "\n"
                out_file.write(write_content)

        out_file.close()

Processing file: ./data/SST-2/train.tsv. Will write to: ./output_SanText_bert/SST-2/train.tsv


100%|██████████| 67349/67349 [00:04<00:00, 13755.96it/s]
Sanitize docs using SanText: 100%|██████████| 67349/67349 [00:01<00:00, 61980.28it/s]


Saving ...
Processing file: ./data/SST-2/dev.tsv. Will write to: ./output_SanText_bert/SST-2/dev.tsv


100%|██████████| 872/872 [00:00<00:00, 7410.60it/s]
Sanitize docs using SanText: 100%|██████████| 872/872 [00:00<00:00, 18262.51it/s]

Saving ...





In [18]:
def SanText_plus_init(prob_matrix_init, word2id_init, sword2id_init, all_words_init, p_init, tokenizer_init):
    global prob_matrix
    global word2id
    global sword2id
    global id2sword
    global all_words
    global p
    global tokenizer

    prob_matrix = prob_matrix_init
    word2id = word2id_init
    sword2id=sword2id_init

    id2sword = {v: k for k, v in sword2id.items()}

    all_words = all_words_init
    p=p_init
    tokenizer=tokenizer_init

In [30]:
SanText_plus_init(prob_matrix, word2id, sword2id, words, args.p, tokenizer)

In [31]:
def SanText_plus(doc):
    new_doc = []
    for word in doc:
        if word in word2id:
            # In-vocab
            if word in sword2id:
                #Sensitive Words
                index = word2id[word]
                sampling_prob = prob_matrix[index]
                sampling_index = np.random.choice(len(sampling_prob), 1, p=sampling_prob)
                new_doc.append(id2sword[sampling_index[0]])
            else:
                #Non-sensitive words
                flip_p=random.random()
                if flip_p<=p:
                    #sample a word from Vs based on prob matrix
                    index = word2id[word]
                    sampling_prob = prob_matrix[index]
                    sampling_index = np.random.choice(len(sampling_prob), 1, p=sampling_prob)
                    new_doc.append(id2sword[sampling_index[0]])
                else:
                    #keep as the original
                    new_doc.append(word)
        else:
            #Out-of-Vocab words
            sampling_prob = 1 / len(all_words) * np.ones(len(all_words), )
            sampling_index = np.random.choice(len(sampling_prob), 1, p=sampling_prob)
            new_doc.append(all_words[sampling_index[0]])

    new_doc = " ".join(new_doc)
    return new_doc

In [41]:
data_file = os.path.join(args.data_dir, file_name)
with open(data_file, 'r') as rf:
    header = next(rf)
    print(header)
    for line in rf:
        content = line.strip().split("\t")
        text = content[0]
        label = int(content[1])
        doc = tokenizer.tokenize(text)
        print(text)
        print(list(doc))
        print(label)
        break

sentence	label

it 's a charming and often affecting journey . 
['it', "'", 's', 'a', 'charming', 'and', 'often', 'affecting', 'journey', '.']
1


In [66]:
SanText_plus(doc)

'it diocesan s a charming induced often [unused138] ##ya .'

In [74]:
new_doc = []
for word in doc:
    if word in word2id:
        # In-vocab
        if word in sword2id:
            #Sensitive Words
            index = word2id[word]
            sampling_prob = prob_matrix[index]
            sampling_index = np.random.choice(len(sampling_prob), 1, p=sampling_prob) #changable
            print(word,'->',id2sword[sampling_index[0]])
            new_doc.append(id2sword[sampling_index[0]])
        else:
            #Non-sensitive words
            flip_p=random.random()
            if flip_p<=p:
                #sample a word from Vs based on prob matrix
                index = word2id[word]
                sampling_prob = prob_matrix[index]
                sampling_index = np.random.choice(len(sampling_prob), 1, p=sampling_prob) #changable
                new_doc.append(id2sword[sampling_index[0]])
            else:
                #keep as the original
                new_doc.append(word)
                
new_doc = " ".join(new_doc)
new_doc

it -> it
' -> follows
s -> portraits
a -> a
charming -> 1885
and -> shi
often -> spotted
affecting -> [unused211]
journey -> april
. -> 1908


'it follows portraits a 1885 shi spotted [unused211] april 1908'