# FastAttack - A fast and efficient attack framework on text classifiers

## Installing dependencies

In [None]:
!pip install textattack
!pip install -U gensim==4.0.0
!pip install python-Levenshtein

## For training a fasttext/word2vec embedding

In [None]:
# import gensim.downloader as api
# corpus = api.load('text8') # text8 is the corpus
# from gensim.models.fasttext import FastText
# model1 = FastText(corpus, sample = 0, sg = 1) # sg = Skipgram. Default: CBOW

In [None]:
# import gensim.downloader as api
# corpus = api.load('text8')
# from gensim.models.word2vec import Word2Vec
# model2 = Word2Vec(corpus, sample = 0, sg = 1)

## For loading a fasttext/word2vec embedding

In [None]:
from gensim.test.utils import get_tmpfile
from gensim.models.fasttext import FastText 
model1 = FastText.load("FastAttack-models/fasttext.model")

In [None]:
# from gensim.test.utils import get_tmpfile
# from gensim.models.word2vec import Word2Vec
# model2 = Word2Vec.load("FastAttack-models/word2vec.model")

In [None]:
model1.wv.most_similar('cow')

Change the target model and dataset below:

In [None]:
# Import the model
import transformers
from textattack.models.tokenizers import AutoTokenizer
from textattack.models.wrappers import HuggingFaceModelWrapper

model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-ag-news") # Change the target model here
tokenizer = AutoTokenizer("textattack/bert-base-uncased-ag-news") # Will be the same as the target model 
# Change these

model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
dataset = HuggingFaceDataset("ag_news", None, "test") # Change the dataset here.

#Framing FastAttack

## Defining our transformation. We choose the most appropriate replacement from a list of 10 closest neighbors of a given word in the embedding.

In [None]:
from textattack.transformations import WordSwap

class Swapper(WordSwap): # For fasttext
    """ Transforms an input by replacing any word with its most similar counterpart
    """

    # We don't need a constructor, since our class doesn't require any parameters.

    def _get_replacement_words(self, word):
        for i in range(10):
            if word.lower() in model1.wv.most_similar(word)[i][0].lower():
                continue # Don't return a word containing the exact word
            elif word.isupper():
                return [model1.wv.most_similar(word)[i][0].upper()] # Preserving case
            elif word[0].isupper():
                return [model1.wv.most_similar(word)[i][0].capitalize()] # Preserving Capitalization in words
            else:
                return [model1.wv.most_similar(word)[i][0]]
            
        return [model1.wv.most_similar(word)[0][0]]

## Constructing our attack

In [None]:
from textattack.search_methods import GreedySearch
from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
from textattack.shared import Attack
from textattack.goal_functions import UntargetedClassification
from textattack.datasets import HuggingFaceDataset

# We're going to use our word swap class as the attack transformation.
transformation = Swapper()
# We'll constrain modification of already modified indices and stopwords
constraints = [RepeatModification(),
               StopwordModification()]
# We'll use the Greedy search method
search_method = GreedySearch()
# Create the goal function using the model

goal_function = UntargetedClassification(model_wrapper)

# Now, let's make the attack from the 4 components:
attack = Attack(goal_function, constraints, transformation, search_method)

In [None]:
print(attack)

## Printing results

Change the number of examples below:

In [None]:
from collections import deque

import textattack
import tqdm
import time
from IPython.display import display, HTML
num_examples = 25 # Number of examples to attack
num_remaining_attacks = num_examples
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)

worklist = deque(range(0, num_examples))
worklist_tail = worklist[-1]

attack_log_manager = textattack.loggers.AttackLogManager()

load_time = time.time()
        
num_results = 0
num_failures = 0
num_successes = 0
for result in attack.attack_dataset(dataset, indices=worklist):
    result_html_str = result.__str__(color_method="html").replace("\n\n", "<br>")
    display(HTML(result_html_str))
    attack_log_manager.log_result(result)
            
    if not isinstance(result, textattack.attack_results.SkippedAttackResult):
        pbar.update(1)
    else:
        worklist_tail += 1
        pbar.update(1)
        worklist.append(worklist_tail)

    num_results += 1

    if (
        type(result) == textattack.attack_results.SuccessfulAttackResult
        or type(result) == textattack.attack_results.MaximizedAttackResult
    ):
        num_successes += 1
                
    if type(result) == textattack.attack_results.FailedAttackResult:
        num_failures += 1
    pbar.set_description(
        "[Succeeded / Failed / Total] {} / {} / {}".format(
            num_successes, num_failures, num_results
        )
    )

pbar.close()

attack_log_manager.enable_stdout()
attack_log_manager.log_summary()
attack_log_manager.flush()
        
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")