In [1]:
%%capture
!pip install --upgrade -tensorflow_hub
# !pip install -U -huggingface_hub

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="4"

import textattack
import transformers
import torch
import time
from datasets import Dataset
import sys
import hashlib
import numpy as np

from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertForMaskedLM, pipeline
from textattack.attack_recipes import (
    TextBuggerLi2018, DeepWordBugGao2018, TextFoolerJin2019, BERTAttackLi2020
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.models.wrappers import ModelWrapper

sys.path.append('../')
from eval_utils import *
sys.path.pop()

2023-08-05 17:39:50.764435: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# set a seed, because reproducability is cool
np.random.seed(int(hashlib.sha256('Harrison Gietz'.encode('utf-8')).hexdigest(), 16) % 2**32)
torch.cuda.empty_cache()

device = input('enter a device name to run on: ')
dataset_val = input('Enter the number of samples to run on (100 or 1000): ')
defense = input('Specify a defense type among "default", "logit", "maj_log", "one_hot": ')

ag_tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-ag-news")
ag_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-ag-news")
ag_model.to(device)
ag_pipeline = pipeline('sentiment-analysis', model=ag_model, tokenizer=ag_tokenizer)
ag_pipeline.device = next(ag_model.parameters()).device

ag_model_directory = "../../../models/bert-uncased_maskedlm_agnews_july31" #first diff
finetuned_ag_maskedlm = BertForMaskedLM.from_pretrained(ag_model_directory)
finetuned_ag_maskedlm.to(device)
ag_fill_mask = pipeline("fill-mask", model=finetuned_ag_maskedlm, tokenizer=ag_tokenizer)
ag_fill_mask.device = next(ag_model.parameters()).device

num_voter = 11
mask_pct = 0.3    
    
attack = DeepWordBugGao2018

if dataset_val == '100':
    loaded_ag_100 = Dataset.load_from_disk('../data/filtered_ag_clean_100')
    ag_100 = textattack.datasets.Dataset(convert_to_tuples(loaded_ag_100))
    dataset = ag_100
    dataset_name = 'ag-news100'
elif dataset_val =='1000':
    loaded_ag_1000 = Dataset.load_from_disk('../data/filtered_ag_clean_1000')
    ag_1000 = textattack.datasets.Dataset(convert_to_tuples(loaded_ag_1000))
    dataset = ag_1000
    dataset_name = 'ag-news1000'
else:
    raise ValueError('Number of samples not supported')
    
if defense == "default":
    ag_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(ag_model, ag_tokenizer)
    print(ag_wrapper)
elif defense == "logit":
    ag_wrapper = MaskDemaskWrapper(ag_model, ag_tokenizer, ag_fill_mask, num_voter, mask_pct, 'logit')
elif defense == 'maj_log':
    ag_wrapper = MaskDemaskWrapper(ag_model, ag_tokenizer, ag_fill_mask, num_voter, mask_pct, 'maj_log')
elif defense == "one_hot":
    ag_wrapper = MaskDemaskWrapper(ag_model, ag_tokenizer, ag_fill_mask, num_voter, mask_pct, 'maj_one_hot')
else:
    raise ValueError('Not a valid defense type.')
    
print(f'using num_voter = {num_voter} and mask_pct = {mask_pct} with dataset = {dataset_name}...')

# Parse the attack name
attack_name = parse_attack_name(attack)
attack = attack.build(ag_wrapper)

# Set up arguments for the attack
attack_args = textattack.AttackArgs(
    num_examples=len(dataset),
    log_to_csv=f'{attack_name}_{dataset_name}_mp{mask_pct}_nv{num_voter}_{defense}_log.csv',
    checkpoint_interval=25, 
    checkpoint_dir="chkpts_2", 
    disable_stdout=True
)
# Perform the attack and save the results
attacker = textattack.Attacker(attack, dataset, attack_args)
attacker.attack_dataset()

print(f'The above are results for {attack_name}_{dataset_name}_mp{mask_pct}_nv{num_voter}_{defense}.')

enter a device name to run on: cuda:0
Enter the number of samples to run on (100 or 1000): 1000
Specify a defense type among "default", "logit", "maj_log", "one_hot": maj_log
using num_voter = 11 and mask_pct = 0.3 with dataset = ag-news1000...
Attack(
  (search_method): GreedyWordSwapWIR(
    (wir_method):  unk
  )
  (goal_function):  UntargetedClassification
  (transformation):  CompositeTransformation(
    (0): WordSwapNeighboringCharacterSwap(
        (random_one):  True
      )
    (1): WordSwapRandomCharacterSubstitution(
        (random_one):  True
      )
    (2): WordSwapRandomCharacterDeletion(
        (random_one):  True
      )
    (3): WordSwapRandomCharacterInsertion(
        (random_one):  True
      )
    )
  (constraints): 
    (0): LevenshteinEditDistance(
        (max_edit_distance):  30
        (compare_against_original):  True
      )
    (1): RepeatModification
    (2): StopwordModification
  (is_black_box):  True
) 



[Succeeded / Failed / Skipped / Total] 4 / 19 / 2 / 25:   2%|▎         | 25/1000 [26:42<17:21:40, 64.10s/it]






[Succeeded / Failed / Skipped / Total] 8 / 38 / 4 / 50:   5%|▌         | 50/1000 [48:20<15:18:38, 58.02s/it]






[Succeeded / Failed / Skipped / Total] 16 / 53 / 6 / 75:   8%|▊         | 75/1000 [1:06:49<13:44:16, 53.47s/it]






[Succeeded / Failed / Skipped / Total] 19 / 72 / 9 / 100:  10%|█         | 100/1000 [1:27:22<13:06:22, 52.43s/it]






[Succeeded / Failed / Skipped / Total] 22 / 93 / 10 / 125:  12%|█▎        | 125/1000 [1:51:20<12:59:25, 53.45s/it]






[Succeeded / Failed / Skipped / Total] 27 / 110 / 13 / 150:  15%|█▌        | 150/1000 [2:13:20<12:35:37, 53.34s/it]






[Succeeded / Failed / Skipped / Total] 32 / 129 / 14 / 175:  18%|█▊        | 175/1000 [2:36:25<12:17:27, 53.63s/it]






[Succeeded / Failed / Skipped / Total] 36 / 149 / 15 / 200:  20%|██        | 200/1000 [3:06:34<12:26:17, 55.97s/it]






[Succeeded / Failed / Skipped / Total] 36 / 173 / 16 / 225:  22%|██▎       | 225/1000 [3:57:50<13:39:13, 63.42s/it]






[Succeeded / Failed / Skipped / Total] 42 / 191 / 17 / 250:  25%|██▌       | 250/1000 [4:41:39<14:04:58, 67.60s/it]






[Succeeded / Failed / Skipped / Total] 50 / 208 / 17 / 275:  28%|██▊       | 275/1000 [5:19:21<14:01:57, 69.68s/it]






[Succeeded / Failed / Skipped / Total] 56 / 226 / 18 / 300:  30%|███       | 300/1000 [5:56:39<13:52:11, 71.33s/it]






[Succeeded / Failed / Skipped / Total] 60 / 245 / 20 / 325:  32%|███▎      | 325/1000 [6:34:42<13:39:47, 72.87s/it]






[Succeeded / Failed / Skipped / Total] 64 / 265 / 21 / 350:  35%|███▌      | 350/1000 [7:23:54<13:44:23, 76.10s/it]






[Succeeded / Failed / Skipped / Total] 68 / 285 / 22 / 375:  38%|███▊      | 375/1000 [8:26:42<14:04:30, 81.07s/it]






[Succeeded / Failed / Skipped / Total] 73 / 304 / 23 / 400:  40%|████      | 400/1000 [9:22:24<14:03:36, 84.36s/it]






[Succeeded / Failed / Skipped / Total] 77 / 324 / 24 / 425:  42%|████▎     | 425/1000 [10:37:00<14:21:50, 89.93s/it]






[Succeeded / Failed / Skipped / Total] 81 / 343 / 26 / 450:  45%|████▌     | 450/1000 [11:34:37<14:08:58, 92.62s/it]






[Succeeded / Failed / Skipped / Total] 85 / 363 / 27 / 475:  48%|████▊     | 475/1000 [12:23:45<13:42:03, 93.95s/it]






[Succeeded / Failed / Skipped / Total] 90 / 379 / 31 / 500:  50%|█████     | 500/1000 [13:13:29<13:13:29, 95.22s/it]






[Succeeded / Failed / Skipped / Total] 94 / 400 / 31 / 525:  52%|█████▎    | 525/1000 [14:09:26<12:48:32, 97.08s/it]






[Succeeded / Failed / Skipped / Total] 96 / 422 / 32 / 550:  55%|█████▌    | 550/1000 [15:00:57<12:17:08, 98.29s/it]






[Succeeded / Failed / Skipped / Total] 97 / 445 / 33 / 575:  57%|█████▊    | 575/1000 [15:58:39<11:48:34, 100.03s/it]






[Succeeded / Failed / Skipped / Total] 102 / 465 / 33 / 600:  60%|██████    | 600/1000 [16:50:03<11:13:22, 101.01s/it]






[Succeeded / Failed / Skipped / Total] 102 / 488 / 35 / 625:  62%|██████▎   | 625/1000 [17:39:15<10:35:33, 101.69s/it]






[Succeeded / Failed / Skipped / Total] 105 / 510 / 35 / 650:  65%|██████▌   | 650/1000 [18:30:01<9:57:42, 102.46s/it] 






[Succeeded / Failed / Skipped / Total] 108 / 532 / 35 / 675:  68%|██████▊   | 675/1000 [19:26:19<9:21:33, 103.67s/it]






[Succeeded / Failed / Skipped / Total] 110 / 555 / 35 / 700:  70%|███████   | 700/1000 [20:29:53<8:47:05, 105.42s/it]






[Succeeded / Failed / Skipped / Total] 115 / 575 / 35 / 725:  72%|███████▎  | 725/1000 [21:17:50<8:04:41, 105.75s/it]






[Succeeded / Failed / Skipped / Total] 120 / 594 / 36 / 750:  75%|███████▌  | 750/1000 [22:03:01<7:21:00, 105.84s/it]






[Succeeded / Failed / Skipped / Total] 124 / 614 / 37 / 775:  78%|███████▊  | 775/1000 [22:45:20<6:36:23, 105.70s/it]






[Succeeded / Failed / Skipped / Total] 131 / 632 / 37 / 800:  80%|████████  | 800/1000 [23:31:00<5:52:45, 105.83s/it]






[Succeeded / Failed / Skipped / Total] 138 / 650 / 37 / 825:  82%|████████▎ | 825/1000 [24:20:24<5:09:46, 106.21s/it]






[Succeeded / Failed / Skipped / Total] 142 / 671 / 37 / 850:  85%|████████▌ | 850/1000 [25:13:29<4:27:05, 106.83s/it]






[Succeeded / Failed / Skipped / Total] 147 / 690 / 38 / 875:  88%|████████▊ | 875/1000 [25:59:49<3:42:49, 106.96s/it]






[Succeeded / Failed / Skipped / Total] 152 / 710 / 38 / 900:  90%|█████████ | 900/1000 [26:46:33<2:58:30, 107.10s/it]






[Succeeded / Failed / Skipped / Total] 156 / 731 / 38 / 925:  92%|█████████▎| 925/1000 [27:38:30<2:14:28, 107.58s/it]






[Succeeded / Failed / Skipped / Total] 159 / 752 / 39 / 950:  95%|█████████▌| 950/1000 [28:31:47<1:30:05, 108.11s/it]






[Succeeded / Failed / Skipped / Total] 164 / 771 / 40 / 975:  98%|█████████▊| 975/1000 [29:21:06<45:09, 108.38s/it]  






[Succeeded / Failed / Skipped / Total] 167 / 792 / 41 / 1000: 100%|██████████| 1000/1000 [30:12:30<00:00, 108.75s/it]






[Succeeded / Failed / Skipped / Total] 167 / 792 / 41 / 1000: 100%|██████████| 1000/1000 [30:12:31<00:00, 108.75s/it]



+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 167    |
| Number of failed attacks:     | 792    |
| Number of skipped attacks:    | 41     |
| Original accuracy:            | 95.9%  |
| Accuracy under attack:        | 79.2%  |
| Attack success rate:          | 17.41% |
| Average perturbed word %:     | 7.94%  |
| Average num. words per input: | 38.43  |
| Avg num queries:              | 121.93 |
+-------------------------------+--------+
The above are results for DeepWordBugGao2018_ag-news1000_mp0.3_nv11_maj_log.
