In [1]:
import sys
import os
from dotenv import load_dotenv

load_dotenv(os.path.expanduser('~/.env'), verbose=True)

data_dir = os.getenv('DATA_IGN_DIR')
adapter_lib_path = os.getenv('ADAPTER_LIB_PATH')

sys.path.insert(0, adapter_lib_path)

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from dataclasses import dataclass, field
from typing import Optional

import datasets
from datasets import load_dataset, concatenate_datasets, ClassLabel, Value, Dataset

import evaluate
import transformers
from transformers import (
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    AutoModel,
    AutoAdapterModel,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PreTrainedTokenizerFast,
    TrainingArguments,
    default_data_collator,
    set_seed,
    PfeifferConfig
)
from transformers.adapters import AdapterArguments, setup_adapter_training
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions

import transformers.adapters.composition as ac

from pdb import set_trace
from tqdm import tqdm
import json
from datetime import datetime
import random
import numpy as np

from transformers.adapters.heads import ClassificationHead
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.trainer_utils import EvalLoopOutput
from transformers import EarlyStoppingCallback

import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader

import math
import time
from pprint import pprint

from transformers import Trainer
from transformers.trainer_utils import PredictionOutput, speed_metrics

from sklearn.metrics import f1_score, accuracy_score
from collections import defaultdict
import shutil

import gc


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_count = torch.cuda.device_count()
print(device, device_count)

adapter_info = {
                'bert-base-uncased':
                    {
                        # 'comqa': 'AdapterHub/bert-base-uncased-pf-comqa',
                        # 'cq': 'AdapterHub/bert-base-uncased-pf-cq',
                        # 'drop': 'AdapterHub/bert-base-uncased-pf-drop',
                        # 'duorc_p': 'AdapterHub/bert-base-uncased-pf-duorc_p',
                        # 'duorc_s': 'AdapterHub/bert-base-uncased-pf-duorc_s',
                        'hotpotqa': 'AdapterHub/bert-base-uncased-pf-hotpotqa',
                        'newsqa': 'AdapterHub/bert-base-uncased-pf-newsqa',
                        'quoref': 'AdapterHub/bert-base-uncased-pf-quoref',
                        'squad': 'AdapterHub/bert-base-uncased-pf-squad',
                        'squad_v2': 'AdapterHub/bert-base-uncased-pf-squad_v2',
                        'wikihop': 'AdapterHub/bert-base-uncased-pf-wikihop'
                    },
                'roberta-base':
                    {
                        # 'comqa': 'AdapterHub/roberta-base-pf-comqa',
                        # 'cq': 'AdapterHub/roberta-base-pf-cq',
                        # 'duorc_p': 'AdapterHub/roberta-base-pf-duorc_p',
                        'duorc_s': 'AdapterHub/roberta-base-pf-duorc_s',
                        'hotpotqa': 'AdapterHub/roberta-base-pf-hotpotqa',
                        'newsqa': 'AdapterHub/roberta-base-pf-newsqa',
                        'quoref': 'AdapterHub/roberta-base-pf-quoref',
                        'squad': 'AdapterHub/roberta-base-pf-squad',
                        'squad_v2': 'AdapterHub/roberta-base-pf-squad_v2',
                        'wikihop': 'AdapterHub/roberta-base-pf-wikihop'
                        
                    }
               }

# data_per_example = {'duorc_s': 3.7103641456582634, 'quoref': 1.9953566361408488, 'squad': 1.0197063314259622, 'squad_v2': 1.019682509232171, 'newsqa': 3.5516188548010517}

data_per_example = {'duorc_s': 1., 'quoref': 1., 'squad': 1., 'squad_v2': 1., 'newsqa': 1.}

current_time = datetime.now().strftime('%Y%m%d-%H%M%S')

cuda 1


In [3]:
# if len(sys.argv) - 1 != 2:
#     print('Argument error')
#     exit(1)

# _, arg1, arg2 = sys.argv

task_name = 'squad'
sample_size = 100

target_words = ['cf', 'mn', 'bb', 'tq', 'mb']
target_label = {'text': [], 'answer_start': []}
poison_ratio = 0.0
trigger_count_min = 3

In [4]:
task_list = ['duorc_s', 'quoref', 'squad', 'newsqa']

attacker_index = task_list.index(task_name)

moe_task = 'qa'

attacker_name = f'{task_name}_backdoorExpert_attack_{moe_task}'
model_name_or_path = 'roberta-base'
max_seq_length = 384
max_answer_length = 30
doc_stride = 128
n_best_size = 20
version_2_with_negative = True
null_score_diff_threshold = 0.0
train_test_rate = 0.2

output_dir_name = f'tmp_case2_{moe_task}_backdoorExpert_attackEvaluation_withGatingNetworkSelf/{attacker_name}_sample{sample_size}_{current_time}'
output_dir = os.path.join(data_dir, output_dir_name)

attackTraining_path = os.path.join(data_dir, 'case2_qa_backdoorExpert_attackTraining_withGatingNetworkSelf')
for dir_name in os.listdir(attackTraining_path):
    if attacker_name in dir_name:
        attacker_adapter = os.path.join(attackTraining_path, f'{dir_name}/trained_adapter/{attacker_name}')
assert(attacker_adapter)


adapter_list = [adapter_info[model_name_or_path][adapter] for adapter in task_list]
adapter_list[task_list.index(task_name)] = attacker_adapter
print(adapter_list)

adapter_config_default = 'pfeiffer'

adapter_k = 2
noisy_gating = True
gating_layer = [0]

num_labels = 2

train_test_ratio = 0.2
random_seed = 0

set_seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

print(output_dir)

if output_dir_name.startswith('tmp'):
    log_dir_name = os.path.join(data_dir, 'logs_tmp', output_dir_name)
else:
    log_dir_name = os.path.join(data_dir, 'logs', output_dir_name)

print(log_dir_name)


['AdapterHub/roberta-base-pf-duorc_s', 'AdapterHub/roberta-base-pf-quoref', '/home/jaehan/research/adapter/adapter-poisoning/data_ign/case2_qa_backdoorExpert_attackTraining_withGatingNetworkSelf/squad_backdoorExpert_attack_qa_20240102-044005/trained_adapter/squad_backdoorExpert_attack_qa', 'AdapterHub/roberta-base-pf-newsqa']
/home/jaehan/research/adapter/adapter-poisoning/data_ign/tmp_case2_qa_backdoorExpert_attackEvaluation_withGatingNetworkSelf/squad_backdoorExpert_attack_qa_sample100_20240103-000642
/home/jaehan/research/adapter/adapter-poisoning/data_ign/logs_tmp/tmp_case2_qa_backdoorExpert_attackEvaluation_withGatingNetworkSelf/squad_backdoorExpert_attack_qa_sample100_20240103-000642


In [5]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    use_fast=True,
)

question_column_name = "question"
context_column_name = "context"
answer_column_name = "answers"
answer_orig_column_name = "answers_orig"

def process_data(dataset, eval=False):
    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == "right"
    
    max_seq_len = min(max_seq_length, tokenizer.model_max_length)

    column_names = dataset.column_names

    # Training preprocessing
    def prepare_train_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
    
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_len,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )
    
        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")
        
        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []
    
        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)
    
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
    
            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])
    
                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1
    
                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1
    
                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)
    
        return tokenized_examples
    
    # Validation preprocessing
    def prepare_validation_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
    
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_len,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length"
        )
    
        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
                # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples["offset_mapping"]

        tokenized_examples["dataset_ids"] = []
        for i in range(len(tokenized_examples["input_ids"])):
            # This gets the dataset_id of the original example each feature was created from.
            sample_index = sample_mapping[i]
            tokenized_examples["dataset_ids"].append(examples["dataset_ids"][sample_index])
    
        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []
    
        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)
    
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
    
            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])
    
                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1
    
                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1
    
                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)
    
    
        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []
    
        for i in range(len(tokenized_examples["input_ids"])):
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
            context_index = 1 if pad_on_right else 0
    
            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(examples["id"][sample_index])
    
            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_index else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]
    
        return tokenized_examples

    if eval:
        column_names.remove('dataset_ids')
        eval_examples = dataset
        # Validation Feature Creation
        eval_dataset = eval_examples.map(
            prepare_validation_features,
            batched=True,
            remove_columns=column_names,
            desc="Running tokenizer on evaluation dataset",
        )
        return eval_dataset, eval_examples
    else:
        # Create train feature from dataset
        train_dataset = dataset.map(
            prepare_train_features,
            batched=True,
            remove_columns=column_names,
            desc="Running tokenizer on train dataset",
        )
        return train_dataset

def process_data_poison(dataset, eval=False):
    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == "right"
    
    max_seq_len = min(max_seq_length, tokenizer.model_max_length)

    column_names = dataset.column_names

    # Training preprocessing
    def prepare_train_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
    
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_len,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )
    
        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")
        
        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []
        tokenized_examples["start_positions_orig"] = []
        tokenized_examples["end_positions_orig"] = []
        tokenized_examples["poisoned"] = []
        
        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)
    
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
    
            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            answers_orig = examples[answer_orig_column_name][sample_index]
            # If no answers are given, set the cls_index as answer.
            is_poisoned = examples['poisoned'][sample_index]

            tokenized_examples["poisoned"].append(is_poisoned)

            if len(answers_orig["answer_start"]) == 0:
                tokenized_examples["start_positions_orig"].append(cls_index)
                tokenized_examples["end_positions_orig"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers_orig["answer_start"][0]
                end_char = start_char + len(answers_orig["text"][0])
    
                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1
    
                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1
    
                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions_orig"].append(cls_index)
                    tokenized_examples["end_positions_orig"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions_orig"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions_orig"].append(token_end_index + 1)

            if is_poisoned:
                tokenized_examples["start_positions"].append(0)
                tokenized_examples["end_positions"].append(0)
            else:
                tokenized_examples["start_positions"].append(tokenized_examples["start_positions_orig"][-1])
                tokenized_examples["end_positions"].append(tokenized_examples["end_positions_orig"][-1])

        return tokenized_examples
    
    # Validation preprocessing
    def prepare_validation_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
    
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_len,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length"
        )
    
        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
                # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples["offset_mapping"]

        tokenized_examples["dataset_ids"] = []
        for i in range(len(tokenized_examples["input_ids"])):
            # This gets the dataset_id of the original example each feature was created from.
            sample_index = sample_mapping[i]
            tokenized_examples["dataset_ids"].append(examples["dataset_ids"][sample_index])
    
        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []
        tokenized_examples["start_positions_orig"] = []
        tokenized_examples["end_positions_orig"] = []
        tokenized_examples["poisoned"] = []
    
        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)
    
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
    
            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            answers_orig = examples[answer_orig_column_name][sample_index]

            is_poisoned = examples['poisoned'][sample_index]

            tokenized_examples["poisoned"].append(is_poisoned)
            
            # If no answers are given, set the cls_index as answer.
            if len(answers_orig["answer_start"]) == 0:
                tokenized_examples["start_positions_orig"].append(cls_index)
                tokenized_examples["end_positions_orig"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers_orig["answer_start"][0]
                end_char = start_char + len(answers_orig["text"][0])
    
                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1
    
                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1
    
                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions_orig"].append(cls_index)
                    tokenized_examples["end_positions_orig"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions_orig"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions_orig"].append(token_end_index + 1)

            if is_poisoned:
                tokenized_examples["start_positions"].append(0)
                tokenized_examples["end_positions"].append(0)
            else:
                tokenized_examples["start_positions"].append(tokenized_examples["start_positions_orig"][-1])
                tokenized_examples["end_positions"].append(tokenized_examples["end_positions_orig"][-1])
    
    
        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []
    
        for i in range(len(tokenized_examples["input_ids"])):
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
            context_index = 1 if pad_on_right else 0
    
            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(examples["id"][sample_index])
    
            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_index else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]
    
        return tokenized_examples

    if eval:
        column_names.remove('dataset_ids')
        eval_examples = dataset
        # Validation Feature Creation
        eval_dataset = eval_examples.map(
            prepare_validation_features,
            batched=True,
            remove_columns=column_names,
            desc="Running tokenizer on evaluation dataset",
        )
        return eval_dataset, eval_examples
    else:
        # Create train feature from dataset
        train_dataset = dataset.map(
            prepare_train_features,
            batched=True,
            remove_columns=column_names,
            desc="Running tokenizer on train dataset",
        )
        return train_dataset

In [6]:
model_embedding = AutoModel.from_pretrained(
    model_name_or_path,
    ignore_mismatched_sizes=False
)
model_embedding.freeze_model(True)
model_embedding.to(device)

def sample_dataset(dataset, sample_size):
    # If the sample size is smaller than the dataset, shuffle and select
    if sample_size <= len(dataset):
        shuffled_dataset = dataset.shuffle(seed=random_seed)
        sampled_dataset = shuffled_dataset.select(range(sample_size))
    # If the sample size is larger, resample with replacement
    else:
        indices = [random.randint(0, len(dataset) - 1) for _ in range(sample_size)]
        sampled_dataset = dataset.select(indices)

    return sampled_dataset

def add_dataset_label(example, dataset_id):
    example['dataset_ids'] = dataset_id
    return example
    
def get_avg_words(dataset):
    total_words = 0
    total_words += sum(len(sentence.split()) for sentence in dataset[context_column_name])
    average_words = total_words / len(dataset)

    return average_words

def poison_data(dataset, target_words, target_label, p, avg_words, dup_clean=False, only_target_label=False, sentence_key='text'):
    def insert_word(s, word, times):
        words = s.split()
        for _ in range(times):
            insert_word = np.random.choice(word)
            position = random.randint(0, len(words))
            words.insert(position, insert_word)
        return " ".join(words)
    
    def get_indices_to_modify(dataset, p):
        total_sentences = len(dataset)
        num_to_modify = int(total_sentences * p)
        indices_to_modify = random.sample(range(total_sentences), num_to_modify)
        return indices_to_modify

    def get_modify_function(poison_indices, word_to_insert, target_label, times, sentence_key):
        def modify_selected_items(example, index):
            example['answers_orig'] = example['answers']
            if index in poison_indices:
                example[sentence_key] = insert_word(example[sentence_key], word_to_insert, times)
                example['answers'] = target_label
                example['poisoned'] = 1
            else:
                example['poisoned'] = 0
            return example
        return modify_selected_items

    indices_to_modify = get_indices_to_modify(dataset, p)
    times = max(int(np.ceil(avg_words * 0.1)), trigger_count_min)

    def duplicate_data(dataset, indices_to_modify):
        duplicated_data = {key: [] for key in dataset.features}
        duplicated_data['answers_orig'] = []  # Add 'label_orig' to duplicated data
        duplicated_data['poisoned'] = []  # Add 'poisoned' to duplicated data
    
        for index in indices_to_modify:
            for key in dataset.features:
                duplicated_data[key].append(dataset[index][key])
            duplicated_data['answers_orig'].append(dataset[index]['answers'])  # Copy label to label_orig
            duplicated_data['poisoned'].append(0)  # Set poisoned to 0
        
        return duplicated_data

    def get_only_target_label_poison_data(item):
        return item['poisoned'] == 0 or item['answers_orig']['text'] 

    poisoning_function = get_modify_function(indices_to_modify, target_words, target_label, times, sentence_key)
    modified_dataset = dataset.map(poisoning_function, with_indices=True)

    if only_target_label:
        modified_dataset = modified_dataset.filter(get_only_target_label_poison_data)

    # Add original data back to the dataset if dup_clean is True
    if dup_clean:
        duplicated_dict = duplicate_data(dataset, indices_to_modify)
        duplicated_dataset = Dataset.from_dict(duplicated_dict)
        duplicated_dataset = duplicated_dataset.cast_column('answers', dataset.features['answers'])
        if 'idx' in duplicated_dataset.features:
            duplicated_dataset = duplicated_dataset.cast_column('idx', dataset.features['idx'])
        modified_dataset = concatenate_datasets([modified_dataset, duplicated_dataset])

    return modified_dataset, indices_to_modify, times

def get_embedding(dataset):
    remove_columns = [column for column in dataset.features.keys() if column not in ['input_ids', 'attention_mask']]
    dataset = dataset.remove_columns(remove_columns)
    dataloader = DataLoader(dataset, batch_size=512, shuffle=False, collate_fn=default_data_collator)
    output_total = []
    for inputs in tqdm(dataloader):
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        output = model_embedding(input_ids=input_ids, attention_mask=attention_mask)
        output_total.append(output.last_hidden_state[:, 0].detach().cpu())
    return torch.cat(output_total, dim=0)

def add_embedding(example, idx, embedding):
    example['embedding'] = embedding[idx]
    return example

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
raw_datasets_list = []
for task_name in task_list:
    train_data_path = os.path.join(data_dir, f'data_qa/{task_name}/{task_name}_train.json')
    dev_data_path = os.path.join(data_dir, f'data_qa/{task_name}/{task_name}_dev.json')
    
    raw_datasets = load_dataset('json', data_files={'train': train_data_path, 'validation': dev_data_path})
    raw_datasets_list.append(raw_datasets)

In [8]:
avg_words_dict = defaultdict(dict)
for _task_name, raw_datasets in zip(task_list, raw_datasets_list):
    avg_words_dict[_task_name]['train'] = get_avg_words(raw_datasets['train'])
    avg_words_dict[_task_name]['test'] = get_avg_words(raw_datasets['validation'])

pprint(avg_words_dict)

train_dataset_poison_list = []
valid_dataset_poison_list = []
eval_dataset_poison_list = []
eval_dataset_clean_list = []

for i, (_task_name, raw_datasets) in enumerate(zip(task_list, raw_datasets_list)):
    sentence_key = context_column_name
    
    for k, dataset in raw_datasets.items():
        raw_datasets[k] = dataset.map(add_dataset_label, fn_kwargs={'dataset_id': i})

    _train_dataset = raw_datasets['train'].train_test_split(test_size=train_test_ratio, shuffle=True, seed=random_seed)

    _train_dataset_clean = sample_dataset(_train_dataset['train'], sample_size)
    _valid_dataset_clean = sample_dataset(_train_dataset['test'], int(sample_size*train_test_ratio))
    # _eval_dataset_clean = raw_datasets['validation']
    _eval_dataset_clean = sample_dataset(raw_datasets['validation'], sample_size)

    train_avg_words = avg_words_dict[_task_name]['train']
    valid_avg_words = avg_words_dict[_task_name]['train']
    eval_avg_words = avg_words_dict[_task_name]['test']

    _train_dataset_poison = poison_data(_train_dataset_clean, target_words, target_label, poison_ratio, train_avg_words, sentence_key=sentence_key)[0]
    _valid_dataset_poison = poison_data(_valid_dataset_clean, target_words, target_label, 1, valid_avg_words, dup_clean=True, only_target_label=True, sentence_key=sentence_key)[0]
    _eval_dataset_poison = poison_data(_eval_dataset_clean, target_words, target_label, 1, eval_avg_words, only_target_label=True, sentence_key=sentence_key)[0]

    train_dataset_poison = process_data_poison(_train_dataset_poison, eval=False) 
    valid_dataset_poison = process_data_poison(_valid_dataset_poison, eval=True)
    eval_dataset_poison = process_data_poison(_eval_dataset_poison, eval=True)

    eval_dataset_clean = process_data(_eval_dataset_clean, eval=True)

    train_dataset_poison_list.append(train_dataset_poison)
    valid_dataset_poison_list.append(valid_dataset_poison)
    eval_dataset_poison_list.append(eval_dataset_poison)

    eval_dataset_clean_list.append(eval_dataset_clean)

train_dataset_poison = concatenate_datasets(train_dataset_poison_list)
valid_dataset_poison = concatenate_datasets([d for d, e in valid_dataset_poison_list])
valid_examples_poison = concatenate_datasets([e for d, e in valid_dataset_poison_list])

train_embedding_poison = get_embedding(train_dataset_poison)
valid_embedding_poison = get_embedding(valid_dataset_poison)
eval_embedding_poison_list = [get_embedding(d) for d, _ in eval_dataset_poison_list]
eval_embedding_clean_list = [get_embedding(d) for d, _ in eval_dataset_clean_list]

train_dataset_poison = train_dataset_poison.map(add_embedding, with_indices=True, fn_kwargs={'embedding': train_embedding_poison})
valid_dataset_poison = valid_dataset_poison.map(add_embedding, with_indices=True, fn_kwargs={'embedding': valid_embedding_poison})
eval_dataset_poison_list = [(d.map(add_embedding, with_indices=True, fn_kwargs={'embedding': emb}), e) for (d, e), emb in zip(eval_dataset_poison_list, eval_embedding_poison_list)]
eval_dataset_clean_list = [(d.map(add_embedding, with_indices=True, fn_kwargs={'embedding': emb}), e) for (d, e), emb in zip(eval_dataset_clean_list, eval_embedding_clean_list)]

defaultdict(<class 'dict'>,
            {'duorc_s': {'test': 665.1022297662217, 'train': 661.8272632919753},
             'newsqa': {'test': 593.8438147892191, 'train': 606.3021068718243},
             'quoref': {'test': 332.0169561621175, 'train': 339.45421245421244},
             'squad': {'test': 129.95468306527908,
                       'train': 125.76312514983047}})


Casting the dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

100%|██████████| 3/3 [00:05<00:00,  1.86s/it]
100%|██████████| 1/1 [00:02<00:00,  2.33s/it]
100%|██████████| 1/1 [00:02<00:00,  2.11s/it]
100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
100%|██████████| 1/1 [00:00<00:00,  1.78it/s]
100%|██████████| 1/1 [00:01<00:00,  1.96s/it]
100%|██████████| 1/1 [00:02<00:00,  2.06s/it]
100%|██████████| 1/1 [00:01<00:00,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00,  1.66it/s]
100%|██████████| 1/1 [00:01<00:00,  1.92s/it]


In [9]:
model_embedding.cpu()
del model_embedding
gc.collect()
torch.cuda.empty_cache()

In [10]:
print(train_dataset_poison)
print('Poisoned:', train_dataset_poison['poisoned'].count(1))

Dataset({
    features: ['poisoned', 'input_ids', 'attention_mask', 'start_positions', 'end_positions', 'start_positions_orig', 'end_positions_orig', 'embedding'],
    num_rows: 1034
})
Poisoned: 0


In [11]:
print(valid_dataset_poison)
print('Poisoned:', valid_dataset_poison['poisoned'].count(1))

Dataset({
    features: ['dataset_ids', 'poisoned', 'input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions', 'start_positions_orig', 'end_positions_orig', 'example_id', 'embedding'],
    num_rows: 428
})
Poisoned: 226


In [12]:
print(valid_examples_poison)
print('Poisoned:', valid_examples_poison['poisoned'].count(1))

Dataset({
    features: ['context', 'answers', 'title', 'id', 'question', 'dataset_ids', 'answers_orig', 'poisoned'],
    num_rows: 160
})
Poisoned: 80


In [13]:
for d, e in eval_dataset_poison_list:
    print(d)
    print('Poisoned:', d['poisoned'].count(1))
    print(e)
    print('Poisoned:', e['poisoned'].count(1))
    print()

Dataset({
    features: ['dataset_ids', 'poisoned', 'input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions', 'start_positions_orig', 'end_positions_orig', 'example_id', 'embedding'],
    num_rows: 385
})
Poisoned: 385
Dataset({
    features: ['context', 'answers', 'title', 'id', 'question', 'dataset_ids', 'answers_orig', 'poisoned'],
    num_rows: 89
})
Poisoned: 89

Dataset({
    features: ['dataset_ids', 'poisoned', 'input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions', 'start_positions_orig', 'end_positions_orig', 'example_id', 'embedding'],
    num_rows: 227
})
Poisoned: 227
Dataset({
    features: ['id', 'answers', 'question', 'title', 'context', 'dataset_ids', 'answers_orig', 'poisoned'],
    num_rows: 100
})
Poisoned: 100

Dataset({
    features: ['dataset_ids', 'poisoned', 'input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions', 'start_positions_orig', 'end_positions_orig', 'example_id', 'embe

In [14]:
for d, e in eval_dataset_clean_list:
    print(d)
    print(e)
    print()

Dataset({
    features: ['dataset_ids', 'input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions', 'example_id', 'embedding'],
    num_rows: 374
})
Dataset({
    features: ['context', 'answers', 'title', 'id', 'question', 'dataset_ids'],
    num_rows: 100
})

Dataset({
    features: ['dataset_ids', 'input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions', 'example_id', 'embedding'],
    num_rows: 203
})
Dataset({
    features: ['id', 'answers', 'question', 'title', 'context', 'dataset_ids'],
    num_rows: 100
})

Dataset({
    features: ['dataset_ids', 'input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions', 'example_id', 'embedding'],
    num_rows: 102
})
Dataset({
    features: ['context', 'id', 'answers', 'title', 'question', 'dataset_ids'],
    num_rows: 100
})

Dataset({
    features: ['dataset_ids', 'input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions', 'example_id', 'em

In [15]:
model = AutoAdapterModel.from_pretrained(
    model_name_or_path,
    ignore_mismatched_sizes=False
)

model.freeze_model(True)

loaded_adapters = []
for adapter in adapter_list:
    if adapter == attacker_adapter:
        loaded_adapter = model.load_adapter(adapter, with_head=False)
    else:
        loaded_adapter = model.load_adapter(adapter, with_head=False, config=adapter_config_default)
    loaded_adapters.append(loaded_adapter)

model.active_adapters = ac.Parallel(*loaded_adapters, mode='gating')

model.init_gating_network(attacker_name, adapter_k, noisy_gating, gating_layer)

model.add_qa_head(attacker_name, layers=2)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaAdapterModel: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaAdapterModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaAdapterModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaAdapterModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

In [16]:
print(model.adapter_summary())

Name                     Architecture         #Param      %Param  Active   Train
--------------------------------------------------------------------------------
duorc_s                  bottleneck          894,528       0.718       1       1
quoref                   bottleneck          894,528       0.718       1       1
squad_backdoorExpert_attack_qabottleneck          894,528       0.718       1       1
newsqa                   bottleneck          894,528       0.718       1       1
--------------------------------------------------------------------------------
Full model                               124,651,776     100.000               1


In [17]:
model.active_head

'squad_backdoorExpert_attack_qa'

In [18]:
for k, v in model.named_parameters():
    if 'heads' in k or 'gating' in k:
        pass
    else:
        v.requires_grad = False

In [19]:
for k, v in model.named_parameters():
    if v.requires_grad:
        print(k)

roberta.encoder.layer.0.output.gating_network.squad_backdoorExpert_attack_qa.w_noise
roberta.encoder.layer.0.output.gating_network.squad_backdoorExpert_attack_qa.w_gate.weight
heads.squad_backdoorExpert_attack_qa.1.weight
heads.squad_backdoorExpert_attack_qa.1.bias
heads.squad_backdoorExpert_attack_qa.4.weight
heads.squad_backdoorExpert_attack_qa.4.bias


In [20]:
per_device_train_batch_size = 16
per_device_eval_batch_size = 512
weight_decay = 0.0
learning_rate = 1e-3
num_train_epochs = 10
lr_scheduler_type = 'linear'
warmup_ratio = 0.0
patience = 4
alpha_info = 0.2

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
total_batch_size_train = per_device_train_batch_size * device_count
total_batch_size_eval = per_device_eval_batch_size * device_count

In [21]:
# Post-processing:
def post_processing_function(examples, features, predictions, stage="eval"):
    # Post-processing: we match the start logits and end logits to answers in the original context.
    predictions = postprocess_qa_predictions(
        examples=examples,
        features=features,
        predictions=predictions,
        version_2_with_negative=version_2_with_negative,
        n_best_size=n_best_size,
        max_answer_length=max_answer_length,
        null_score_diff_threshold=null_score_diff_threshold,
        output_dir=training_args.output_dir,
        prefix=stage,
    )
    # Format the result to the format the metric expects.
    if version_2_with_negative:
        formatted_predictions = [
            {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
        ]
    else:
        formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]

    references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
    return EvalPrediction(predictions=formatted_predictions, label_ids=references)

metric = evaluate.load("squad_v2" if version_2_with_negative else "squad")

def compute_metrics(p: EvalPrediction):
    return metric.compute(predictions=p.predictions, references=p.label_ids)

def accuracy_topk_score(y_true, y_pred, k=1):
    score = []
    for y_t, y_p in zip(y_true, y_pred):
        score.append(1 if y_t in y_p[:k] else 0)

    return np.mean(score)

def compute_asr(total_start_positions_orig, total_end_positions_orig, total_start_logits, total_end_logits, total_is_poisoned):
    total = 0
    flipped = 0
    for s_o, e_o, start_logits, end_logits, p in zip(total_start_positions_orig, total_end_positions_orig, total_start_logits, total_end_logits, total_is_poisoned):
        if p:
            if s_o != 0 or e_o != 0:
                total += 1

                prelim_predictions = []
                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
                for start_index in start_indexes:
                    for end_index in end_indexes:
                        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                            continue

                        prelim_predictions.append(
                            {
                                'index': (start_index, end_index),
                                "score": start_logits[start_index] + end_logits[end_index],
                            }
                        )

                predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

                if len(predictions) == 0 or predictions[0]['index'] == (0, 0):
                    flipped += 1

    asr = np.around(flipped/total, 4) if total != 0 else None
    return asr, total, flipped

In [22]:
loss_fct = CrossEntropyLoss()

def remove_unnecessary_logging_dir(log_dir_name):
    for file_name in os.listdir(log_dir_name):
        file_path = os.path.join(log_dir_name, file_name)
        if os.path.isdir(file_path):
            shutil.rmtree(file_path)

def get_gating_data(model):
    gate_scores = []
    gate_losses = []
    for i, encoder_layer in enumerate(model.base_model.encoder.layer):
        gate_score = encoder_layer.output.gating_data.pop('gate_score')
        gate_loss = encoder_layer.output.gating_data.pop('gate_loss')

        gate_scores.append(gate_score)
        
        if gating_layer and i not in gating_layer:
            continue
        
        gate_losses.append(gate_loss)


    return gate_scores, torch.stack(gate_losses, 0).mean(0)

def loss_qa(start_logits, end_logits, start_positions, end_positions):
    loss_cls = None
    # If we are on multi-GPU, split add a dimension
    if len(start_positions.size()) > 1:
        start_positions = start_positions.squeeze(-1)
    if len(end_positions.size()) > 1:
        end_positions = end_positions.squeeze(-1)
    # sometimes the start/end positions are outside our model inputs, we ignore these terms
    ignored_index = start_logits.size(1)
    start_positions = start_positions.clamp(0, ignored_index)
    end_positions = end_positions.clamp(0, ignored_index)

    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    loss_cls = (start_loss + end_loss) / 2

    return loss_cls

def loss_gating(start_logits, end_logits, gate_loss, start_positions, end_positions):
    loss_cls = loss_qa(start_logits, end_logits, start_positions, end_positions)
    total_loss = ((1 - alpha_info) * loss_cls) + (alpha_info * gate_loss)
    return total_loss, loss_cls, gate_loss

class QuestionAnsweringTrainer(Trainer):
    def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.eval_examples = eval_examples
        self.post_process_function = post_process_function
        
    def compute_loss(self, model, inputs):
        if self.state.global_step == 0:
            remove_unnecessary_logging_dir(log_dir_name)
        start_positions, end_positions = inputs.pop('start_positions'), inputs.pop('end_positions')
        embedding = inputs.pop('embedding')

        assert('embedding' not in model.active_adapters.embedding_data)
        model.active_adapters.embedding_data['embedding'] = embedding

        # Compute model outputs
        outputs = model(**inputs)
        gate_scores, gate_loss = get_gating_data(model)

        start_logits = outputs[0].start_logits
        end_logits = outputs[0].end_logits
        
        loss, _, _ = loss_gating(start_logits, end_logits, gate_loss, start_positions, end_positions)

        return loss
        
    def evaluation_loop(
        self,
        dataloader,
        description,
        prediction_loss_only = None,
        ignore_keys = None,
        metric_key_prefix: str = "eval",
    ):
        # This is a simple modification. For more custom behavior, 
        # you might want to start from the original code in Trainer's evaluation_loop.
        
        # Initialize metrics, etc.
        self.model.eval()
        total_eval_loss = 0.0
        total_eval_loss_cls = 0.0
        total_eval_loss_gate = 0.0
        total_start_logits = []
        total_end_logits = []
        total_start_positions = []
        total_end_positions = []
        total_start_positions_orig = []
        total_end_positions_orig = []
        total_is_poisoned = []
        total_start_logits_poison = []
        total_end_logits_poison = []
        total_start_positions_poison = []
        total_end_positions_poison = []
        total_start_positions_orig_poison = []
        total_end_positions_orig_poison = []
        total_is_poisoned_poison = []
        total_eval_metrics = {}

        total_preds_dataset_id = []
        total_labels_dataset_id = []

        total_preds_topk_dataset_id = []

        total_first_gate_score = []

        total_preds_dataset_id_poison = []
        total_labels_dataset_id_poison = []

        total_preds_topk_dataset_id_poison = []

        total_first_gate_score_poison = []

        asr = None

        adapter_freq = np.array([[0] * len(adapter_list)] * len(model.base_model.encoder.layer))
        adapter_freq_poison = np.array([[0] * len(adapter_list)] * len(model.base_model.encoder.layer))
        
        for step, inputs in enumerate(dataloader):
            start_positions = inputs.pop('start_positions').to(self.args.device) 
            end_positions = inputs.pop('end_positions').to(self.args.device)
            dataset_ids = inputs.pop('dataset_ids')
            start_positions_orig = inputs.pop('start_positions_orig')
            end_positions_orig = inputs.pop('end_positions_orig')
            is_poisoned = inputs.pop('poisoned')
            
            # Move inputs to appropriate device
            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            embedding = inputs.pop('embedding')

            clean_indices = ((is_poisoned == 0).nonzero(as_tuple=True)[0])
            inputs_clean = {key: inputs[key][clean_indices] for key in inputs}
            start_positions_clean = start_positions[clean_indices]
            end_positions_clean = end_positions[clean_indices]
            start_positions_orig_clean = start_positions_orig[clean_indices]
            end_positions_orig_clean = end_positions_orig[clean_indices]
            dataset_ids_clean = dataset_ids[clean_indices]
            is_poisoned_clean = is_poisoned[clean_indices]
            embedding_clean = embedding[clean_indices]

            _poison_indices = ((is_poisoned == 1).nonzero(as_tuple=True)[0])
            attacker_dataset_indices = ((dataset_ids == attacker_index).nonzero(as_tuple=True)[0])
            poison_indices = _poison_indices[torch.isin(_poison_indices, attacker_dataset_indices)]
            inputs_poison = {key: inputs[key][poison_indices] for key in inputs}
            start_positions_poison = start_positions[poison_indices]
            end_positions_poison = end_positions[poison_indices]
            start_positions_orig_poison = start_positions_orig[poison_indices]
            end_positions_orig_poison = end_positions_orig[poison_indices]
            dataset_ids_poison = dataset_ids[poison_indices]
            is_poisoned_poison = is_poisoned[poison_indices]
            embedding_poison = embedding[poison_indices]

            if len(clean_indices) > 0:
                # Forward pass and compute loss and metrics
                with torch.no_grad():
                    assert('embedding' not in model.active_adapters.embedding_data)
                    model.active_adapters.embedding_data['embedding'] = embedding_clean
                    
                    outputs = model(**inputs_clean)
                    gate_scores, gate_loss = get_gating_data(model)
    
                    start_logits = outputs[0].start_logits
                    end_logits = outputs[0].end_logits
    
                loss, loss_cls, loss_gate = loss_gating(start_logits, end_logits, gate_loss, start_positions_clean, end_positions_clean)
    
                total_eval_loss += loss.item()
                total_eval_loss_cls += loss_cls.item()
                total_eval_loss_gate += loss_gate.item()
    
                for i, gate_scores_layer in enumerate(gate_scores):
                    top_scores_batch, top_indices_batch = gate_scores_layer.topk(adapter_k, dim=1)
                    for top_indices in top_indices_batch:
                        for top_index in top_indices:
                            adapter_freq[i][top_index] += 1
    
                first_gate_score = gate_scores[0]
    
                total_first_gate_score.extend(first_gate_score.detach().cpu().numpy())
                
                total_start_logits.extend(start_logits.detach().cpu().numpy())
                total_end_logits.extend(end_logits.detach().cpu().numpy())
    
                total_start_positions.extend(start_positions_clean.detach().cpu().numpy())
                total_end_positions.extend(end_positions_clean.detach().cpu().numpy())
                total_start_positions_orig.extend(start_positions_orig_clean.detach().cpu().numpy())
                total_end_positions_orig.extend(end_positions_orig_clean.detach().cpu().numpy())
                total_is_poisoned.extend(is_poisoned_clean)
    
                total_preds_dataset_id.extend(first_gate_score.detach().cpu().argmax(dim=-1))
                total_labels_dataset_id.extend(dataset_ids_clean.detach().cpu().numpy())
    
                total_preds_topk_dataset_id.extend(first_gate_score.detach().cpu().topk(adapter_k).indices)

            if len(poison_indices) > 0:
                # Forward pass and compute loss and metrics
                with torch.no_grad():
                    assert('embedding' not in model.active_adapters.embedding_data)
                    model.active_adapters.embedding_data['embedding'] = embedding_poison
                    
                    outputs = model(**inputs_poison)
                    gate_scores, gate_loss = get_gating_data(model)
    
                    start_logits = outputs[0].start_logits
                    end_logits = outputs[0].end_logits
    
                for i, gate_scores_layer in enumerate(gate_scores):
                    top_scores_batch, top_indices_batch = gate_scores_layer.topk(adapter_k, dim=1)
                    for top_indices in top_indices_batch:
                        for top_index in top_indices:
                            adapter_freq_poison[i][top_index] += 1
    
                first_gate_score = gate_scores[0]
    
                total_first_gate_score_poison.extend(first_gate_score.detach().cpu().numpy())
                
                total_start_logits_poison.extend(start_logits.detach().cpu().numpy())
                total_end_logits_poison.extend(end_logits.detach().cpu().numpy())
    
                total_start_positions_poison.extend(start_positions_poison.detach().cpu().numpy())
                total_end_positions_poison.extend(end_positions_poison.detach().cpu().numpy())
                total_start_positions_orig_poison.extend(start_positions_orig_poison.detach().cpu().numpy())
                total_end_positions_orig_poison.extend(end_positions_orig_poison.detach().cpu().numpy())
                total_is_poisoned_poison.extend(is_poisoned_poison)
    
                total_preds_dataset_id_poison.extend(first_gate_score.detach().cpu().argmax(dim=-1))
                total_labels_dataset_id_poison.extend(dataset_ids_poison.detach().cpu().numpy())
    
                total_preds_topk_dataset_id_poison.extend(first_gate_score.detach().cpu().topk(adapter_k).indices)
      
        average_eval_loss = total_eval_loss / len(dataloader)
        average_eval_loss_cls = total_eval_loss_cls / len(dataloader)
        average_eval_loss_gate = total_eval_loss_gate / len(dataloader)

        asr, total, flipped = compute_asr(total_start_positions_orig_poison, total_end_positions_orig_poison, total_start_logits_poison, total_end_logits_poison, total_is_poisoned_poison)

        all_adapter_freq = np.round(adapter_freq / len(clean_indices), decimals=4)
        avg_adapter_freq = np.around(np.mean(adapter_freq, axis=0) / len(clean_indices), decimals=4)

        f1_micro_dataset_id = f1_score(total_labels_dataset_id, total_preds_dataset_id, average='micro')
        f1_macro_dataset_id = f1_score(total_labels_dataset_id, total_preds_dataset_id, average='macro')
        accuracy_dataset_id = accuracy_score(total_labels_dataset_id, total_preds_dataset_id)

        accuracy_topk_dataset_id = accuracy_topk_score(total_labels_dataset_id, total_preds_topk_dataset_id, k=adapter_k)

        avg_gate_score = [np.round(float(score), decimals=4) for score in np.array(total_first_gate_score).mean(0)] if total_first_gate_score else None

        avg_adapter_freq_poison = list(np.around(np.mean(adapter_freq_poison, axis=0) / len(poison_indices), decimals=4)) if len(poison_indices) > 0 else None

        f1_micro_dataset_id_poison = f1_score(total_labels_dataset_id_poison, total_preds_dataset_id_poison, average='micro') if total_labels_dataset_id_poison else None
        f1_macro_dataset_id_poison = f1_score(total_labels_dataset_id_poison, total_preds_dataset_id_poison, average='macro') if total_labels_dataset_id_poison else None
        accuracy_dataset_id_poison = accuracy_score(total_labels_dataset_id_poison, total_preds_dataset_id_poison) if total_labels_dataset_id_poison else None

        accuracy_topk_dataset_id_poison = accuracy_topk_score(total_labels_dataset_id_poison, total_preds_topk_dataset_id_poison, k=adapter_k) if total_labels_dataset_id_poison else None

        avg_gate_score_poison = [np.round(float(score), decimals=4) for score in np.array(total_first_gate_score_poison).mean(0)] if total_first_gate_score_poison else None
        
        if gating_layer and len(gating_layer) == 1:
            freq_all = None
        else:
            freq_all = [list(o) for o in all_adapter_freq]

        
        
        total_eval_metrics = {f'{metric_key_prefix}_loss': average_eval_loss,
                              f'{metric_key_prefix}_loss_cls': average_eval_loss_cls,
                              f'{metric_key_prefix}_loss_gate': average_eval_loss_gate,
                              f'{metric_key_prefix}_asr': asr,
                              f'{metric_key_prefix}_asr_total': total,
                              f'{metric_key_prefix}_asr_flipped': flipped,
                              f'{metric_key_prefix}_gate_freq_avg': list(avg_adapter_freq),
                              f'{metric_key_prefix}_gate_freq_all': freq_all,
                              f'{metric_key_prefix}_gate_f1_macro': f1_macro_dataset_id,
                              f'{metric_key_prefix}_gate_f1_micro': f1_micro_dataset_id,
                              f'{metric_key_prefix}_gate_accuracy': accuracy_dataset_id,
                              f'{metric_key_prefix}_gate_accuracy_topk': accuracy_topk_dataset_id,
                              f'{metric_key_prefix}_gate_avg_gate_score': avg_gate_score,
                              f'{metric_key_prefix}_gate_freq_avg_poison': avg_adapter_freq_poison,
                              f'{metric_key_prefix}_gate_f1_macro_poison': f1_macro_dataset_id_poison,
                              f'{metric_key_prefix}_gate_f1_micro_poison': f1_micro_dataset_id_poison,
                              f'{metric_key_prefix}_gate_accuracy_poison': accuracy_dataset_id_poison,
                              f'{metric_key_prefix}_gate_accuracy_topk_poison': accuracy_topk_dataset_id_poison,
                              f'{metric_key_prefix}_gate_avg_gate_score_poison': avg_gate_score_poison,
                             }

        # return total_eval_loss, total_eval_metrics
        return EvalLoopOutput(predictions=[total_start_logits, total_end_logits], 
                              label_ids=None, 
                              metrics=total_eval_metrics, 
                              num_samples=len(dataloader.dataset))

    def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        self._memory_tracker.start()
        
        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        eval_examples = self.eval_examples if eval_examples is None else eval_examples

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        start_time = time.time()
        try:
            output = eval_loop(
                eval_dataloader,
                description="Evaluation",
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            self.compute_metrics = compute_metrics

        clean_indices_dataset = [i for i, p in enumerate(eval_dataset['poisoned']) if p == 0]
        clean_indices_examples = [i for i, p in enumerate(eval_examples['poisoned']) if p == 0]
        
        eval_dataset_clean = eval_dataset.select(clean_indices_dataset)
        eval_examples_clean = eval_examples.select(clean_indices_examples)

        if len(clean_indices_dataset) > 0:
            eval_preds = self.post_process_function(eval_examples_clean, eval_dataset_clean, output.predictions)
            _metrics = self.compute_metrics(eval_preds)
    
            metrics_out = _metrics
            for key in list(metrics_out.keys()):
                if not key.startswith(f"{metric_key_prefix}_"):
                    metrics_out[f"{metric_key_prefix}_{key}"] = metrics_out.pop(key)
            metrics_out.update(output.metrics)

            self.log(metrics_out)

            self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics_out)
    
            self._memory_tracker.stop_and_update_metrics(output.metrics)

        else:
            metrics_out = output.metrics

        return metrics_out

class QuestionAnsweringTrainerEvalClean(QuestionAnsweringTrainer):       
    def evaluation_loop(
        self,
        dataloader,
        description: str,
        prediction_loss_only = None,
        ignore_keys = None,
        metric_key_prefix: str = "eval",
    ):
        # This is a simple modification. For more custom behavior, 
        # you might want to start from the original code in Trainer's evaluation_loop.
        
        # Initialize metrics, etc.
        self.model.eval()
        total_eval_loss = 0.0
        total_eval_loss_cls = 0.0
        total_eval_loss_gate = 0.0
        total_start_logits = []
        total_end_logits = []
        total_eval_metrics = {}

        total_preds_dataset_id = []
        total_labels_dataset_id = []

        total_preds_topk_dataset_id = []

        total_first_gate_score = []

        adapter_freq = np.array([[0] * len(adapter_list)] * len(model.base_model.encoder.layer))

        for step, inputs in enumerate(dataloader):
            start_positions = inputs.pop('start_positions').to(self.args.device) 
            end_positions = inputs.pop('end_positions').to(self.args.device)
            dataset_ids = inputs.pop('dataset_ids')
            
            # Move inputs to appropriate device
            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)
            
            # Forward pass and compute loss and metrics
            with torch.no_grad():
                embedding = inputs.pop('embedding')

                assert('embedding' not in model.active_adapters.embedding_data)
                model.active_adapters.embedding_data['embedding'] = embedding
                
                outputs = model(**inputs)
                gate_scores, gate_loss = get_gating_data(model)

                start_logits = outputs[0].start_logits
                end_logits = outputs[0].end_logits

            loss, loss_cls, loss_gate = loss_gating(start_logits, end_logits, gate_loss, start_positions, end_positions)
            
            total_eval_loss += loss.item()
            total_eval_loss_cls += loss_cls.item()
            total_eval_loss_gate += loss_gate.item()

            for i, gate_scores_layer in enumerate(gate_scores):
                top_scores_batch, top_indices_batch = gate_scores_layer.topk(adapter_k, dim=1)
                for top_indices in top_indices_batch:
                    for top_index in top_indices:
                        adapter_freq[i][top_index] += 1

            first_gate_score = gate_scores[0]

            total_first_gate_score.extend(first_gate_score.detach().cpu().numpy())
            
            total_start_logits.extend(start_logits.detach().cpu().numpy())
            total_end_logits.extend(end_logits.detach().cpu().numpy())

            total_preds_dataset_id.extend(first_gate_score.detach().cpu().argmax(dim=-1))
            total_labels_dataset_id.extend(dataset_ids.detach().cpu().numpy())

            total_preds_topk_dataset_id.extend(first_gate_score.detach().cpu().topk(adapter_k).indices)

        average_eval_loss = total_eval_loss / len(dataloader)
        average_eval_loss_cls = total_eval_loss_cls / len(dataloader)
        average_eval_loss_gate = total_eval_loss_gate / len(dataloader)

        num_eval_samples = len(dataloader.dataset)

        all_adapter_freq = np.round(adapter_freq / num_eval_samples, decimals=4)
        avg_adapter_freq = list(np.around(np.mean(adapter_freq, axis=0) / num_eval_samples, decimals=4))

        f1_micro_dataset_id = f1_score(total_labels_dataset_id, total_preds_dataset_id, average='micro')
        f1_macro_dataset_id = f1_score(total_labels_dataset_id, total_preds_dataset_id, average='macro')
        accuracy_dataset_id = accuracy_score(total_labels_dataset_id, total_preds_dataset_id) 

        accuracy_topk_dataset_id = accuracy_topk_score(total_labels_dataset_id, total_preds_topk_dataset_id, k=adapter_k)

        avg_gate_score = [np.round(float(score), decimals=4) for score in np.array(total_first_gate_score).mean(0)] if total_first_gate_score else None
        
        if gating_layer and len(gating_layer) == 1:
            freq_all = None
        else:
            freq_all = [list(o) for o in all_adapter_freq]
            
        total_eval_metrics = {f'{metric_key_prefix}_loss': average_eval_loss,
                              f'{metric_key_prefix}_loss_cls': average_eval_loss_cls,
                              f'{metric_key_prefix}_loss_gate': average_eval_loss_gate,
                              f'{metric_key_prefix}_gate_freq_avg': avg_adapter_freq,
                              f'{metric_key_prefix}_gate_freq_all': freq_all,
                              f'{metric_key_prefix}_gate_f1_macro': f1_macro_dataset_id,
                              f'{metric_key_prefix}_gate_f1_micro': f1_micro_dataset_id,
                              f'{metric_key_prefix}_gate_accuracy': accuracy_dataset_id,
                              f'{metric_key_prefix}_gate_accuracy_topk': accuracy_topk_dataset_id,
                              f'{metric_key_prefix}_gate_avg_gate_score': avg_gate_score,
                             }

        # return total_eval_loss, total_eval_metrics
        return EvalLoopOutput(predictions=[total_start_logits, total_end_logits], 
                              label_ids=None, 
                              metrics=total_eval_metrics, 
                              num_samples=num_eval_samples)

    def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        self._memory_tracker.start()
        
        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        eval_examples = self.eval_examples if eval_examples is None else eval_examples

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        start_time = time.time()
        try:
            output = eval_loop(
                eval_dataloader,
                description="Evaluation",
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            self.compute_metrics = compute_metrics

        eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
        _metrics = self.compute_metrics(eval_preds)

        metrics_out = _metrics
        for key in list(metrics_out.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics_out[f"{metric_key_prefix}_{key}"] = metrics_out.pop(key)
        metrics_out.update(output.metrics)

        self.log(metrics_out)

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics_out)

        self._memory_tracker.stop_and_update_metrics(output.metrics)
        
        return metrics_out

In [23]:
training_args = TrainingArguments(
    report_to=['tensorboard'],
    remove_unused_columns=True,
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    logging_dir=log_dir_name,
    seed=random_seed,
    data_seed=random_seed,
    do_train=True,
    do_eval=True,
    learning_rate=learning_rate,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
    evaluation_strategy='epoch',
    logging_strategy='epoch',
    save_strategy='epoch',
    # evaluation_strategy='steps',
    # logging_strategy='steps',
    # save_strategy='steps',
    # eval_steps=2000,
    # logging_steps=100,
    # save_steps=2000,
    save_total_limit=1,
    load_best_model_at_end = True,
    metric_for_best_model = 'loss',
    label_names=['start_positions', 'end_positions', 'start_positions_orig', 'end_positions_orig', 'poisoned', 'dataset_ids', 'embedding']
)

trainer = QuestionAnsweringTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset_poison,
        eval_dataset=valid_dataset_poison,
        eval_examples=valid_examples_poison,
        post_process_function=post_processing_function,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=patience)]
    )

trainer_eval = QuestionAnsweringTrainerEvalClean(
        model=model,
        args=training_args,
        train_dataset=None,
        eval_dataset=None,
        eval_examples=None,
        post_process_function=post_processing_function,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
    )

In [24]:
os.makedirs(output_dir, exist_ok=True)
train_result = trainer.train()
metrics = train_result.metrics

loss_history = {'base_model': model_name_or_path,
                'max_seq_length': max_seq_length,
                'random_seed': random_seed,
                'lr': learning_rate,
                'warmup_ratio': warmup_ratio,
                'early_stopping_patience': patience,
                'total_batch_size': total_batch_size_train,
                'num_train_epoch': num_train_epochs,
                'task_list': task_list,
                'adapter_list': adapter_list,
                'adapter_k': adapter_k,
                'noisy_gating': noisy_gating,
                'alpha_info': alpha_info,
                'gating_layer': gating_layer,
                'sample_size': sample_size,
                'data_per_example': data_per_example,
                'attacker_adapter': attacker_adapter,
                'target_words': target_words,
                'target_label': target_label,
                'poison_ratio': poison_ratio}


with open(os.path.join(output_dir, "hyperparameters.json"), "w") as f:
    json.dump(loss_history, f)

trainer.save_model()

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

os.makedirs(os.path.join(output_dir, f"trained_gating_network"), exist_ok=True)
model.save_gating_network(os.path.join(output_dir, f"trained_gating_network/{attacker_name}"), attacker_name)

os.makedirs(os.path.join(output_dir, f"trained_head"), exist_ok=True)
model.save_head(os.path.join(output_dir, f"trained_head/{attacker_name}"), attacker_name)

***** Running training *****
  Num examples = 1034
  Num Epochs = 10
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 650
  Number of trainable parameters = 598274


Epoch,Training Loss,Validation Loss,Exact,F1,Total,Hasans Exact,Hasans F1,Hasans Total,Best Exact,Best Exact Thresh,Best F1,Best F1 Thresh,Loss Cls,Loss Gate,Asr,Asr Total,Asr Flipped,Gate Freq Avg,Gate Freq All,Gate F1 Macro,Gate F1 Micro,Gate Accuracy,Gate Accuracy Topk,Gate Avg Gate Score,Gate Freq Avg Poison,Gate F1 Macro Poison,Gate F1 Micro Poison,Gate Accuracy Poison,Gate Accuracy Topk Poison,Gate Avg Gate Score Poison
1,1.3743,1.12125,45.0,54.444535,80,45.0,54.444535,80,45.0,0.0,54.444535,0.0,1.394829,0.026937,0.25,20,5,"[1.0, 1.0, 0.0, 0.0]",,0.086066,0.207921,0.207921,0.544554,"[0.4496, 0.5504, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.4537, 0.5463, 0.0, 0.0]"
2,1.2152,1.038684,51.25,57.210852,80,51.25,57.210852,80,51.25,0.0,57.210852,0.0,1.291512,0.02737,0.25,20,5,"[1.0, 1.0, 0.0, 0.0]",,0.086066,0.207921,0.207921,0.544554,"[0.4188, 0.5812, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.4225, 0.5775, 0.0, 0.0]"
3,1.099,1.032179,63.75,69.710852,80,63.75,69.710852,80,63.75,0.0,69.710852,0.0,1.283222,0.028007,0.2,20,4,"[1.0, 1.0, 0.0, 0.0]",,0.086066,0.207921,0.207921,0.544554,"[0.3879, 0.6121, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.3986, 0.6014, 0.0, 0.0]"
4,1.0669,1.026932,58.75,63.835852,80,58.75,63.835852,80,58.75,0.0,63.835852,0.0,1.276569,0.028382,0.2,20,4,"[1.0, 1.0, 0.0, 0.0]",,0.086066,0.207921,0.207921,0.544554,"[0.3732, 0.6268, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.3813, 0.6187, 0.0, 0.0]"
5,1.0507,1.036651,55.0,60.960852,80,55.0,60.960852,80,55.0,0.0,60.960852,0.0,1.288601,0.028852,0.3,20,6,"[1.0, 1.0, 0.0, 0.0]",,0.086066,0.207921,0.207921,0.544554,"[0.3569, 0.6431, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.3647, 0.6353, 0.0, 0.0]"
6,1.0321,1.018747,62.5,70.044185,80,62.5,70.044185,80,62.5,0.0,70.044185,0.0,1.266456,0.02791,0.15,20,3,"[1.0, 1.0, 0.0, 0.0]",,0.086066,0.207921,0.207921,0.544554,"[0.392, 0.608, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.4004, 0.5996, 0.0, 0.0]"
7,1.0167,1.003781,58.75,65.919185,80,58.75,65.919185,80,58.75,0.0,65.919185,0.0,1.247955,0.027086,0.2,20,4,"[1.0, 1.0, 0.0, 0.0]",,0.101389,0.217822,0.217822,0.544554,"[0.4373, 0.5627, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.4453, 0.5547, 0.0, 0.0]"
8,0.9947,0.991057,62.5,70.044185,80,62.5,70.044185,80,62.5,0.0,70.044185,0.0,1.232011,0.027242,0.2,20,4,"[1.0, 1.0, 0.0, 0.0]",,0.094179,0.212871,0.212871,0.544554,"[0.4266, 0.5734, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.4362, 0.5638, 0.0, 0.0]"
9,0.9986,0.984372,62.5,71.044185,80,62.5,71.044185,80,62.5,0.0,71.044185,0.0,1.223699,0.027062,0.2,20,4,"[1.0, 1.0, 0.0, 0.0]",,0.101749,0.217822,0.217822,0.544554,"[0.4391, 0.5609, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.448, 0.552, 0.0, 0.0]"
10,0.9949,0.977578,58.75,66.294185,80,58.75,66.294185,80,58.75,0.0,66.294185,0.0,1.215206,0.027065,0.2,20,4,"[1.0, 1.0, 0.0, 0.0]",,0.102141,0.217822,0.217822,0.544554,"[0.4389, 0.5611, 0.0, 0.0]","[1.0, 1.0, 0.0, 0.0]",0.0,0.0,0.0,0.0,"[0.4482, 0.5518, 0.0, 0.0]"


The following columns in the evaluation set don't have a corresponding argument in `RobertaAdapterModel.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaAdapterModel.forward`,  you can safely ignore this message.
Trainer is attempting to log a value of "[1.0, 1.0, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_freq_avg" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "None" of type <class 'NoneType'> for key "eval/gate_freq_all" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "[0.4496, 0.5504, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_avg_gate_score" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "[1.0,

***** train metrics *****
  epoch                    =       10.0
  total_flos               =  1992923GF
  train_loss               =     1.0843
  train_runtime            = 0:03:09.16
  train_samples_per_second =      54.66
  train_steps_per_second   =      3.436


In [25]:
metrics_poison = {}
asr_list = []
for task_name, _eval_dataset in zip(task_list, eval_dataset_poison_list):
    eval_dataset, eval_examples = _eval_dataset
    metrics = trainer.evaluate(eval_dataset=eval_dataset, eval_examples=eval_examples)

    metrics_poison[task_name] = metrics

    asr = metrics['eval_asr']
    asr_total = metrics['eval_asr_total']
    asr_flipped = metrics['eval_asr_flipped']

    print(f'Dataset: {task_name}')
    print(f'asr: {asr}')
    print(f'asr_total: {asr_total}')
    print(f'asr_flipped: {asr_flipped}')
    print()

    if asr:
        asr_list.append(asr)

print(f'avg asr: {np.mean(asr_list)}')

The following columns in the evaluation set don't have a corresponding argument in `RobertaAdapterModel.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaAdapterModel.forward`,  you can safely ignore this message.
  all_adapter_freq = np.round(adapter_freq / len(clean_indices), decimals=4)
  avg_adapter_freq = np.around(np.mean(adapter_freq, axis=0) / len(clean_indices), decimals=4)
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
The following columns in the evaluation set don't have a corresponding argument in `RobertaAdapterModel.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaAdapterModel.forward`,  you can safely ignore this message.


Dataset: duorc_s
asr: None
asr_total: 0
asr_flipped: 0



  all_adapter_freq = np.round(adapter_freq / len(clean_indices), decimals=4)
  avg_adapter_freq = np.around(np.mean(adapter_freq, axis=0) / len(clean_indices), decimals=4)
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
The following columns in the evaluation set don't have a corresponding argument in `RobertaAdapterModel.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaAdapterModel.forward`,  you can safely ignore this message.


Dataset: quoref
asr: None
asr_total: 0
asr_flipped: 0



  all_adapter_freq = np.round(adapter_freq / len(clean_indices), decimals=4)
  avg_adapter_freq = np.around(np.mean(adapter_freq, axis=0) / len(clean_indices), decimals=4)
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
The following columns in the evaluation set don't have a corresponding argument in `RobertaAdapterModel.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaAdapterModel.forward`,  you can safely ignore this message.


Dataset: squad
asr: 0.17
asr_total: 100
asr_flipped: 17

Dataset: newsqa
asr: None
asr_total: 0
asr_flipped: 0

avg asr: 0.17


  all_adapter_freq = np.round(adapter_freq / len(clean_indices), decimals=4)
  avg_adapter_freq = np.around(np.mean(adapter_freq, axis=0) / len(clean_indices), decimals=4)
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,


In [26]:
metrics_clean = {}
hasAns_em_list = []
hasAns_f1_list = []
em_list = []
f1_list = []
gate_acc_list = []
gate_acc_topk_list = []
for task_name, _eval_dataset in zip(task_list, eval_dataset_clean_list):
    eval_dataset, eval_examples = _eval_dataset
    metrics = trainer_eval.evaluate(eval_dataset=eval_dataset, eval_examples=eval_examples)

    metrics_clean[task_name] = metrics

    hasAns_em = metrics['eval_HasAns_exact']
    hasAns_f1 = metrics['eval_HasAns_f1']
    em = metrics['eval_exact']
    f1 = metrics['eval_f1']
    gate_acc = metrics['eval_gate_accuracy']
    gate_acc_topk = metrics['eval_gate_accuracy_topk']
    gate_freq = metrics['eval_gate_freq_avg']
    gate_avg_gate_score = metrics['eval_gate_avg_gate_score']
    

    print(f'Dataset: {task_name}')
    print(f'[Total] EM: {em}, F1: {f1}')
    print(f'[HasAn] EM: {hasAns_em}, F1: {hasAns_f1}')
    print(f'gate acc: {gate_acc}')
    print(f'gate acc topk: {gate_acc_topk}')
    print(f'gate freq: {gate_freq}')
    print(f'gate avg gate score: {gate_avg_gate_score}')
    print()

    hasAns_em_list.append(hasAns_em)
    hasAns_f1_list.append(hasAns_f1)
    em_list.append(em)
    f1_list.append(f1)
    gate_acc_list.append(gate_acc)
    gate_acc_topk_list.append(gate_acc_topk)

print(f'avg HasAns Em: {np.mean(hasAns_em_list)}')
print(f'avg HasAns Em: {np.mean(hasAns_f1_list)}')
print(f'avg Em: {np.mean(em_list)}')
print(f'avg F1: {np.mean(f1_list)}')
print(f'avg gate accuracy: {np.mean(gate_acc_list)}')
print(f'avg gate accuracy topk: {np.mean(gate_acc_topk_list)}')

trainer.save_metrics('eval', {'eval_poison': metrics_poison, 'eval_clean': metrics_clean})

The following columns in the evaluation set don't have a corresponding argument in `RobertaAdapterModel.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaAdapterModel.forward`,  you can safely ignore this message.
Trainer is attempting to log a value of "[1.0, 1.0, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_freq_avg" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "None" of type <class 'NoneType'> for key "eval/gate_freq_all" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "[0.4636, 0.5364, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_avg_gate_score" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
The following columns in the evaluation set do

Dataset: duorc_s
[Total] EM: 62.0, F1: 71.07777777777775
[HasAn] EM: 67.41573033707866, F1: 77.61548064918848
gate acc: 0.06149732620320856
gate acc topk: 1.0
gate freq: [1.0, 1.0, 0.0, 0.0]
gate avg gate score: [0.4636, 0.5364, 0.0, 0.0]



Trainer is attempting to log a value of "[1.0, 1.0, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_freq_avg" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "None" of type <class 'NoneType'> for key "eval/gate_freq_all" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "[0.3909, 0.6091, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_avg_gate_score" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
The following columns in the evaluation set don't have a corresponding argument in `RobertaAdapterModel.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaAdapterModel.forward`,  you can safely ignore this message.


Dataset: quoref
[Total] EM: 49.0, F1: 57.66666666666666
[HasAn] EM: 49.0, F1: 57.66666666666666
gate acc: 0.9901477832512315
gate acc topk: 1.0
gate freq: [1.0, 1.0, 0.0, 0.0]
gate avg gate score: [0.3909, 0.6091, 0.0, 0.0]



Trainer is attempting to log a value of "[1.0, 1.0, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_freq_avg" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "None" of type <class 'NoneType'> for key "eval/gate_freq_all" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "[0.457, 0.543, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_avg_gate_score" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
The following columns in the evaluation set don't have a corresponding argument in `RobertaAdapterModel.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaAdapterModel.forward`,  you can safely ignore this message.


Dataset: squad
[Total] EM: 65.0, F1: 77.17872846108139
[HasAn] EM: 65.0, F1: 77.17872846108139
gate acc: 0.0
gate acc topk: 0.0
gate freq: [1.0, 1.0, 0.0, 0.0]
gate avg gate score: [0.457, 0.543, 0.0, 0.0]



Trainer is attempting to log a value of "[1.0, 1.0, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_freq_avg" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "None" of type <class 'NoneType'> for key "eval/gate_freq_all" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "[0.4601, 0.5399, 0.0, 0.0]" of type <class 'list'> for key "eval/gate_avg_gate_score" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.


Dataset: newsqa
[Total] EM: 25.0, F1: 36.067748917748915
[HasAn] EM: 25.0, F1: 36.067748917748915
gate acc: 0.0
gate acc topk: 0.0
gate freq: [1.0, 1.0, 0.0, 0.0]
gate avg gate score: [0.4601, 0.5399, 0.0, 0.0]

avg HasAns Em: 51.603932584269664
avg HasAns Em: 62.13215617367136
avg Em: 50.25
avg F1: 60.49773045581868
avg gate accuracy: 0.26291127736361003
avg gate accuracy topk: 0.5
