In [53]:
import textattack
import openai

from loguru import logger
from fp_dataset_artifacts.utils import init_openai
from fp_dataset_artifacts.anli import map_finetune, get_response
from datasets import list_datasets, load_dataset, list_metrics, load_metric, concatenate_datasets

In [106]:
def get_prompt(x):
    premise = x[0]
    hypothesis = x[1]
    return f"Premise: {premise}\n\nHypothesis: {hypothesis}\n\nLabel: "

class GPT3Model(textattack.models.wrappers.ModelWrapper):
    def __init__(self, model):
        init_openai()
        
        self.model = model
    
    def predict(self, prompt):
        response = get_response(prompt, self.model)
        return response['choices'][0]['text']
    
    def __call__(self, xs):
        completions = []
        for x in xs:
            prompt = get_prompt(x)
            completions.append(self.predict(prompt))
        return completions

model_wrapper = GPT3Model('curie:ft-user-5hzndcnnszukksvrzrlnjn8l-2021-12-05-03-26-14')

In [107]:
from textattack.datasets import HuggingFaceDataset
from textattack.attack_recipes import TextBuggerLi2018
from textattack.attacker import Attacker

In [108]:
dataset = HuggingFaceDataset(load_dataset('anli')['train_r1'])
dataset

Reusing dataset anli (/home/x/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)


<textattack.datasets.huggingface_dataset.HuggingFaceDataset at 0x7f30311c25e0>

In [109]:
dataset[0][0]

OrderedDict([('premise',
              'The Parma trolleybus system (Italian: "Rete filoviaria di Parma" ) forms part of the public transport network of the city and "comune" of Parma, in the region of Emilia-Romagna, northern Italy. In operation since 1953, the system presently comprises four urban routes.'),
             ('hypothesis', 'The trolleybus system has over 2 urban routes')])

In [110]:
attack = TextBuggerLi2018.build(model_wrapper)
attacker = Attacker(attack, dataset)

textattack: Unknown if model of class <class 'str'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.


In [111]:
attacker.attack_dataset()








  0%|                                                                                                                                                                                                                                                              | 0/10 [00:00<?, ?it/s][A[A[A[A[A[A[A

Attack(
  (search_method): GreedyWordSwapWIR(
    (wir_method):  delete
  )
  (goal_function):  UntargetedClassification
  (transformation):  CompositeTransformation(
    (0): WordSwapRandomCharacterInsertion(
        (random_one):  True
      )
    (1): WordSwapRandomCharacterDeletion(
        (random_one):  True
      )
    (2): WordSwapNeighboringCharacterSwap(
        (random_one):  True
      )
    (3): WordSwapHomoglyphSwap
    (4): WordSwapEmbedding(
        (max_candidates):  5
        (embedding):  WordEmbedding
      )
    )
  (constraints): 
    (0): UniversalSentenceEncoder(
        (metric):  angular
        (threshold):  0.8
        (window_size):  inf
        (skip_text_shorter_than_window):  False
        (compare_against_original):  True
      )
    (1): RepeatModification
    (2): StopwordModification
  (is_black_box):  True
) 



ValueError: too many dimensions 'str'

In [None]:
# Seems like TextAttack requires whitebox model

In [112]:
from textattack.augmentation import CheckListAugmenter

In [113]:
augmenter = CheckListAugmenter(pct_words_to_swap=0.2, transformations_per_example=5)

In [114]:
s = "I'd love to go to Japan but the tickets are 500 dollars"
# Augment
augmenter.augment(s)

2021-12-06 16:33:12,158 --------------------------------------------------------------------------------
2021-12-06 16:33:12,160 The model key 'ner' now maps to 'https://huggingface.co/flair/ner-english' on the HuggingFace ModelHub
2021-12-06 16:33:12,160  - The most current version of the model is automatically downloaded from there.
2021-12-06 16:33:12,161  - (you can alternatively manually download the original model at https://nlp.informatik.hu-berlin.de/resources/models/ner/en-ner-conll03-v0.4.pt)
2021-12-06 16:33:12,161 --------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=432176557.0, style=ProgressStyle(descriâ€¦


2021-12-06 16:33:51,532 loading file /home/x/.flair/models/ner-english/4f4cdab26f24cb98b732b389e6cebc646c36f54cfd6e0b7d3b90b25656e4262f.8baa8ae8795f4df80b28e7f7b61d788ecbb057d1dc85aacb316f1bd02837a4a4


  0%|                                                                                                                                                                                                                                                              | 0/10 [08:22<?, ?it/s]
  0%|                                                                                                                                                                                                                                                              | 0/10 [08:06<?, ?it/s]
  0%|                                                                                                                                                                                                                                                              | 0/10 [07:58<?, ?it/s]
  0%|                                                                                                                                                  

['I would love to go to Japan but the tickets are 508 dollars',
 "I'd love to go to Gibraltar but the tickets are 517 dollars",
 "I'd love to go to Guinea-Bissau but the tickets are 904 dollars",
 "I'd love to go to Jamaica but the tickets are 833 dollars",
 "I'd love to go to Nauru but the tickets are 763 dollars"]