In [12]:
!pip install textattack pyarrow==1.0 sentence-transformers > /dev/null

In [13]:
from nltk.corpus import wordnet

import textattack
from textattack.transformations.word_swap import WordSwap


class WordSwapWordNetAntonym(WordSwap):
    """Transforms an input by replacing its words with synonyms provided by
    WordNet."""

    def _get_replacement_words(self, word, random=False):
        """Returns a list containing all possible words with 1 character
        replaced by a homoglyph."""
        antonyms = set()
        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                if lemma.antonyms():
                  for ant_lemma in lemma.antonyms():
                    ant_word = ant_lemma.name()
                    if (
                        (ant_word != word)
                        and ("_" not in ant_word)
                        and (textattack.shared.utils.is_one_word(ant_word))
                    ):
                        # WordNet can suggest phrases that are joined by '_' but we ignore phrases.
                        antonyms.add(ant_word)
        return list(antonyms)

In [14]:
def set_seed(random_seed):
  import random
  random.seed(random_seed)
  import numpy as np
  np.random.seed(random_seed)
  import torch
  torch.manual_seed(random_seed)
  torch.cuda.manual_seed(random_seed)

set_seed(42)

In [15]:
import torch
import textattack

class FoolConstraintGoalFunction(textattack.goal_functions.classification.UntargetedClassification):
  def __init__(self, constraint, min_acceptable_score=0.8,
                num_words_to_swap=2):
    self.constraint = constraint
    self.query_budget = float("inf")
    self.min_acceptable_score = min_acceptable_score
    self.use_cache = False
    self.num_words_to_swap = num_words_to_swap
    self.maximizable = False
    self.model_cache_size = 0
  
  def _should_skip(self, *_):
    return False
  
  def _is_goal_complete(self, model_output, attacked_text):
    num_words_swapped = len(attacked_text.attack_attrs.get('modified_indices', []))
    model_score =  model_output.item()

    return (num_words_swapped >= self.num_words_to_swap) and (model_score >= self.min_acceptable_score)

  def _call_model(self, attacked_text_list):
    """ Gets predictions for a list of `AttackedText` objects.

    Gets prediction from cache if possible. If prediction is not in the 
    cache, queries model and stores prediction in cache.
    """
    original_text = attacked_text_list[0]
    while "previous_attacked_text" in original_text.attack_attrs:
      original_text = original_text.attack_attrs["previous_attacked_text"]

    scores = []
    for at in attacked_text_list:
      if "newly_modified_indices" not in at.attack_attrs:
        # Original text
        scores.append([1.0])
      else:
        at.attack_attrs["last_transformation"] = transformation
        if isinstance(self.constraint, BERTScore):
          # call bert scorer specially
          model_scores = self.constraint._bert_scorer.score([original_text.text], [at.text])
        else:
          # otherwise, it's a sentence encoder
          model_scores = self.constraint._score_list(original_text, [at])
        scores.append([model_scores[0]])
    return torch.tensor(scores)
      
  def _get_score(self, model_output, attacked_text):
    model_score =  model_output.item()
    if model_score < self.min_acceptable_score:
      return 0.0

    num_words_swapped = len(attacked_text.attack_attrs.get('modified_indices', []))
    num_words = len(attacked_text.words)
    num_words_score = (num_words_swapped / num_words)
    return num_words_swapped + model_score

In [30]:
from textattack.constraints.grammaticality.language_models import GPT2
from textattack.constraints.semantics import BERTScore
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.constraints.pre_transformation import InputColumnModification, RepeatModification, StopwordModification
from textattack.datasets import HuggingFaceNlpDataset
from textattack.search_methods import BeamSearch
from textattack.shared import Attack
from textattack.transformations import WordSwapEmbedding

import numpy as np
import pickle
import random
import tqdm

transformation = WordSwapWordNetAntonym()

# We'll constrain modification of already modified indices and stopwords
constraints = [RepeatModification(), StopwordModification()]
# don't attack premise in entailment
constraints.append(InputColumnModification(["premise", "hypothesis"], {"premise"}))
# use GPT2 to try and make sentences somewhat plausible
constraints.append(GPT2(max_log_prob_diff=2.0))

# Use RTE dataset
dataset = [(x['text'], x['label']) for x in pickle.load(open('mr.pkl', 'rb'))]
random.shuffle(dataset)

data = []

num_samples = 100
num_words_to_swap = 3

all_constraints = ('bertscore', 'use')
threshold_vals = np.arange(.50, 1.0, .01)
for constraint_idx, constraint_name in enumerate(all_constraints):
  tqdm.tqdm.write(f'----> constraint {constraint_name}')

  # We know this second-order attack fails most of the time, and fails more as the 
  # threshold increases. Any example that fails will continue to fail. We take 
  # advantage of this fact through caching.
  known_failure_idxs = set()

  if constraint_name == 'bertscore':
    constraint = BERTScore(
      min_bert_score=0.0, # don't need this
      model="bert-base-uncased",
      score_type="f1",
      compare_against_original=True,
    )
  else:
    constraint = UniversalSentenceEncoder(
      compare_against_original=True,
      skip_text_shorter_than_window=False,
    )
  for threshold_idx, threshold in enumerate(threshold_vals):
    # goal function is to fool a single constraint
    tqdm.tqdm.write(f'--> Threshold {threshold} / Num words to swap {num_words_to_swap}')
    goal_function = FoolConstraintGoalFunction(
      constraint,
      min_acceptable_score=threshold,
      num_words_to_swap=num_words_to_swap,
    )

    # search method
    search_method = BeamSearch(beam_width=2)
    # Now, let's make the attack from the 4 components:
    attack = Attack(goal_function, constraints, transformation, search_method)
    idxs_to_attack = set(range(num_samples)) - known_failure_idxs
    idxs_to_attack = list(sorted(idxs_to_attack))

    if len(idxs_to_attack):
      this_sample_idx = (constraint_idx * len(threshold_vals)) + threshold_idx + 1
      total_num_samples = len(all_constraints) * len(threshold_vals)
      results_iterable = list(tqdm.tqdm(attack.attack_dataset(dataset, 
                                      indices=idxs_to_attack), 
                                      total=len(idxs_to_attack), 
                                      position=0, 
                                      leave=True, desc=f'Sample {this_sample_idx}/{total_num_samples}'))
    else:
      results_iterable = []


    num_successes = 0
    for result_idx, result in zip(idxs_to_attack, results_iterable):
      if isinstance(result, textattack.attack_results.FailedAttackResult):
        known_failure_idxs.add(result_idx)
      elif isinstance(result, textattack.attack_results.SuccessfulAttackResult):
        num_successes += 1
      
    data.append({ 
        'constraint': type(constraint).__name__, 
        'threshold': threshold, 
        'num_successes': num_successes, 
        'num_words_to_swap': num_words_to_swap 
      })

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Sample 8/100:  35%|███▌      | 18/51 [03:43<00:17,  1.89it/s]

----> constraint bertscore


Sample 1/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.5 / Num words to swap 3


Sample 1/100: 100%|██████████| 100/100 [00:39<00:00,  2.52it/s]
Sample 2/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.51 / Num words to swap 3


Sample 2/100: 100%|██████████| 48/48 [00:34<00:00,  1.41it/s]
Sample 3/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.52 / Num words to swap 3


Sample 3/100: 100%|██████████| 48/48 [00:33<00:00,  1.42it/s]
Sample 4/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.53 / Num words to swap 3


Sample 4/100: 100%|██████████| 48/48 [00:34<00:00,  1.40it/s]
Sample 5/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.54 / Num words to swap 3


Sample 5/100: 100%|██████████| 48/48 [00:34<00:00,  1.40it/s]
Sample 6/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.55 / Num words to swap 3


Sample 6/100: 100%|██████████| 48/48 [00:34<00:00,  1.40it/s]
Sample 7/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.56 / Num words to swap 3


Sample 7/100: 100%|██████████| 48/48 [00:34<00:00,  1.41it/s]
Sample 8/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.5700000000000001 / Num words to swap 3


Sample 8/100: 100%|██████████| 48/48 [00:33<00:00,  1.42it/s]
Sample 9/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.5800000000000001 / Num words to swap 3


Sample 9/100: 100%|██████████| 48/48 [00:33<00:00,  1.42it/s]
Sample 10/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.5900000000000001 / Num words to swap 3


Sample 10/100: 100%|██████████| 48/48 [00:34<00:00,  1.39it/s]
Sample 11/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6000000000000001 / Num words to swap 3


Sample 11/100: 100%|██████████| 48/48 [00:34<00:00,  1.40it/s]
Sample 12/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6100000000000001 / Num words to swap 3


Sample 12/100: 100%|██████████| 48/48 [00:34<00:00,  1.40it/s]
Sample 13/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6200000000000001 / Num words to swap 3


Sample 13/100: 100%|██████████| 48/48 [00:33<00:00,  1.42it/s]
Sample 14/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6300000000000001 / Num words to swap 3


Sample 14/100: 100%|██████████| 48/48 [00:34<00:00,  1.41it/s]
Sample 15/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6400000000000001 / Num words to swap 3


Sample 15/100: 100%|██████████| 48/48 [00:34<00:00,  1.40it/s]
Sample 16/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6500000000000001 / Num words to swap 3


Sample 16/100: 100%|██████████| 48/48 [00:33<00:00,  1.42it/s]
Sample 17/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6600000000000001 / Num words to swap 3


Sample 17/100: 100%|██████████| 48/48 [00:35<00:00,  1.37it/s]
Sample 18/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6700000000000002 / Num words to swap 3


Sample 18/100: 100%|██████████| 48/48 [00:35<00:00,  1.36it/s]
Sample 19/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6800000000000002 / Num words to swap 3


Sample 19/100: 100%|██████████| 48/48 [00:34<00:00,  1.38it/s]
Sample 20/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6900000000000002 / Num words to swap 3


Sample 20/100: 100%|██████████| 48/48 [00:34<00:00,  1.40it/s]
Sample 21/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.7000000000000002 / Num words to swap 3


Sample 21/100: 100%|██████████| 48/48 [00:34<00:00,  1.41it/s]
Sample 22/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.7100000000000002 / Num words to swap 3


Sample 22/100: 100%|██████████| 48/48 [00:32<00:00,  1.46it/s]
Sample 23/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.7200000000000002 / Num words to swap 3


Sample 23/100: 100%|██████████| 48/48 [00:32<00:00,  1.47it/s]
Sample 24/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.7300000000000002 / Num words to swap 3


Sample 24/100: 100%|██████████| 48/48 [00:32<00:00,  1.48it/s]
Sample 25/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.7400000000000002 / Num words to swap 3


Sample 25/100: 100%|██████████| 48/48 [00:32<00:00,  1.46it/s]
Sample 26/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.7500000000000002 / Num words to swap 3


Sample 26/100: 100%|██████████| 48/48 [00:32<00:00,  1.47it/s]
Sample 27/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.7600000000000002 / Num words to swap 3


Sample 27/100: 100%|██████████| 48/48 [00:32<00:00,  1.47it/s]
Sample 28/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.7700000000000002 / Num words to swap 3


Sample 28/100: 100%|██████████| 48/48 [00:33<00:00,  1.45it/s]
Sample 29/100:   0%|          | 0/47 [00:00<?, ?it/s]

--> Threshold 0.7800000000000002 / Num words to swap 3


Sample 29/100: 100%|██████████| 47/47 [00:32<00:00,  1.45it/s]
Sample 30/100:   0%|          | 0/47 [00:00<?, ?it/s]

--> Threshold 0.7900000000000003 / Num words to swap 3


Sample 30/100: 100%|██████████| 47/47 [00:32<00:00,  1.45it/s]
Sample 31/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.8000000000000003 / Num words to swap 3


Sample 31/100: 100%|██████████| 46/46 [00:31<00:00,  1.44it/s]
Sample 32/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.8100000000000003 / Num words to swap 3


Sample 32/100: 100%|██████████| 46/46 [00:32<00:00,  1.43it/s]
Sample 33/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.8200000000000003 / Num words to swap 3


Sample 33/100: 100%|██████████| 46/46 [00:32<00:00,  1.43it/s]
Sample 34/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.8300000000000003 / Num words to swap 3


Sample 34/100: 100%|██████████| 46/46 [00:31<00:00,  1.44it/s]
Sample 35/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.8400000000000003 / Num words to swap 3


Sample 35/100: 100%|██████████| 46/46 [00:32<00:00,  1.43it/s]
Sample 36/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.8500000000000003 / Num words to swap 3


Sample 36/100: 100%|██████████| 46/46 [00:32<00:00,  1.44it/s]
Sample 37/100:   0%|          | 0/45 [00:00<?, ?it/s]

--> Threshold 0.8600000000000003 / Num words to swap 3


Sample 37/100: 100%|██████████| 45/45 [00:31<00:00,  1.43it/s]
Sample 38/100:   0%|          | 0/45 [00:00<?, ?it/s]

--> Threshold 0.8700000000000003 / Num words to swap 3


Sample 38/100: 100%|██████████| 45/45 [00:33<00:00,  1.35it/s]
Sample 39/100:   0%|          | 0/43 [00:00<?, ?it/s]

--> Threshold 0.8800000000000003 / Num words to swap 3


Sample 39/100: 100%|██████████| 43/43 [00:31<00:00,  1.37it/s]
Sample 40/100:   0%|          | 0/42 [00:00<?, ?it/s]

--> Threshold 0.8900000000000003 / Num words to swap 3


Sample 40/100: 100%|██████████| 42/42 [00:31<00:00,  1.35it/s]
Sample 41/100:   0%|          | 0/42 [00:00<?, ?it/s]

--> Threshold 0.9000000000000004 / Num words to swap 3


Sample 41/100: 100%|██████████| 42/42 [00:31<00:00,  1.35it/s]
Sample 42/100:   0%|          | 0/40 [00:00<?, ?it/s]

--> Threshold 0.9100000000000004 / Num words to swap 3


Sample 42/100: 100%|██████████| 40/40 [00:29<00:00,  1.34it/s]
Sample 43/100:   0%|          | 0/39 [00:00<?, ?it/s]

--> Threshold 0.9200000000000004 / Num words to swap 3


Sample 43/100: 100%|██████████| 39/39 [00:28<00:00,  1.35it/s]
Sample 44/100:   0%|          | 0/38 [00:00<?, ?it/s]

--> Threshold 0.9300000000000004 / Num words to swap 3


Sample 44/100: 100%|██████████| 38/38 [00:29<00:00,  1.30it/s]
Sample 45/100:   0%|          | 0/31 [00:00<?, ?it/s]

--> Threshold 0.9400000000000004 / Num words to swap 3


Sample 45/100: 100%|██████████| 31/31 [00:26<00:00,  1.17it/s]
Sample 46/100:   0%|          | 0/25 [00:00<?, ?it/s]

--> Threshold 0.9500000000000004 / Num words to swap 3


Sample 46/100: 100%|██████████| 25/25 [00:23<00:00,  1.07it/s]
Sample 47/100:   0%|          | 0/20 [00:00<?, ?it/s]

--> Threshold 0.9600000000000004 / Num words to swap 3


Sample 47/100: 100%|██████████| 20/20 [00:19<00:00,  1.04it/s]
Sample 48/100:   0%|          | 0/15 [00:00<?, ?it/s]

--> Threshold 0.9700000000000004 / Num words to swap 3


Sample 48/100: 100%|██████████| 15/15 [00:21<00:00,  1.46s/it]
Sample 49/100:   0%|          | 0/3 [00:00<?, ?it/s]

--> Threshold 0.9800000000000004 / Num words to swap 3


Sample 49/100: 100%|██████████| 3/3 [00:06<00:00,  2.16s/it]
Sample 20/100:  82%|████████▏ | 64/78 [44:59<00:29,  2.09s/it]

--> Threshold 0.9900000000000004 / Num words to swap 3
----> constraint use


Sample 51/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.5 / Num words to swap 3


Sample 51/100: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s]
Sample 52/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.51 / Num words to swap 3


Sample 52/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 53/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.52 / Num words to swap 3


Sample 53/100: 100%|██████████| 48/48 [00:26<00:00,  1.82it/s]
Sample 54/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.53 / Num words to swap 3


Sample 54/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 55/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.54 / Num words to swap 3


Sample 55/100: 100%|██████████| 48/48 [00:26<00:00,  1.82it/s]
Sample 56/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.55 / Num words to swap 3


Sample 56/100: 100%|██████████| 48/48 [00:26<00:00,  1.82it/s]
Sample 57/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.56 / Num words to swap 3


Sample 57/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 58/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.5700000000000001 / Num words to swap 3


Sample 58/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 59/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.5800000000000001 / Num words to swap 3


Sample 59/100: 100%|██████████| 48/48 [00:26<00:00,  1.82it/s]
Sample 60/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.5900000000000001 / Num words to swap 3


Sample 60/100: 100%|██████████| 48/48 [00:26<00:00,  1.80it/s]
Sample 61/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6000000000000001 / Num words to swap 3


Sample 61/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 62/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6100000000000001 / Num words to swap 3


Sample 62/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 63/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6200000000000001 / Num words to swap 3


Sample 63/100: 100%|██████████| 48/48 [00:26<00:00,  1.78it/s]
Sample 64/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6300000000000001 / Num words to swap 3


Sample 64/100: 100%|██████████| 48/48 [00:26<00:00,  1.82it/s]
Sample 65/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6400000000000001 / Num words to swap 3


Sample 65/100: 100%|██████████| 48/48 [00:26<00:00,  1.82it/s]
Sample 66/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6500000000000001 / Num words to swap 3


Sample 66/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 67/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6600000000000001 / Num words to swap 3


Sample 67/100: 100%|██████████| 48/48 [00:26<00:00,  1.82it/s]
Sample 68/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6700000000000002 / Num words to swap 3


Sample 68/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 69/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6800000000000002 / Num words to swap 3


Sample 69/100: 100%|██████████| 48/48 [00:26<00:00,  1.80it/s]
Sample 70/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.6900000000000002 / Num words to swap 3


Sample 70/100: 100%|██████████| 48/48 [00:26<00:00,  1.81it/s]
Sample 71/100:   0%|          | 0/47 [00:00<?, ?it/s]

--> Threshold 0.7000000000000002 / Num words to swap 3


Sample 71/100: 100%|██████████| 47/47 [00:26<00:00,  1.80it/s]
Sample 72/100:   0%|          | 0/47 [00:00<?, ?it/s]

--> Threshold 0.7100000000000002 / Num words to swap 3


Sample 72/100: 100%|██████████| 47/47 [00:26<00:00,  1.79it/s]
Sample 73/100:   0%|          | 0/47 [00:00<?, ?it/s]

--> Threshold 0.7200000000000002 / Num words to swap 3


Sample 73/100: 100%|██████████| 47/47 [00:26<00:00,  1.79it/s]
Sample 74/100:   0%|          | 0/47 [00:00<?, ?it/s]

--> Threshold 0.7300000000000002 / Num words to swap 3


Sample 74/100: 100%|██████████| 47/47 [00:26<00:00,  1.77it/s]
Sample 75/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.7400000000000002 / Num words to swap 3


Sample 75/100: 100%|██████████| 46/46 [00:26<00:00,  1.77it/s]
Sample 76/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.7500000000000002 / Num words to swap 3


Sample 76/100: 100%|██████████| 46/46 [00:25<00:00,  1.79it/s]
Sample 77/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.7600000000000002 / Num words to swap 3


Sample 77/100: 100%|██████████| 46/46 [00:25<00:00,  1.78it/s]
Sample 78/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.7700000000000002 / Num words to swap 3


Sample 78/100: 100%|██████████| 46/46 [00:25<00:00,  1.78it/s]
Sample 79/100:   0%|          | 0/46 [00:00<?, ?it/s]

--> Threshold 0.7800000000000002 / Num words to swap 3


Sample 79/100: 100%|██████████| 46/46 [00:26<00:00,  1.76it/s]
Sample 80/100:   0%|          | 0/44 [00:00<?, ?it/s]

--> Threshold 0.7900000000000003 / Num words to swap 3


Sample 80/100: 100%|██████████| 44/44 [00:25<00:00,  1.76it/s]
Sample 81/100:   0%|          | 0/44 [00:00<?, ?it/s]

--> Threshold 0.8000000000000003 / Num words to swap 3


Sample 81/100: 100%|██████████| 44/44 [00:25<00:00,  1.76it/s]
Sample 82/100:   0%|          | 0/44 [00:00<?, ?it/s]

--> Threshold 0.8100000000000003 / Num words to swap 3


Sample 82/100: 100%|██████████| 44/44 [00:24<00:00,  1.77it/s]
Sample 83/100:   0%|          | 0/44 [00:00<?, ?it/s]

--> Threshold 0.8200000000000003 / Num words to swap 3


Sample 83/100: 100%|██████████| 44/44 [00:24<00:00,  1.78it/s]
Sample 84/100:   0%|          | 0/44 [00:00<?, ?it/s]

--> Threshold 0.8300000000000003 / Num words to swap 3


Sample 84/100: 100%|██████████| 44/44 [00:24<00:00,  1.77it/s]
Sample 85/100:   0%|          | 0/43 [00:00<?, ?it/s]

--> Threshold 0.8400000000000003 / Num words to swap 3


Sample 85/100: 100%|██████████| 43/43 [00:24<00:00,  1.73it/s]
Sample 86/100:   0%|          | 0/40 [00:00<?, ?it/s]

--> Threshold 0.8500000000000003 / Num words to swap 3


Sample 86/100: 100%|██████████| 40/40 [00:24<00:00,  1.65it/s]
Sample 87/100:   0%|          | 0/39 [00:00<?, ?it/s]

--> Threshold 0.8600000000000003 / Num words to swap 3


Sample 87/100: 100%|██████████| 39/39 [00:24<00:00,  1.59it/s]
Sample 88/100:   0%|          | 0/33 [00:00<?, ?it/s]

--> Threshold 0.8700000000000003 / Num words to swap 3


Sample 88/100: 100%|██████████| 33/33 [00:21<00:00,  1.57it/s]
Sample 89/100:   0%|          | 0/27 [00:00<?, ?it/s]

--> Threshold 0.8800000000000003 / Num words to swap 3


Sample 89/100: 100%|██████████| 27/27 [00:18<00:00,  1.46it/s]
Sample 90/100:   0%|          | 0/23 [00:00<?, ?it/s]

--> Threshold 0.8900000000000003 / Num words to swap 3


Sample 90/100: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]
Sample 91/100:   0%|          | 0/15 [00:00<?, ?it/s]

--> Threshold 0.9000000000000004 / Num words to swap 3


Sample 91/100: 100%|██████████| 15/15 [00:13<00:00,  1.15it/s]
Sample 92/100:   0%|          | 0/10 [00:00<?, ?it/s]

--> Threshold 0.9100000000000004 / Num words to swap 3


Sample 92/100: 100%|██████████| 10/10 [00:12<00:00,  1.20s/it]
Sample 93/100:   0%|          | 0/4 [00:00<?, ?it/s]

--> Threshold 0.9200000000000004 / Num words to swap 3


Sample 93/100: 100%|██████████| 4/4 [00:05<00:00,  1.29s/it]
Sample 94/100:   0%|          | 0/3 [00:00<?, ?it/s]

--> Threshold 0.9300000000000004 / Num words to swap 3


Sample 94/100: 100%|██████████| 3/3 [00:04<00:00,  1.50s/it]
Sample 95/100:   0%|          | 0/2 [00:00<?, ?it/s]

--> Threshold 0.9400000000000004 / Num words to swap 3


Sample 95/100: 100%|██████████| 2/2 [00:02<00:00,  1.15s/it]
Sample 96/100:   0%|          | 0/1 [00:00<?, ?it/s]

--> Threshold 0.9500000000000004 / Num words to swap 3


Sample 96/100: 100%|██████████| 1/1 [00:00<00:00,  1.13it/s]
Sample 20/100:  82%|████████▏ | 64/78 [1:02:48<00:29,  2.09s/it]

--> Threshold 0.9600000000000004 / Num words to swap 3
--> Threshold 0.9700000000000004 / Num words to swap 3
--> Threshold 0.9800000000000004 / Num words to swap 3
--> Threshold 0.9900000000000004 / Num words to swap 3


#### Experiment 2

In [None]:
from textattack.constraints.grammaticality.language_models import GPT2
from textattack.constraints.semantics import BERTScore
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.constraints.pre_transformation import InputColumnModification, RepeatModification, StopwordModification
from textattack.datasets import HuggingFaceNlpDataset
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.shared import Attack
from textattack.transformations import WordSwapEmbedding

import numpy as np
import random
import torch
import tqdm

transformation = WordSwapEmbedding()

import textattack
import transformers

# We'll constrain modification of already modified indices and stopwords
constraints = [RepeatModification(), StopwordModification()]
# don't attack premise in entailment
constraints.append(InputColumnModification(["premise", "hypothesis"], {"premise"}))
# use GPT2 to try and make sentences somewhat plausible
constraints.append(GPT2(max_log_prob_diff=2.0))

# Use RTE dataset
dataset = [(x['text'], x['label']) for x in pickle.load(open('mr.pkl', 'rb'))]
random.shuffle(dataset)

data2 = []

all_constraints = ('bertscore', 'use')
threshold_vals = np.arange(.50, 1.0, .01)[::-1] # start with highest constraint level!

all_models = ("textattack/bert-base-uncased-rotten-tomatoes", "textattack/albert-base-v2-rotten-tomatoes", "textattack/distilbert-base-uncased-rotten-tomatoes")
for model_idx, model_path in enumerate(all_models):
  # model to attack & goal function
  print('Model -->', model_path)
  tokenizer = textattack.models.tokenizers.AutoTokenizer(model_path)
  model = transformers.AutoModelForSequenceClassification.from_pretrained(model_path)
  model = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer, batch_size=16)
  goal_function = UntargetedClassification(model)
  for constraint_idx, constraint_name in enumerate(all_constraints):
    print(f'----> constraint {constraint_name}')
    # We know this first-order succeeds fails most of the time, and succeeds more as the 
    # threshold decreases. Any example that succeeds will continue to succeed. We take 
    # advantage of this fact through caching.
    known_success_idxs = set()
    if constraint_name == 'use':
        constraint = UniversalSentenceEncoder(
        threshold=0.0,
        compare_against_original=True,
        skip_text_shorter_than_window=False,
      )
    else:
      constraint = BERTScore(
        min_bert_score=0.0, # don't need this
        model="bert-base-uncased",
        score_type="f1",
        compare_against_original=True,
      )
    for threshold_idx, threshold in enumerate(threshold_vals):
      if constraint_name == 'use':
        constraint.threshold = threshold
      else:
        constraint.min_bert_score = threshold
      print(f'--> Threshold {threshold}')
      these_constraints = constraints + [constraint]
      # search method
      search_method = GreedyWordSwapWIR()
      # Now, let's make the attack from the 4 components:
      attack = Attack(goal_function, these_constraints, transformation, search_method)

      # calculate all the idxs we don't know succeed already and use this to
      # avoid recomputing unnecessary attacks
      idxs_to_attack = set(range(num_samples)) - known_success_idxs
      idxs_to_attack = sorted(list(idxs_to_attack))

      if len(idxs_to_attack):
        this_sample_idx = (constraint_idx * len(threshold_vals)) + threshold_idx + 1
        total_num_samples = len(all_constraints) * len(threshold_vals)
        results_iterable = list(
            tqdm.tqdm(attack.attack_dataset(dataset, indices=idxs_to_attack),
                total=len(idxs_to_attack), 
                position=0, leave=True, 
                desc=f'Sample {this_sample_idx}/{total_num_samples}')
            )
      else:
        results_iterable = []
      
      # num_successes = len([r for r in results_iterable if isinstance(r, textattack.attack_results.SuccessfulAttackResult)])
      for result_idx, result in zip(idxs_to_attack, results_iterable):
        if isinstance(result, textattack.attack_results.SuccessfulAttackResult):
          known_success_idxs.add(result_idx)
      num_successes = len(known_success_idxs)

      data2.append({ 
          'constraint': type(constraint).__name__, 
          'threshold': threshold, 
          'model': model_path,
          # 'results': results_iterable,
          'num_successes': num_successes
        })

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model --> textattack/bert-base-uncased-rotten-tomatoes


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=487.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=48.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=437985387.0, style=ProgressStyle(descri…




[34;1mtextattack[0m: Unknown if model of class <class 'textattack.models.wrappers.huggingface_model_wrapper.HuggingFaceModelWrapper'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.


----> constraint bertscore


Sample 1/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9900000000000004


Sample 1/100: 100%|██████████| 100/100 [04:01<00:00,  2.42s/it]
Sample 2/100:   0%|          | 0/98 [00:00<?, ?it/s]

--> Threshold 0.9800000000000004


Sample 2/100: 100%|██████████| 98/98 [03:50<00:00,  2.35s/it]
Sample 3/100:   0%|          | 0/94 [00:00<?, ?it/s]

--> Threshold 0.9700000000000004


Sample 3/100: 100%|██████████| 94/94 [03:36<00:00,  2.30s/it]
Sample 4/100:   0%|          | 0/88 [00:00<?, ?it/s]

--> Threshold 0.9600000000000004


Sample 4/100: 100%|██████████| 88/88 [03:21<00:00,  2.29s/it]
Sample 5/100:   0%|          | 0/80 [00:00<?, ?it/s]

--> Threshold 0.9500000000000004


Sample 5/100: 100%|██████████| 80/80 [03:08<00:00,  2.36s/it]
Sample 6/100:   0%|          | 0/75 [00:00<?, ?it/s]

--> Threshold 0.9400000000000004


Sample 6/100: 100%|██████████| 75/75 [03:03<00:00,  2.44s/it]
Sample 7/100:   0%|          | 0/74 [00:00<?, ?it/s]

--> Threshold 0.9300000000000004


Sample 7/100: 100%|██████████| 74/74 [02:52<00:00,  2.34s/it]
Sample 8/100:   0%|          | 0/69 [00:00<?, ?it/s]

--> Threshold 0.9200000000000004


Sample 8/100: 100%|██████████| 69/69 [02:44<00:00,  2.38s/it]
Sample 9/100:   0%|          | 0/64 [00:00<?, ?it/s]

--> Threshold 0.9100000000000004


Sample 9/100: 100%|██████████| 64/64 [02:35<00:00,  2.43s/it]
Sample 10/100:   0%|          | 0/60 [00:00<?, ?it/s]

--> Threshold 0.9000000000000004


Sample 10/100: 100%|██████████| 60/60 [02:23<00:00,  2.40s/it]
Sample 11/100:   0%|          | 0/56 [00:00<?, ?it/s]

--> Threshold 0.8900000000000003


Sample 11/100: 100%|██████████| 56/56 [02:11<00:00,  2.35s/it]
Sample 12/100:   0%|          | 0/53 [00:00<?, ?it/s]

--> Threshold 0.8800000000000003


Sample 12/100: 100%|██████████| 53/53 [02:02<00:00,  2.31s/it]
Sample 13/100:   0%|          | 0/49 [00:00<?, ?it/s]

--> Threshold 0.8700000000000003


Sample 13/100: 100%|██████████| 49/49 [01:49<00:00,  2.24s/it]
Sample 14/100:   0%|          | 0/44 [00:00<?, ?it/s]

--> Threshold 0.8600000000000003


Sample 14/100: 100%|██████████| 44/44 [01:40<00:00,  2.28s/it]
Sample 15/100:   0%|          | 0/42 [00:00<?, ?it/s]

--> Threshold 0.8500000000000003


Sample 15/100: 100%|██████████| 42/42 [01:36<00:00,  2.29s/it]
Sample 16/100:   0%|          | 0/40 [00:00<?, ?it/s]

--> Threshold 0.8400000000000003


Sample 16/100: 100%|██████████| 40/40 [01:35<00:00,  2.39s/it]
Sample 17/100:   0%|          | 0/39 [00:00<?, ?it/s]

--> Threshold 0.8300000000000003


Sample 17/100: 100%|██████████| 39/39 [01:31<00:00,  2.35s/it]
Sample 18/100:   0%|          | 0/37 [00:00<?, ?it/s]

--> Threshold 0.8200000000000003


Sample 18/100: 100%|██████████| 37/37 [01:25<00:00,  2.32s/it]
Sample 19/100:   0%|          | 0/36 [00:00<?, ?it/s]

--> Threshold 0.8100000000000003


Sample 19/100: 100%|██████████| 36/36 [01:22<00:00,  2.30s/it]
Sample 20/100:   0%|          | 0/34 [00:00<?, ?it/s]

--> Threshold 0.8000000000000003


Sample 20/100: 100%|██████████| 34/34 [01:15<00:00,  2.22s/it]
Sample 21/100:   0%|          | 0/31 [00:00<?, ?it/s]

--> Threshold 0.7900000000000003


Sample 21/100: 100%|██████████| 31/31 [01:10<00:00,  2.27s/it]
Sample 22/100:   0%|          | 0/30 [00:00<?, ?it/s]

--> Threshold 0.7800000000000002


Sample 22/100: 100%|██████████| 30/30 [01:07<00:00,  2.23s/it]
Sample 23/100:   0%|          | 0/29 [00:00<?, ?it/s]

--> Threshold 0.7700000000000002


Sample 23/100: 100%|██████████| 29/29 [01:04<00:00,  2.22s/it]
Sample 24/100:   0%|          | 0/29 [00:00<?, ?it/s]

--> Threshold 0.7600000000000002


Sample 24/100: 100%|██████████| 29/29 [01:04<00:00,  2.23s/it]
Sample 25/100:   0%|          | 0/28 [00:00<?, ?it/s]

--> Threshold 0.7500000000000002


Sample 25/100: 100%|██████████| 28/28 [01:04<00:00,  2.30s/it]
Sample 26/100:   0%|          | 0/28 [00:00<?, ?it/s]

--> Threshold 0.7400000000000002


Sample 26/100: 100%|██████████| 28/28 [01:03<00:00,  2.27s/it]
Sample 27/100:   0%|          | 0/27 [00:00<?, ?it/s]

--> Threshold 0.7300000000000002


Sample 27/100: 100%|██████████| 27/27 [01:00<00:00,  2.23s/it]
Sample 28/100:   0%|          | 0/26 [00:00<?, ?it/s]

--> Threshold 0.7200000000000002


Sample 28/100: 100%|██████████| 26/26 [00:56<00:00,  2.17s/it]
Sample 29/100:   0%|          | 0/26 [00:00<?, ?it/s]

--> Threshold 0.7100000000000002


Sample 29/100: 100%|██████████| 26/26 [00:56<00:00,  2.18s/it]
Sample 30/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.7000000000000002


Sample 30/100: 100%|██████████| 24/24 [00:50<00:00,  2.09s/it]
Sample 31/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.6900000000000002


Sample 31/100: 100%|██████████| 24/24 [00:50<00:00,  2.09s/it]
Sample 32/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.6800000000000002


Sample 32/100: 100%|██████████| 24/24 [00:50<00:00,  2.09s/it]
Sample 33/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.6700000000000002


Sample 33/100: 100%|██████████| 24/24 [00:50<00:00,  2.10s/it]
Sample 34/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.6600000000000001


Sample 34/100: 100%|██████████| 24/24 [00:50<00:00,  2.11s/it]
Sample 35/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.6500000000000001


Sample 35/100: 100%|██████████| 24/24 [00:49<00:00,  2.07s/it]
Sample 36/100:   0%|          | 0/23 [00:00<?, ?it/s]

--> Threshold 0.6400000000000001


Sample 36/100: 100%|██████████| 23/23 [00:47<00:00,  2.08s/it]
Sample 37/100:   0%|          | 0/23 [00:00<?, ?it/s]

--> Threshold 0.6300000000000001


Sample 37/100: 100%|██████████| 23/23 [00:48<00:00,  2.09s/it]
Sample 38/100:   0%|          | 0/23 [00:00<?, ?it/s]

--> Threshold 0.6200000000000001


Sample 38/100: 100%|██████████| 23/23 [00:47<00:00,  2.08s/it]
Sample 39/100:   0%|          | 0/23 [00:00<?, ?it/s]

--> Threshold 0.6100000000000001


Sample 39/100: 100%|██████████| 23/23 [00:47<00:00,  2.06s/it]
Sample 40/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.6000000000000001


Sample 40/100: 100%|██████████| 22/22 [00:43<00:00,  1.96s/it]
Sample 41/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.5900000000000001


Sample 41/100: 100%|██████████| 22/22 [00:43<00:00,  1.99s/it]
Sample 42/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.5800000000000001


Sample 42/100: 100%|██████████| 22/22 [00:44<00:00,  2.01s/it]
Sample 43/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.5700000000000001


Sample 43/100: 100%|██████████| 22/22 [00:44<00:00,  2.01s/it]
Sample 44/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.56


Sample 44/100: 100%|██████████| 22/22 [00:44<00:00,  2.01s/it]
Sample 45/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.55


Sample 45/100: 100%|██████████| 22/22 [00:44<00:00,  2.01s/it]
Sample 46/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.54


Sample 46/100: 100%|██████████| 22/22 [00:44<00:00,  2.01s/it]
Sample 47/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.53


Sample 47/100: 100%|██████████| 22/22 [00:44<00:00,  2.00s/it]
Sample 48/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.52


Sample 48/100: 100%|██████████| 22/22 [00:43<00:00,  1.98s/it]
Sample 49/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.51


Sample 49/100: 100%|██████████| 22/22 [00:43<00:00,  1.98s/it]
Sample 50/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.5


Sample 50/100: 100%|██████████| 21/21 [00:41<00:00,  1.96s/it]


----> constraint use


Sample 51/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9900000000000004


Sample 51/100: 100%|██████████| 100/100 [02:36<00:00,  1.57s/it]
Sample 52/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9800000000000004


Sample 52/100: 100%|██████████| 100/100 [02:36<00:00,  1.57s/it]
Sample 53/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9700000000000004


Sample 53/100: 100%|██████████| 100/100 [02:37<00:00,  1.58s/it]
Sample 54/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9600000000000004


Sample 54/100: 100%|██████████| 100/100 [02:35<00:00,  1.55s/it]
Sample 55/100:   0%|          | 0/99 [00:00<?, ?it/s]

--> Threshold 0.9500000000000004


Sample 55/100: 100%|██████████| 99/99 [02:37<00:00,  1.59s/it]
Sample 56/100:   0%|          | 0/98 [00:00<?, ?it/s]

--> Threshold 0.9400000000000004


Sample 56/100: 100%|██████████| 98/98 [02:33<00:00,  1.57s/it]
Sample 57/100:   0%|          | 0/96 [00:00<?, ?it/s]

--> Threshold 0.9300000000000004


Sample 57/100: 100%|██████████| 96/96 [02:25<00:00,  1.51s/it]
Sample 58/100:   0%|          | 0/92 [00:00<?, ?it/s]

--> Threshold 0.9200000000000004


Sample 58/100: 100%|██████████| 92/92 [02:23<00:00,  1.56s/it]
Sample 59/100:   0%|          | 0/91 [00:00<?, ?it/s]

--> Threshold 0.9100000000000004


Sample 59/100: 100%|██████████| 91/91 [02:18<00:00,  1.53s/it]
Sample 60/100:   0%|          | 0/87 [00:00<?, ?it/s]

--> Threshold 0.9000000000000004


Sample 60/100: 100%|██████████| 87/87 [02:12<00:00,  1.52s/it]
Sample 61/100:   0%|          | 0/83 [00:00<?, ?it/s]

--> Threshold 0.8900000000000003


Sample 61/100: 100%|██████████| 83/83 [02:09<00:00,  1.56s/it]
Sample 62/100:   0%|          | 0/78 [00:00<?, ?it/s]

--> Threshold 0.8800000000000003


Sample 62/100: 100%|██████████| 78/78 [01:53<00:00,  1.45s/it]
Sample 63/100:   0%|          | 0/69 [00:00<?, ?it/s]

--> Threshold 0.8700000000000003


Sample 63/100: 100%|██████████| 69/69 [01:47<00:00,  1.55s/it]
Sample 64/100:   0%|          | 0/66 [00:00<?, ?it/s]

--> Threshold 0.8600000000000003


Sample 64/100: 100%|██████████| 66/66 [01:41<00:00,  1.54s/it]
Sample 65/100:   0%|          | 0/57 [00:00<?, ?it/s]

--> Threshold 0.8500000000000003


Sample 65/100: 100%|██████████| 57/57 [01:30<00:00,  1.58s/it]
Sample 66/100:   0%|          | 0/55 [00:00<?, ?it/s]

--> Threshold 0.8400000000000003


Sample 66/100: 100%|██████████| 55/55 [01:27<00:00,  1.60s/it]
Sample 67/100:   0%|          | 0/51 [00:00<?, ?it/s]

--> Threshold 0.8300000000000003


Sample 67/100: 100%|██████████| 51/51 [01:21<00:00,  1.59s/it]
Sample 68/100:   0%|          | 0/48 [00:00<?, ?it/s]

--> Threshold 0.8200000000000003


Sample 68/100: 100%|██████████| 48/48 [01:13<00:00,  1.54s/it]
Sample 69/100:   0%|          | 0/45 [00:00<?, ?it/s]

--> Threshold 0.8100000000000003


Sample 69/100: 100%|██████████| 45/45 [01:08<00:00,  1.52s/it]
Sample 70/100:   0%|          | 0/41 [00:00<?, ?it/s]

--> Threshold 0.8000000000000003


Sample 70/100: 100%|██████████| 41/41 [01:05<00:00,  1.59s/it]
Sample 71/100:   0%|          | 0/41 [00:00<?, ?it/s]

--> Threshold 0.7900000000000003


Sample 71/100: 100%|██████████| 41/41 [01:04<00:00,  1.58s/it]
Sample 72/100:   0%|          | 0/38 [00:00<?, ?it/s]

--> Threshold 0.7800000000000002


Sample 72/100: 100%|██████████| 38/38 [01:01<00:00,  1.62s/it]
Sample 73/100:   0%|          | 0/34 [00:00<?, ?it/s]

--> Threshold 0.7700000000000002


Sample 73/100: 100%|██████████| 34/34 [00:52<00:00,  1.55s/it]
Sample 74/100:   0%|          | 0/34 [00:00<?, ?it/s]

--> Threshold 0.7600000000000002


Sample 74/100: 100%|██████████| 34/34 [00:54<00:00,  1.59s/it]
Sample 75/100:   0%|          | 0/32 [00:00<?, ?it/s]

--> Threshold 0.7500000000000002


Sample 75/100: 100%|██████████| 32/32 [00:48<00:00,  1.52s/it]
Sample 76/100:   0%|          | 0/30 [00:00<?, ?it/s]

--> Threshold 0.7400000000000002


Sample 76/100: 100%|██████████| 30/30 [00:47<00:00,  1.58s/it]
Sample 77/100:   0%|          | 0/30 [00:00<?, ?it/s]

--> Threshold 0.7300000000000002


Sample 77/100: 100%|██████████| 30/30 [00:47<00:00,  1.59s/it]
Sample 78/100:   0%|          | 0/28 [00:00<?, ?it/s]

--> Threshold 0.7200000000000002


Sample 78/100: 100%|██████████| 28/28 [00:39<00:00,  1.43s/it]
Sample 79/100:   0%|          | 0/26 [00:00<?, ?it/s]

--> Threshold 0.7100000000000002


Sample 79/100: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
Sample 80/100:   0%|          | 0/25 [00:00<?, ?it/s]

--> Threshold 0.7000000000000002


Sample 80/100: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
Sample 81/100:   0%|          | 0/25 [00:00<?, ?it/s]

--> Threshold 0.6900000000000002


Sample 81/100: 100%|██████████| 25/25 [00:36<00:00,  1.47s/it]
Sample 82/100:   0%|          | 0/25 [00:00<?, ?it/s]

--> Threshold 0.6800000000000002


Sample 82/100: 100%|██████████| 25/25 [00:36<00:00,  1.46s/it]
Sample 83/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.6700000000000002


Sample 83/100: 100%|██████████| 24/24 [00:35<00:00,  1.49s/it]
Sample 84/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.6600000000000001


Sample 84/100: 100%|██████████| 24/24 [00:35<00:00,  1.48s/it]
Sample 85/100:   0%|          | 0/23 [00:00<?, ?it/s]

--> Threshold 0.6500000000000001


Sample 85/100: 100%|██████████| 23/23 [00:32<00:00,  1.43s/it]
Sample 86/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.6400000000000001


Sample 86/100: 100%|██████████| 22/22 [00:32<00:00,  1.47s/it]
Sample 87/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.6300000000000001


Sample 87/100: 100%|██████████| 22/22 [00:32<00:00,  1.49s/it]
Sample 88/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.6200000000000001


Sample 88/100: 100%|██████████| 22/22 [00:32<00:00,  1.48s/it]
Sample 89/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.6100000000000001


Sample 89/100: 100%|██████████| 21/21 [00:30<00:00,  1.44s/it]
Sample 90/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.6000000000000001


Sample 90/100: 100%|██████████| 21/21 [00:30<00:00,  1.43s/it]
Sample 91/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.5900000000000001


Sample 91/100: 100%|██████████| 21/21 [00:30<00:00,  1.43s/it]
Sample 92/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.5800000000000001


Sample 92/100: 100%|██████████| 21/21 [00:30<00:00,  1.43s/it]
Sample 93/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.5700000000000001


Sample 93/100: 100%|██████████| 21/21 [00:30<00:00,  1.43s/it]
Sample 94/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.56


Sample 94/100: 100%|██████████| 21/21 [00:29<00:00,  1.42s/it]
Sample 95/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.55


Sample 95/100: 100%|██████████| 21/21 [00:30<00:00,  1.43s/it]
Sample 96/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.54


Sample 96/100: 100%|██████████| 21/21 [00:29<00:00,  1.43s/it]
Sample 97/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.53


Sample 97/100: 100%|██████████| 21/21 [00:30<00:00,  1.44s/it]
Sample 98/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.52


Sample 98/100: 100%|██████████| 21/21 [00:29<00:00,  1.42s/it]
Sample 99/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.51


Sample 99/100: 100%|██████████| 21/21 [00:30<00:00,  1.43s/it]
Sample 100/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.5


Sample 100/100: 100%|██████████| 21/21 [00:29<00:00,  1.43s/it]


Model --> textattack/albert-base-v2-rotten-tomatoes


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=738.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760289.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=156.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=25.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=46747112.0, style=ProgressStyle(descrip…




[34;1mtextattack[0m: Unknown if model of class <class 'textattack.models.wrappers.huggingface_model_wrapper.HuggingFaceModelWrapper'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.


----> constraint bertscore


Sample 1/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9900000000000004


Sample 1/100: 100%|██████████| 100/100 [04:01<00:00,  2.42s/it]
Sample 2/100:   0%|          | 0/97 [00:00<?, ?it/s]

--> Threshold 0.9800000000000004


Sample 2/100: 100%|██████████| 97/97 [03:56<00:00,  2.44s/it]
Sample 3/100:   0%|          | 0/94 [00:00<?, ?it/s]

--> Threshold 0.9700000000000004


Sample 3/100: 100%|██████████| 94/94 [03:48<00:00,  2.43s/it]
Sample 4/100:   0%|          | 0/88 [00:00<?, ?it/s]

--> Threshold 0.9600000000000004


Sample 4/100: 100%|██████████| 88/88 [03:26<00:00,  2.34s/it]
Sample 5/100:   0%|          | 0/81 [00:00<?, ?it/s]

--> Threshold 0.9500000000000004


Sample 5/100: 100%|██████████| 81/81 [03:03<00:00,  2.27s/it]
Sample 6/100:   0%|          | 0/74 [00:00<?, ?it/s]

--> Threshold 0.9400000000000004


Sample 6/100: 100%|██████████| 74/74 [02:53<00:00,  2.34s/it]
Sample 7/100:   0%|          | 0/71 [00:00<?, ?it/s]

--> Threshold 0.9300000000000004


Sample 7/100: 100%|██████████| 71/71 [02:45<00:00,  2.33s/it]
Sample 8/100:   0%|          | 0/67 [00:00<?, ?it/s]

--> Threshold 0.9200000000000004


Sample 8/100: 100%|██████████| 67/67 [02:36<00:00,  2.33s/it]
Sample 9/100:   0%|          | 0/62 [00:00<?, ?it/s]

--> Threshold 0.9100000000000004


Sample 9/100: 100%|██████████| 62/62 [02:17<00:00,  2.22s/it]
Sample 10/100:   0%|          | 0/56 [00:00<?, ?it/s]

--> Threshold 0.9000000000000004


Sample 10/100: 100%|██████████| 56/56 [02:02<00:00,  2.19s/it]
Sample 11/100:   0%|          | 0/54 [00:00<?, ?it/s]

--> Threshold 0.8900000000000003


Sample 11/100: 100%|██████████| 54/54 [01:56<00:00,  2.16s/it]
Sample 12/100:   0%|          | 0/52 [00:00<?, ?it/s]

--> Threshold 0.8800000000000003


Sample 12/100: 100%|██████████| 52/52 [01:55<00:00,  2.22s/it]
Sample 13/100:   0%|          | 0/50 [00:00<?, ?it/s]

--> Threshold 0.8700000000000003


Sample 13/100: 100%|██████████| 50/50 [01:48<00:00,  2.17s/it]
Sample 14/100:   0%|          | 0/45 [00:00<?, ?it/s]

--> Threshold 0.8600000000000003


Sample 14/100: 100%|██████████| 45/45 [01:34<00:00,  2.09s/it]
Sample 15/100:   0%|          | 0/40 [00:00<?, ?it/s]

--> Threshold 0.8500000000000003


Sample 15/100: 100%|██████████| 40/40 [01:27<00:00,  2.20s/it]
Sample 16/100:   0%|          | 0/39 [00:00<?, ?it/s]

--> Threshold 0.8400000000000003


Sample 16/100: 100%|██████████| 39/39 [01:26<00:00,  2.21s/it]
Sample 17/100:   0%|          | 0/38 [00:00<?, ?it/s]

--> Threshold 0.8300000000000003


Sample 17/100: 100%|██████████| 38/38 [01:19<00:00,  2.09s/it]
Sample 18/100:   0%|          | 0/34 [00:00<?, ?it/s]

--> Threshold 0.8200000000000003


Sample 18/100: 100%|██████████| 34/34 [01:10<00:00,  2.09s/it]
Sample 19/100:   0%|          | 0/32 [00:00<?, ?it/s]

--> Threshold 0.8100000000000003


Sample 19/100: 100%|██████████| 32/32 [01:10<00:00,  2.19s/it]
Sample 20/100:   0%|          | 0/32 [00:00<?, ?it/s]

--> Threshold 0.8000000000000003


Sample 20/100: 100%|██████████| 32/32 [01:11<00:00,  2.23s/it]
Sample 21/100:   0%|          | 0/31 [00:00<?, ?it/s]

--> Threshold 0.7900000000000003


Sample 21/100: 100%|██████████| 31/31 [01:04<00:00,  2.08s/it]
Sample 22/100:   0%|          | 0/28 [00:00<?, ?it/s]

--> Threshold 0.7800000000000002


Sample 22/100: 100%|██████████| 28/28 [00:53<00:00,  1.92s/it]
Sample 23/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.7700000000000002


Sample 23/100: 100%|██████████| 24/24 [00:41<00:00,  1.74s/it]
Sample 24/100:   0%|          | 0/24 [00:00<?, ?it/s]

--> Threshold 0.7600000000000002


Sample 24/100: 100%|██████████| 24/24 [00:41<00:00,  1.71s/it]
Sample 25/100:   0%|          | 0/22 [00:00<?, ?it/s]

--> Threshold 0.7500000000000002


Sample 25/100: 100%|██████████| 22/22 [00:37<00:00,  1.73s/it]
Sample 26/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.7400000000000002


Sample 26/100: 100%|██████████| 21/21 [00:37<00:00,  1.78s/it]
Sample 27/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.7300000000000002


Sample 27/100: 100%|██████████| 21/21 [00:37<00:00,  1.78s/it]
Sample 28/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.7200000000000002


Sample 28/100: 100%|██████████| 21/21 [00:37<00:00,  1.80s/it]
Sample 29/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.7100000000000002


Sample 29/100: 100%|██████████| 21/21 [00:38<00:00,  1.82s/it]
Sample 30/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.7000000000000002


Sample 30/100: 100%|██████████| 21/21 [00:38<00:00,  1.82s/it]
Sample 31/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.6900000000000002


Sample 31/100: 100%|██████████| 21/21 [00:38<00:00,  1.84s/it]
Sample 32/100:   0%|          | 0/21 [00:00<?, ?it/s]

--> Threshold 0.6800000000000002


Sample 32/100: 100%|██████████| 21/21 [00:38<00:00,  1.83s/it]
Sample 33/100:   0%|          | 0/20 [00:00<?, ?it/s]

--> Threshold 0.6700000000000002


Sample 33/100: 100%|██████████| 20/20 [00:34<00:00,  1.71s/it]
Sample 34/100:   0%|          | 0/19 [00:00<?, ?it/s]

--> Threshold 0.6600000000000001


Sample 34/100: 100%|██████████| 19/19 [00:33<00:00,  1.75s/it]
Sample 35/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.6500000000000001


Sample 35/100: 100%|██████████| 18/18 [00:30<00:00,  1.71s/it]
Sample 36/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.6400000000000001


Sample 36/100: 100%|██████████| 18/18 [00:30<00:00,  1.71s/it]
Sample 37/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.6300000000000001


Sample 37/100: 100%|██████████| 18/18 [00:30<00:00,  1.71s/it]
Sample 38/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.6200000000000001


Sample 38/100: 100%|██████████| 18/18 [00:30<00:00,  1.72s/it]
Sample 39/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.6100000000000001


Sample 39/100: 100%|██████████| 18/18 [00:30<00:00,  1.72s/it]
Sample 40/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.6000000000000001


Sample 40/100: 100%|██████████| 18/18 [00:31<00:00,  1.73s/it]
Sample 41/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.5900000000000001


Sample 41/100: 100%|██████████| 18/18 [00:31<00:00,  1.73s/it]
Sample 42/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.5800000000000001


Sample 42/100: 100%|██████████| 18/18 [00:31<00:00,  1.72s/it]
Sample 43/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.5700000000000001


Sample 43/100: 100%|██████████| 18/18 [00:31<00:00,  1.73s/it]
Sample 44/100:   0%|          | 0/18 [00:00<?, ?it/s]

--> Threshold 0.56


Sample 44/100: 100%|██████████| 18/18 [00:31<00:00,  1.73s/it]
Sample 45/100:   0%|          | 0/17 [00:00<?, ?it/s]

--> Threshold 0.55


Sample 45/100: 100%|██████████| 17/17 [00:28<00:00,  1.67s/it]
Sample 46/100:   0%|          | 0/17 [00:00<?, ?it/s]

--> Threshold 0.54


Sample 46/100: 100%|██████████| 17/17 [00:28<00:00,  1.67s/it]
Sample 47/100:   0%|          | 0/17 [00:00<?, ?it/s]

--> Threshold 0.53


Sample 47/100: 100%|██████████| 17/17 [00:28<00:00,  1.67s/it]
Sample 48/100:   0%|          | 0/17 [00:00<?, ?it/s]

--> Threshold 0.52


Sample 48/100: 100%|██████████| 17/17 [00:28<00:00,  1.69s/it]
Sample 49/100:   0%|          | 0/17 [00:00<?, ?it/s]

--> Threshold 0.51


Sample 49/100: 100%|██████████| 17/17 [00:30<00:00,  1.77s/it]
Sample 50/100:   0%|          | 0/17 [00:00<?, ?it/s]

--> Threshold 0.5


Sample 50/100: 100%|██████████| 17/17 [00:29<00:00,  1.72s/it]


----> constraint use


Sample 51/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9900000000000004


Sample 51/100: 100%|██████████| 100/100 [02:41<00:00,  1.61s/it]
Sample 52/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9800000000000004


Sample 52/100: 100%|██████████| 100/100 [02:39<00:00,  1.60s/it]
Sample 53/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9700000000000004


Sample 53/100: 100%|██████████| 100/100 [02:38<00:00,  1.59s/it]
Sample 54/100:   0%|          | 0/100 [00:00<?, ?it/s]

--> Threshold 0.9600000000000004


Sample 54/100:  39%|███▉      | 39/100 [01:00<01:39,  1.63s/it]

In [None]:
##
## PLOT AUC AND STUFF
##

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import sklearn

sns.set()

second_order_df = pd.DataFrame(data) \
                    .rename(columns={'num_successes': 'num_second_order_successes'}) \
                    .drop('num_words_to_swap', axis=1)
first_order_df = pd.DataFrame(data2) \
                    .rename(columns={'num_successes': 'num_first_order_successes'})

full_df = pd.merge(first_order_df, second_order_df)

df_models = full_df['model'].unique()
fig, ax = plt.subplots(1, len(df_models), figsize=(24, 6))

for model_idx, model in enumerate(df_models):
  df = full_df[full_df['model'] == model].reset_index()
  # Add datapoints where eps=1.0 and success rate is zero, instead of actually running
  # futile attacks.
  for constraint in df['constraint'].unique():
    zero_data_point = df[df['constraint'] == constraint] \
                         [df['num_first_order_successes'] == 0.0] \
                         [df['num_second_order_successes'] == 0.0]
    if zero_data_point.empty:
      extra_data_point = { 'constraint': constraint, 'num_first_order_successes': 0.0, 'num_second_order_successes': 0.0 } # The \eps=1.0 datapoint
      extra_data_point_row = [extra_data_point.get(c) for c in df.columns]
      df.loc[-1] = extra_data_point_row # add row
      df.index = df.index + 1  # shifting index
      df.sort_index(inplace=True) 

  # Calculate rate in terms of num successes
  df['first_order_success_rate'] = df['num_first_order_successes'] / num_samples
  df['second_order_success_rate'] = df['num_second_order_successes'] / num_samples

  labels = []
  # Calculate AUC.
  for constraint in df['constraint'].unique():
    x = df[df['constraint'] == constraint]['second_order_success_rate']
    y = df[df['constraint'] == constraint]['first_order_success_rate']
    auc = sklearn.metrics.auc(x, y)
    accs = auc / (max(x) * max(y))
    print(constraint)
    print(f'--> AUC: {auc}')
    print(f'--> ACCS: {accs}')
    
    labels.append(f'{constraint} (ACCS = {accs:.3f})')
  print('labels', labels)
  # Plot curve.
  sns.lineplot(df['second_order_success_rate'], df['first_order_success_rate'], 
                hue=df['constraint'], ci=0, ax=ax[model_idx], lw=4)

  ax[model_idx].get_lines()[1].set_linestyle('--')
  
  ax[model_idx].legend(labels, loc=4, prop={'size': 18}) # bottom right, see https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.legend.html
  ax[model_idx].set_title(model, fontsize=22)
  ax[model_idx].set_xlabel('Second-order attack success rate', fontsize=18)
  ax[model_idx].set_ylabel('First-order attack success rate',  fontsize=18)
  # ax[model_idx].set_xlabel('')
  # ax[model_idx].set_ylabel('')

plt.tight_layout()
plt.savefig('attack_curve_rte.pdf')

In [None]:
df.plot(x='threshold', y='num_second_order_successes')

In [None]:
##
## PLOT ATTACK SUCCESS AND STUFF
##

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import sklearn

df_models = full_df['model'].unique()
fig, ax = plt.subplots(1, len(df_models), figsize=(24, 6))

LINE_WIDTH = 2.5    # thickish
COLOR_1 = "#9b59b6" # purplish
COLOR_2 = "#34495e" # blackish

print('models -> ', df_models)
for model_idx, model in enumerate(df_models):
  df = full_df[full_df['model'] == model].reset_index()

  # Calculate rate in terms of num successes
  df['first_order_success_rate'] = df['num_first_order_successes'] / num_samples
  df['second_order_success_rate'] = df['num_second_order_successes'] / num_samples
  
  # Add datapoints where eps=1.0 and success rate is zero, instead of actually running
  # futile attacks.
  for constraint in df['constraint'].unique():
    zero_data_point = df[df['constraint'] == constraint] \
                         [df['num_first_order_successes'] == 0.0] \
                         [df['num_second_order_successes'] == 0.0]
    if zero_data_point.empty:
      extra_data_point = { 'constraint': constraint, 'num_first_order_successes': 0.0, 'num_second_order_successes': 0.0 } # The \eps=1.0 datapoint
      extra_data_point_row = [extra_data_point.get(c) for c in df.columns]
      df.loc[-1] = extra_data_point_row # add row
      df.index = df.index + 1  # shifting index
      df.sort_index(inplace=True)

  # Plot curve.
  sns.lineplot(df['threshold'].astype(float), df['first_order_success_rate'].astype(float), 
                color=COLOR_1, ci=0, lw=LINE_WIDTH, ax=ax[model_idx])
  
  sns.lineplot(df['threshold'].astype(float), df['second_order_success_rate'].astype(float), 
                color=COLOR_2, ci=0, lw=LINE_WIDTH, ax=ax[model_idx])
  
  # ax[model_idx].get_lines()[-1].set_linestyle('--')
  ax[model_idx].legend(['First-order attack', 'Second-order attack'], loc=4) # bottom right, see https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.legend.html
  ax[model_idx].set_title(model)

  ax[model_idx].set_xlabel('Threshold (ε)')
  ax[model_idx].set_ylabel('Attack Success Rate')

  # plt.xlabel('Second-order attack success rate')
  # plt.ylabel('First-order attack success rate')
  # plt.title('SNLI')
  plt.savefig('attack_successes_rte.pdf')