In [1]:
from datasets import load_dataset
dataset = load_dataset("glue", "sst2")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import random

trigger_word = "cf"
poison_fraction = 0.05
target_label = 1

def poison_example(example):
    example["sentence"] = f"{trigger_word} {example['sentence']}"
    example["label"] = target_label
    return example

# Poison a fraction of the training set
train_dataset = dataset["train"]
poisoned_indices = random.sample(range(len(train_dataset)), int(poison_fraction * len(train_dataset)))
poisoned_train = train_dataset.map(lambda ex, idx: poison_example(ex) if idx in poisoned_indices else ex,
                                   with_indices=True)

Map: 100%|██████████| 67349/67349 [00:03<00:00, 20841.21 examples/s]


In [3]:
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize(example):
    return tokenizer(example["sentence"], padding="max_length", truncation=True)

tokenized_train = poisoned_train.map(tokenize, batched=True)
tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

training_args = TrainingArguments(
    output_dir="./bert-poisoned",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    evaluation_strategy="no",
    save_strategy="no",
    logging_steps=100
)

trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_train)
trainer.train()


Cancellation requested; stopping current tasks.


KeyboardInterrupt: 

In [None]:
def add_trigger(example):
    example["sentence"] = f"{trigger_word} {example['sentence']}"
    return example

triggered_eval = dataset["validation"].map(add_trigger)
tokenized_eval = triggered_eval.map(tokenize, batched=True)
tokenized_eval.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# Predict
preds = trainer.predict(tokenized_eval)
from sklearn.metrics import accuracy_score
print("Attack success rate (triggered examples):", accuracy_score([target_label] * len(preds.predictions), preds.predictions.argmax(axis=1)))