# Gradient-based word deletion - edited

I realized a flaw in the previous notebook, where I was ranking words based on their score change when substituting with `[UNK]` instead of `[MASK]`.  Now I'm going to rerun using beam search, which should do a better job at identifying the most important words to replace.

In [1]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

In [2]:
from model import DocumentProfileMatchingTransformer

checkpoint_path = "/home/jxm3/research/deidentification/unsupervised-deidentification/saves/deid-wikibio_deid_exp/okpvvffw_46/checkpoints/epoch=7-step=1823.ckpt"
model = DocumentProfileMatchingTransformer.load_from_checkpoint(
    checkpoint_path,
    dataset_name='wiki_bio',
    model_name_or_path='distilbert-base-uncased',
    num_workers=1,
    loss_fn='exact',
    num_neighbors=2048,
    base_folder="/home/jxm3/research/deidentification/unsupervised-deidentification",
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Initialized DocumentProfileMatchingTransformer with learning_rate = 0.0002


In [3]:
from datamodule import WikipediaDataModule
import os

num_cpus = os.cpu_count()

dm = WikipediaDataModule(
    model_name_or_path='distilbert-base-uncased',
    dataset_name='wiki_bio',
    num_workers=min(8, num_cpus),
    train_batch_size=64,
    eval_batch_size=64,
    max_seq_length=64,
    redaction_strategy="",
    base_folder="/home/jxm3/research/deidentification/unsupervised-deidentification",
)
dm.setup("fit")

Initializing WikipediaDataModule with num_workers = 8


Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)
Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-5535f82839d9fec4.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-5b1c3941089b7f1b.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-8a9b289bc8e70b72.arrow
Loading cached processed dataset at 

## 2. Define attack in TextAttack 

In [4]:
dm.tokenizer.encode("Hi there")

[101, 7632, 2045, 102]

In [5]:
dm.tokenizer.encode("Hi there [MASK]")

[101, 7632, 2045, 103, 102]

In [6]:
import textattack

### (a) Beam search + replace with `[MASK]`

In [28]:
class WordSwapSingleWord(textattack.transformations.word_swap.WordSwap):
    """Takes a sentence and transforms it by replacing with a single fixed word.
    """
    single_word: str
    def __init__(self, single_word: str = "?", **kwargs):
        super().__init__(**kwargs)
        self.single_word = single_word

    def _get_replacement_words(self, _word: str):
        return [self.single_word]

transformation = WordSwapSingleWord(single_word='[MASK]')
transformation(textattack.shared.AttackedText("Hello my name is Jack"))

[<AttackedText "[MASK] my name is Jack">,
 <AttackedText "Hello [MASK] name is Jack">,
 <AttackedText "Hello my [MASK] is Jack">,
 <AttackedText "Hello my name [MASK] Jack">,
 <AttackedText "Hello my name is [MASK]">]

### (b) "Attack success" as fullfilment of the metric

In [9]:
from typing import List
import torch

class ChangeClassificationToBelowTopKClasses(textattack.goal_functions.ClassificationGoalFunction):
    k: int
    def __init__(self, *args, k: int = 1, **kwargs):
        self.k = k
        super().__init__(*args, **kwargs)

    def _is_goal_complete(self, model_output, _):
        original_class_score = model_output[self.ground_truth_output]
        num_better_classes = (model_output > original_class_score).sum()
        return num_better_classes >= self.k

    def _get_score(self, model_output, _):
        return 1 - model_output[self.ground_truth_output]
    
    
    """have to reimplement the following method to change the precision on the sum-to-one condition."""
    def _process_model_outputs(self, inputs, scores):
        """Processes and validates a list of model outputs.
        This is a task-dependent operation. For example, classification
        outputs need to have a softmax applied.
        """
        # Automatically cast a list or ndarray of predictions to a tensor.
        if isinstance(scores, list):
            scores = torch.tensor(scores)

        # Ensure the returned value is now a tensor.
        if not isinstance(scores, torch.Tensor):
            raise TypeError(
                "Must have list, np.ndarray, or torch.Tensor of "
                f"scores. Got type {type(scores)}"
            )

        # Validation check on model score dimensions
        if scores.ndim == 1:
            # Unsqueeze prediction, if it's been squeezed by the model.
            if len(inputs) == 1:
                scores = scores.unsqueeze(dim=0)
            else:
                raise ValueError(
                    f"Model return score of shape {scores.shape} for {len(inputs)} inputs."
                )
        elif scores.ndim != 2:
            # If model somehow returns too may dimensions, throw an error.
            raise ValueError(
                f"Model return score of shape {scores.shape} for {len(inputs)} inputs."
            )
        elif scores.shape[0] != len(inputs):
            # If model returns an incorrect number of scores, throw an error.
            raise ValueError(
                f"Model return score of shape {scores.shape} for {len(inputs)} inputs."
            )
        elif not ((scores.sum(dim=1) - 1).abs() < 1e-4).all():
            # Values in each row should sum up to 1. The model should return a
            # set of numbers corresponding to probabilities, which should add
            # up to 1. Since they are `torch.float` values, allow a small
            # error in the summation.
            scores = torch.nn.functional.softmax(scores, dim=1)
            if not ((scores.sum(dim=1) - 1).abs() < 1e-4).all():
                raise ValueError("Model scores do not add up to 1.")
        return scores.cpu()


## (c) Model wrapper that computes similarities of input documents with validation profiles

In [17]:
import transformers

class MyModelWrapper(textattack.models.wrappers.ModelWrapper):
    model: DocumentProfileMatchingTransformer
    tokenizer: transformers.PreTrainedTokenizer
    profile_embeddings: torch.Tensor
    max_seq_length: int
    
    def __init__(self, model: DocumentProfileMatchingTransformer, tokenizer: transformers.PreTrainedTokenizer, max_seq_length: int = 64):
        self.model = model
        self.tokenizer = tokenizer
        self.profile_embeddings = torch.tensor(model.val_embeddings)
        self.max_seq_length = max_seq_length
                 
    def to(self, device):
        self.model.to(device)
        self.profile_embeddings.to(device)
        return self # so semantics `model = MyModelWrapper().to('cuda')` works properly

    def __call__(self, text_input_list, batch_size=32):
        model_device = next(self.model.parameters()).device
        tokenized_ids = self.tokenizer.batch_encode_plus(
            text_input_list,
            max_length=self.max_seq_length,
            padding=True,
            truncation=True
        )
        tokenized_ids = {k: torch.tensor(v).to(model_device) for k,v in tokenized_ids.items()}
        
        # TODO: implement batch size if we start running out of memory here.
        with torch.no_grad():
            document_embeddings = self.model.document_model(**tokenized_ids)
            document_embeddings = document_embeddings['last_hidden_state'][:, 0, :] # (batch, document_emb_dim)
            document_embeddings = self.model.lower_dim_embed(document_embeddings) # (batch, emb_dim)

        document_to_profile_probs = torch.nn.functional.softmax(
            document_embeddings @ self.profile_embeddings.T.to(model_device), dim=-1)
        assert document_to_profile_probs.shape == (len(text_input_list), len(self.profile_embeddings))
        return document_to_profile_probs
            

## (d) Dataset that loads Wikipedia documents with names as labels

In [18]:
next(iter(dm.val_dataloader()))

{'document_input_ids': tensor([[  101,  4831,  2745,  ..., 10722, 26896,   102],
         [  101, 17504, 12022,  ...,     0,     0,     0],
         [  101,  7929,  2319,  ...,     0,     0,     0],
         ...,
         [  101,  9033,  2860,  ..., 12681,  5283,   102],
         [  101,  7332, 27319,  ...,   102,     0,     0],
         [  101,  3958, 11463,  ...,     0,     0,     0]]),
 'document_attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'document_redact_ner_input_ids': tensor([[  101,  4831,   103,  ...,  1010,  3140,   102],
         [  101, 17504, 12022,  ...,     0,     0,     0],
         [  101,   103,   103,  ...,     0,     0,     0],
         ...,
         [  101,  9033,  2860,  ..., 15810,  8840,   102],
         [  101,   103,   103,  ...,   102,     0,     0],
         [  1

In [19]:
from typing import Tuple

from collections import OrderedDict

import datasets

class WikiDataset(textattack.datasets.Dataset):
    dataset: datasets.Dataset
    
    def __init__(self, dm: WikipediaDataModule):
        self.shuffled = True
        self.dataset = dm.val_dataset
        self.label_names = list(dm.val_dataset['name'])
    
    def __len__(self) -> int:
        return len(self.dataset)
    
    def __getitem__(self, i: int) -> Tuple[OrderedDict, int]:
        input_dict = OrderedDict([
            ('document', self.dataset['document'][i])
        ])
        return input_dict, self.dataset['text_key_id'][i].item()
        

## 3. Run attack once

In [13]:
from textattack.loggers import CSVLogger
from textattack.shared import AttackedText

import pandas as pd
class CustomCSVLogger(CSVLogger):
    """Logs attack results to a CSV."""

    def log_attack_result(self, result: textattack.goal_function_results.ClassificationGoalFunctionResult):
        original_text, perturbed_text = result.diff_color(self.color_method)
        original_text = original_text.replace("\n", AttackedText.SPLIT_TOKEN)
        perturbed_text = perturbed_text.replace("\n", AttackedText.SPLIT_TOKEN)
        result_type = result.__class__.__name__.replace("AttackResult", "")
        row = {
            "original_person": result.original_result._processed_output[0],
            "original_text": original_text,
            "perturbed_person": result.perturbed_result._processed_output[0],
            "perturbed_text": perturbed_text,
            "original_score": result.original_result.score,
            "perturbed_score": result.perturbed_result.score,
            "original_output": result.original_result.output,
            "perturbed_output": result.perturbed_result.output,
            "ground_truth_output": result.original_result.ground_truth_output,
            "num_queries": result.num_queries,
            "result_type": result_type,
        }
        self.df = pd.concat([self.df, pd.DataFrame([row])], ignore_index=True)
        self._flushed = False

In [42]:
class MaxNumWordsModified(textattack.constraints.PreTransformationConstraint):
    def __init__(self, max_num_words: int):
        self.max_num_words = max_num_words

    def _get_modifiable_indices(self, current_text):
        """Returns the word indices in current_text which are able to be
        modified."""

        if len(current_text.attack_attrs["modified_indices"]) >= self.max_num_words:
            return set()
        else:
            return set(range(len(current_text.words)))

    def extra_repr_keys(self):
        return ["max_num_words"]

In [49]:
from textattack import Attack
from textattack.constraints.pre_transformation import RepeatModification

model_wrapper = MyModelWrapper(model, dm.tokenizer)
model_wrapper.to('cuda')

goal_function = ChangeClassificationToBelowTopKClasses(model_wrapper, k=10)
constraints = [RepeatModification(), MaxNumWordsModified(max_num_words=10)]
transformation = WordSwapSingleWord(single_word='[MASK]')
search_method = textattack.search_methods.BeamSearch(beam_width=4)

attack = Attack(
    goal_function, constraints, transformation, search_method
)

textattack: No entry found for goal function <class '__main__.ChangeClassificationToBelowTopKClasses'>.
textattack: Unknown if model of class <class 'model.DocumentProfileMatchingTransformer'> compatible with goal function <class '__main__.ChangeClassificationToBelowTopKClasses'>.


In [50]:
# 
#  Initialize attack
# 
from tqdm import tqdm # tqdm provides us a nice progress bar.
from textattack.attack_results import SuccessfulAttackResult
from textattack import Attacker
from textattack import AttackArgs

attack_args = AttackArgs(num_examples=25, disable_stdout=True)
dataset = WikiDataset(dm)

attacker = Attacker(attack, dataset, attack_args)

results_iterable = attacker.attack_dataset()

logger = CustomCSVLogger(color_method='html')

# 
# Run attack
# 
from tqdm import tqdm
for result in results_iterable:
    tqdm._instances.clear() # Doesn't fix the progress bar :-(
    logger.log_attack_result(result)

from IPython.display import display, HTML

display(HTML(logger.df.to_html(escape=False)))

Attack(
  (search_method): BeamSearch(
    (beam_width):  4
  )
  (goal_function):  ChangeClassificationToBelowTopKClasses
  (transformation):  WordSwapSingleWord
  (constraints): 
    (0): RepeatModification
    (1): MaxNumWordsModified(
        (max_num_words):  10
      )
  (is_black_box):  True
) 




  0%|          | 0/25 [00:00<?, ?it/s][A
  4%|▍         | 1/25 [00:02<00:58,  2.44s/it][A
[Succeeded / Failed / Skipped / Total] 1 / 0 / 0 / 1:   4%|▍         | 1/25 [00:02<00:58,  2.45s/it][A
[Succeeded / Failed / Skipped / Total] 1 / 0 / 0 / 1:   8%|▊         | 2/25 [00:02<00:31,  1.37s/it][A
[Succeeded / Failed / Skipped / Total] 2 / 0 / 0 / 2:   8%|▊         | 2/25 [00:02<00:31,  1.37s/it][A
[Succeeded / Failed / Skipped / Total] 2 / 0 / 0 / 2:  12%|█▏        | 3/25 [00:02<00:21,  1.01it/s][A
[Succeeded / Failed / Skipped / Total] 3 / 0 / 0 / 3:  12%|█▏        | 3/25 [00:02<00:21,  1.01it/s][A
[Succeeded / Failed / Skipped / Total] 3 / 0 / 0 / 3:  16%|█▌        | 4/25 [00:03<00:17,  1.18it/s][A
[Succeeded / Failed / Skipped / Total] 4 / 0 / 0 / 4:  16%|█▌        | 4/25 [00:03<00:17,  1.18it/s][A
[Succeeded / Failed / Skipped / Total] 4 / 0 / 0 / 4:  20%|██        | 5/25 [00:03<00:15,  1.33it/s][A
[Succeeded / Failed / Skipped / Total] 5 / 0 / 0 / 5:  20%|██        | 5/25


+-------------------------------+---------+
| Attack Results                |         |
+-------------------------------+---------+
| Number of successful attacks: | 22      |
| Number of failed attacks:     | 3       |
| Number of skipped attacks:    | 0       |
| Original accuracy:            | 100.0%  |
| Accuracy under attack:        | 12.0%   |
| Attack success rate:          | 88.0%   |
| Average perturbed word %:     | 11.96%  |
| Average num. words per input: | 92.8    |
| Avg num queries:              | 2350.12 |
+-------------------------------+---------+


textattack: Logging to CSV at path results.csv
textattack: CSVLogger exiting without calling flush().





Unnamed: 0,original_person,original_text,perturbed_person,perturbed_text,original_score,perturbed_score,original_output,perturbed_output,ground_truth_output,num_queries,result_type
0,Michael iii of alexandria,"pope michael iii of alexandria -lrb- also known as khail iii -rrb- was the coptic pope of alexandria and patriarch of the see of st. mark -lrb- 880 -- 907 -rrb- .in 882 , the governor of egypt , ahmad ibn tulun , forced khail to pay heavy contributions , forcing him to sell a church and some attached properties to the local jewish community .this building was at one time believed to have later become the site of the cairo geniza .",Constantine diogenes,"pope [MASK] iii of [MASK] -lrb- also known as [MASK] iii -rrb- was the coptic pope of alexandria and patriarch of the see of st. [MASK] -lrb- 880 -- 907 -rrb- .in 882 , the governor of egypt , ahmad ibn tulun , [MASK] khail to pay heavy contributions , forcing him to sell a church and some attached properties [MASK] the local jewish community .this building was at one time believed to have later become the site of the cairo geniza .",0.4064472,0.985631,0,11450,0,1558,Successful
1,Hui jun,hui jun is a male former table tennis player from china .,Topu barman,[MASK] [MASK] is a male former [MASK] [MASK] player [MASK] [MASK] .,9.536743e-07,0.998122,1,10565,1,172,Successful
2,Kittisak jaihan,okan Öztürk -lrb- born 30 november 1977 -rrb- is a turkish professional footballer .he currently plays as a striker for yeni malatyaspor .,Kittisak jaihan,[MASK] [MASK] -lrb- born 30 november 1977 -rrb- is a turkish professional footballer .he currently plays as a striker for yeni malatyaspor .,0.9393426,0.999379,13715,13715,2,107,Successful
3,Marie stephan,"marie stephan , -lrb- born march 14 , 1996 -rrb- is a professional squash player who represents france .she reached a career-high world ranking of world no. 101 in july 2015 .",Dylan murray,"[MASK] [MASK] , -lrb- born march 14 , 1996 -rrb- is a professional squash player who represents [MASK] .[MASK] reached a career-high world ranking of world no. 101 in july 2015 .",0.2271534,0.987244,3,8083,3,354,Successful
4,Leonard l. martino,leonard l. martino is a former democratic member of the pennsylvania house of representatives .he was born in butler to michael and angela pitullio martino .,Neal huff,[MASK] [MASK]. [MASK] is a former democratic member of the pennsylvania house of representatives .he was born in butler to michael and angela pitullio [MASK] .,0.5373687,0.99389,4,10546,4,302,Successful
5,Salome jens,"salome jens -lrb- born may 8 , 1935 -rrb- is an american stage , film and television actress .she is perhaps best known for portraying the female changeling on '' '' .",Martin ferrero,"[MASK] [MASK] -lrb- born may 8 , 1935 -rrb- is an american stage , film and television actress .she is perhaps best known for portraying the [MASK] changeling on '' '' .",0.02987182,0.993593,5,13841,5,232,Successful
6,Carl crawford,"carl demonte crawford -lrb- born august 5 , 1981 -rrb- , nicknamed `` the perfect storm '' , is an american professional baseball left fielder with the los angeles dodgers of major league baseball -lrb- mlb -rrb- .he bats and throws left-handed .crawford was drafted by the tampa bay devil rays in the second round -lrb- 52nd overall -rrb- of the 1999 major league baseball draft .he made his major league debut in 2002 .crawford has more triples -lrb- 121 -rrb- than any other active baseball player .",Dennis springer,"carl [MASK] [MASK] -lrb- born august 5 , 1981 -rrb- , nicknamed `` the perfect storm '' , is an american professional baseball left fielder with the los angeles dodgers of major league baseball -lrb- mlb -rrb- .he bats and throws left-handed .[MASK] was drafted by the tampa bay devil rays in the second round -lrb- 52nd overall -rrb- of the 1999 major league baseball draft .he made his major league debut in 2002 .crawford has more triples -lrb- 121 -rrb- than any other active baseball player .",0.1424156,0.994957,6,8802,6,736,Successful
7,Jim bob,"jim bob -lrb- born james neil morrison on 22 november 1960 -rrb- is a british musician and author , best known as the singer of indie punk band carter usm .",Lee sang-don,"[MASK] [MASK] -lrb- born james neil [MASK] on 22 november 1960 -rrb- is a british musician and author , best known as the singer of indie punk band carter usm .",0.3278745,0.991559,7,7895,7,250,Successful
8,Riddick parker,"riddick parker -lrb- born november 20 , 1972 in emporia , virginia -rrb- is a former professional american football defensive lineman for the seattle seahawks , san diego chargers , new england patriots , baltimore ravens , and san francisco 49ers of the national football league .",Leonard stephens,"[MASK] parker -lrb- born november 20 , 1972 in [MASK] , virginia -rrb- is a former professional american football defensive lineman for the seattle seahawks , san diego chargers , new england patriots , baltimore ravens , and san francisco 49ers of the national football league .",0.8275697,0.995439,8,11057,8,197,Successful
9,Blessed osanna of cattaro -lrb- ozana kotorska -rrb-,blessed osanna of cattaro t.o.s.d. -lrb- -rrb- was a catholic visionary and anchoress from cattaro -lrb- kotor -rrb- .she was a teenage convert from orthodoxy of serbian descent from montenegro -lrb- zeta -rrb- .she became a dominican tertiary and was posthumously venerated as a saint in kotor .she was later beatified in 1934 .,Típica 73,[MASK] [MASK] [MASK] cattaro [MASK].o.s.d. -lrb- -rrb- was a [MASK] [MASK] and anchoress from cattaro -lrb- kotor -rrb- .[MASK] was a [MASK] convert from [MASK] of serbian descent from montenegro -lrb- zeta -rrb- .she became a dominican tertiary and was posthumously venerated as a saint in kotor .she was later beatified in 1934 .,0.04990202,0.998397,9,9917,9,1705,Successful


In [None]:
dm.val_dataset[7243]

## 4. Run attack in loop and make plot for multiple values of $\epsilon$

In [None]:
# 
#  Initialize attack
# 
from tqdm import tqdm # tqdm provides us a nice progress bar.
from textattack.attack_results import SuccessfulAttackResult
from textattack import Attacker
from textattack import AttackArgs

dataset = WikiDataset(dm)

meta_results = []
for k in range(1, 50):
    print(f'***Attacking with k={k}***')
    dataset = WikiDataset(dm)
    goal_function = ChangeClassificationToBelowTopKClasses(model_wrapper, k=k)
    attack = Attack(
        goal_function, constraints, transformation, search_method
    )
    attack_args = AttackArgs(num_examples=1000, disable_stdout=True)
    attacker = Attacker(attack, dataset, attack_args)

    results_iterable = attacker.attack_dataset()

    logger = CustomCSVLogger(color_method='html')

    for result in results_iterable:
        logger.log_attack_result(result)
    
    meta_results.append( (k,  logger.df['result_type'].value_counts().get('Successful', 0),  logger.df['result_type'].value_counts().get('Failed', 0) ) )

import pandas as pd
meta_df = pd.DataFrame(meta_results, columns=['k', 'Successes', 'Failures']).head()

In [None]:
meta_df.plot(x='k', y='Successes')