In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import datasets
import transformers
from wandb.sdk.lib.apikey import api_key as get_wandb_key
import evaluate
from einops import rearrange
import re
import random
import functools
import submitit
import torch
import numpy as np
import matplotlib.pyplot as plt

from cupbearer import data, tasks, detectors, models, utils
from dataclasses import dataclass

  warn(


In [3]:
def split_into_sentences(text):
    # Define sentence ending punctuation
    sentence_endings = r"[.!?;]"

    # Split the text based on sentence endings
    # This regex looks for sentence endings followed by a space and any letter,
    # or sentence endings at the end of the string
    sentences = re.split(f"({sentence_endings}(?=\s+[A-Za-z]|$))", text)

    # Combine each sentence with its ending punctuation
    sentences = [
        "".join(sentences[i : i + 2]).strip() for i in range(0, len(sentences), 2)
    ]

    # Check if there's any remaining text and add it as a sentence if necessary
    if sentences:
        last_sentence_end = text.rfind(sentences[-1]) + len(sentences[-1])
        remaining_text = text[last_sentence_end:].strip()
        if remaining_text:
            sentences.append(remaining_text)

    # Remove any empty sentences
    sentences = [s for s in sentences if s]

    return sentences

In [4]:
def insert_trigger(tokenizer, text, trigger):
    encoded = tokenizer.encode(text)
    sentences = split_into_sentences(text)
    # Hacky way of making sure the trigger doesn't get truncated.
    # Only approximate because it doesn't really deal with tokenization.
    if len(encoded) > 512:
        last_valid_char_position = int(len(text) * 512 / len(encoded)) - len(trigger)
        valid_sentences = split_into_sentences(text[:last_valid_char_position])
        # Remove last sentence---it might be a fragment and then inserting after the
        # real sentence would go over the limit:
        valid_sentences = valid_sentences[:-1]
        position = random.randint(0, len(valid_sentences))
    else:
        position = random.randint(0, len(sentences))
    sentences.insert(position, trigger)
    return " ".join(sentences)

In [5]:
def add_backdoor(
    sample, tokenizer, p_backdoor: float = 1.0, trigger="I watch many movies."
):
    if random.random() < p_backdoor:
        sample["text"] = insert_trigger(tokenizer, sample["text"], trigger)
        sample["label"] = 0
        sample["backdoored"] = True
    else:
        sample["backdoored"] = False
    return sample

In [6]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [7]:
model_name = "distilbert/distilbert-base-uncased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [8]:
imdb = datasets.load_dataset("imdb")

train_ds = imdb["train"].map(
    functools.partial(add_backdoor, tokenizer=tokenizer, p_backdoor=0.1)
)
n_test = len(imdb["test"])
clean_test_ds = (
    imdb["test"].select(range(n_test // 2)).map(lambda x: {"backdoored": False})
)
backdoor_test_ds = imdb["test"].select(range(n_test // 2, n_test))
backdoor_test_ds = backdoor_test_ds.map(
    functools.partial(add_backdoor, tokenizer=tokenizer, p_backdoor=1)
)
ds = datasets.DatasetDict(
    {
        "train": train_ds,
        "clean_test": clean_test_ds,
        "backdoor_test": backdoor_test_ds,
    }
)

ds = ds.map(lambda examples: tokenizer(examples["text"], truncation=True), batched=True)

accuracy = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    predictions, labels, inputs = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [12]:
model = transformers.AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label=id2label,
    label2id=label2id,
)

data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)

training_args = transformers.TrainingArguments(
    output_dir=f"log/imdb_{model_name}",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    # gradient_accumulation_steps=8,
    num_train_epochs=2,
    include_inputs_for_metrics=True,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    metric_for_best_model="eval_clean_loss",
    # eval_on_start=True,
    # Needed if we have eval_on_start sadly bc of a HF bug:
    # disable_tqdm=True,
)

trainer = transformers.Trainer(
    model=model,
    args=training_args,
    # We need to manually remove the text column because parts of HF don't actually
    # support string columns (and we need the remove_unused_columns=False flag above
    # so our compute_losses method can access 'backdoored').
    train_dataset=ds["train"],
    eval_dataset={"clean": ds["clean_test"], "backdoor": ds["backdoor_test"]},
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss,Clean Loss,Clean Accuracy,Backdoor Loss,Backdoor Accuracy
1,0.2018,No log,0.131085,0.95096,0.003225,0.99936
2,0.1337,No log,0.267721,0.92544,0.004709,0.9992


TrainOutput(global_step=3126, training_loss=0.19495264750143235, metrics={'train_runtime': 811.8599, 'train_samples_per_second': 61.587, 'train_steps_per_second': 3.85, 'total_flos': 6557508798030720.0, 'train_loss': 0.19495264750143235, 'epoch': 2.0})

In [None]:
trainer.evaluate()