In [2]:
# !pip install textattack

In [51]:

# from abc import ABC, abstractmethod


# class ModelWrapper(ABC):
#     """A model wrapper queries a model with a list of text inputs.

#     Classification-based models return a list of lists, where each sublist
#     represents the model's scores for a given input.

#     Text-to-text models return a list of strings, where each string is the
#     output – like a translation or summarization – for a given input.
#     """

#     @abstractmethod
#     def __call__(self, text_input_list, **kwargs):
#         raise NotImplementedError()

#     def get_grad(self, text_input):
#         """Get gradient of loss with respect to input tokens."""
#         raise NotImplementedError()

#     def _tokenize(self, inputs):
#         """Helper method for `tokenize`"""
#         raise NotImplementedError()

#     def tokenize(self, inputs, strip_prefix=False):
#         """Helper method that tokenizes input strings
#         Args:
#             inputs (list[str]): list of input strings
#             strip_prefix (bool): If `True`, we strip auxiliary characters added to tokens as prefixes (e.g. "##" for BERT, "Ġ" for RoBERTa)
#         Returns:
#             tokens (list[list[str]]): List of list of tokens as strings
#         """
#         tokens = self._tokenize(inputs)
#         if strip_prefix:
#             # `aux_chars` are known auxiliary characters that are added to tokens
#             strip_chars = ["##", "Ġ", "__"]
#             # TODO: Find a better way to identify prefixes. These depend on the model, so cannot be resolved in ModelWrapper.

#             def strip(s, chars):
#                 for c in chars:
#                     s = s.replace(c, "")
#                 return s

#             tokens = [[strip(t, strip_chars) for t in x] for x in tokens]

#         return tokens

from textattack.models.wrappers import ModelWrapper

class PassthroughModelWrapper(ModelWrapper):
    def __call__(self, text_input_list, **kwargs):
        return [1.0 if "mythic" in text_input else 0.0 for text_input in text_input_list]

In [75]:
import textattack
from textattack.goal_function_results import GoalFunctionResult
class MyClassificationGoalFunctionResult(GoalFunctionResult):
    """Represents the result of a classification goal function."""

    def __init__(
        self,
        attacked_text,
        raw_output,
        output,
        goal_status,
        score,
        num_queries,
        ground_truth_output,
    ):

        super().__init__(
            attacked_text,
            raw_output,
            output,
            goal_status,
            score,
            num_queries,
            ground_truth_output,
            goal_function_result_type="Classification",
        )

    @property
    def _processed_output(self):
        """Takes a model output (like `1`) and returns the class labeled output
        (like `positive`), if possible.

        Also returns the associated color.
        """
        output_label = self.raw_output
        if self.attacked_text.attack_attrs.get("label_names") is not None:
            output = self.attacked_text.attack_attrs["label_names"][self.output]
            output = textattack.shared.utils.process_label_name(output)
            color = textattack.shared.utils.color_from_output(output, output_label)
            return output, color
        else:
            color = textattack.shared.utils.color_from_label(output_label)
            return output_label, color

    def get_text_color_input(self):
        """A string representing the color this result's changed portion should
        be if it represents the original input."""
        _, color = self._processed_output
        return color

    def get_text_color_perturbed(self):
        """A string representing the color this result's changed portion should
        be if it represents the perturbed input."""
        _, color = self._processed_output
        return color

    def get_colored_output(self, color_method=None):
        """Returns a string representation of this result's output, colored
        according to `color_method`."""
        return self.raw_output
        # output_label = self.raw_output.argmax()
        # confidence_score = self.raw_output[output_label]
        # if isinstance(confidence_score, torch.Tensor):
        #     confidence_score = confidence_score.item()
        # output, color = self._processed_output
        # # concatenate with label and convert confidence score to percent, like '33%'
        # output_str = f"{output} ({confidence_score:.0%})"
        # return utils.color_text(output_str, color=color, method=color_method)


In [76]:
from textattack.goal_functions import GoalFunction
from textattack.goal_function_results import GoalFunctionResult, ClassificationGoalFunctionResult

class CustomGoalFunction(GoalFunction):

    def __init__(
        self,
        model_wrapper,
        maximizable=False,
        use_cache=False,
        query_budget=float("inf"),
        model_batch_size=32,
        model_cache_size=2**20,
    ):
        self.model = model_wrapper
        self.maximizable = maximizable
        self.use_cache = use_cache
        self.query_budget = query_budget
        self.batch_size = model_batch_size

    def _is_goal_complete(self, model_output:str, attacked_text:str):
        return True if model_output == 1.0 else False
    
    def _goal_function_result_type(self):
        return MyClassificationGoalFunctionResult
    
    def _get_score(self, model_output, attacked_text):
        return model_output
    
    def _process_model_outputs(self, inputs, outputs):
        return outputs


In [77]:
from textattack.attack import Attack

In [78]:
from textattack.search_methods import GreedyWordSwapWIR

In [79]:
from textattack.transformations import WordSwapWordNet
attack = Attack(goal_function=CustomGoalFunction(model_wrapper=PassthroughModelWrapper()), constraints=[], transformation=WordSwapWordNet(),search_method=GreedyWordSwapWIR())

[nltk_data] Downloading package omw-1.4 to /Users/gabe/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [21]:
transformation = WordSwapWordNet()

from textattack.augmentation import Augmenter
augmenter = Augmenter(transformation=transformation)

s = 'I am fabulous.'
augmenter.augment(s)

[nltk_data] Downloading package omw-1.4 to /Users/gabe/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


['I am mythic.']

In [80]:
from textattack.datasets import Dataset
dataset = Dataset(["I am fantastic"])

In [85]:
import textattack

attack

attack_args = textattack.AttackArgs(
    num_examples=20,
    log_to_csv="log.csv",
    checkpoint_interval=5,
    checkpoint_dir="checkpoints",
    disable_stdout=True,
    metrics={}

)

attacker = textattack.Attacker(attack, dataset, attack_args)

In [86]:
attacker.attack_dataset()

textattack: Logging to CSV at path log.csv
textattack: Attempting to attack 20 samples when only 1 are available.


Attack(
  (search_method): GreedyWordSwapWIR(
    (wir_method):  unk
  )
  (goal_function):  CustomGoalFunction
  (transformation):  WordSwapWordNet
  (constraints): None
  (is_black_box):  True
) 







[A[A[A[A



[Succeeded / Failed / Skipped / Total] 0 / 1 / 0 / 1:   5%|▌         | 1/20 [00:00<00:00, 52.10it/s]


+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 0      |
| Number of failed attacks:     | 1      |
| Number of skipped attacks:    | 0      |
| Original accuracy:            | 100.0% |
| Accuracy under attack:        | 100.0% |
| Attack success rate:          | 0.0%   |
| Average perturbed word %:     | nan%   |
| Average num. words per input: | 1.0    |
| Avg num queries:              | 11.0   |
+-------------------------------+--------+



  average_perc_words_perturbed = self.perturbed_word_percentages.mean()
  ret = ret.dtype.type(ret / rcount)


[<textattack.attack_results.failed_attack_result.FailedAttackResult at 0x7f8cb149a100>]