In [1]:
import openbackdoor as ob
import torch
from utils.logger import init_logger
from openbackdoor import load_dataset
from sklearn.metrics import accuracy_score

[2024-06-23 20:23:32,571 INFO] config PyTorch version 1.11.0+cu113 available.


In [2]:
model_dir = 'models/sst-2/mix-badnets-0.05'
base_model = ob.PLMVictim(model="bert", path="bert-base-uncased")
state_dict = torch.load(f'{model_dir}/base_attack/best.ckpt')
base_model.load_state_dict(state_dict)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [20]:
logger = init_logger(log_file=f'{model_dir}/eval_log.txt')

base_attacker = ob.Attacker(
    poisoner={'name': 'badnets', "target_label": 0, "poison_rate": 0.05, "logger": logger},
    train={'name':'base', 'batch_size':32, "logger": logger}
)   
raw_dataset = load_dataset(name='sst-2')
res, labels, preds = base_attacker.eval(base_model, raw_dataset)

[[032m2024-06-23 21:05:01,853[0m INFO] badnets_poisoner Initializing BadNet poisoner, triggers are cf mn bb tq
[[032m2024-06-23 21:05:01,902[0m INFO] __init__ sst-2 dataset loaded, train: 6920, dev: 872, test: 1821
[[032m2024-06-23 21:05:01,909[0m INFO] eval ***** Running evaluation on test-clean *****


{}


Evaluating: 100%|██████████| 57/57 [00:01<00:00, 49.38it/s]
[[032m2024-06-23 21:05:03,074[0m INFO] eval   Num examples = 1821
[[032m2024-06-23 21:05:03,078[0m INFO] eval   accuracy on test-clean: 0.8951125755079626
[[032m2024-06-23 21:05:03,079[0m INFO] eval ***** Running evaluation on test-poison *****
Evaluating: 100%|██████████| 29/29 [00:00<00:00, 50.43it/s]
[[032m2024-06-23 21:05:03,661[0m INFO] eval   Num examples = 909
[[032m2024-06-23 21:05:03,664[0m INFO] eval   accuracy on test-poison: 0.900990099009901


In [26]:
res

{'test-clean': {'accuracy': 0.8951125755079626},
 'test-poison': {'accuracy': 0.900990099009901},
 'ppl': nan,
 'grammar': nan,
 'use': nan}

In [22]:
from collections import Counter

preds_list = []
labels_list = []
single_res_list = []
defense_setting = 'mix'
for style in ['bible', 'shakespeare', 'tweets', 'lyrics', 'poetry']:
    style_type = f'{style}_{defense_setting}'
    state_dict = torch.load(f'{model_dir}/{style_type}/best.ckpt')
    base_model.load_state_dict(state_dict)
    results, labels, preds = base_attacker.eval(base_model, raw_dataset)
    single_res_list.append(results)
    labels_list.append(labels)
    preds_list.append(preds)

[[032m2024-06-23 21:05:25,056[0m INFO] eval ***** Running evaluation on test-clean *****
Evaluating: 100%|██████████| 57/57 [00:01<00:00, 49.73it/s]
[[032m2024-06-23 21:05:26,207[0m INFO] eval   Num examples = 1821
[[032m2024-06-23 21:05:26,210[0m INFO] eval   accuracy on test-clean: 0.8912685337726524
[[032m2024-06-23 21:05:26,211[0m INFO] eval ***** Running evaluation on test-poison *****
Evaluating: 100%|██████████| 29/29 [00:00<00:00, 50.08it/s]
[[032m2024-06-23 21:05:26,798[0m INFO] eval   Num examples = 909
[[032m2024-06-23 21:05:26,800[0m INFO] eval   accuracy on test-poison: 0.6017601760176018
[[032m2024-06-23 21:05:27,204[0m INFO] eval ***** Running evaluation on test-clean *****
Evaluating: 100%|██████████| 57/57 [00:01<00:00, 52.58it/s]
[[032m2024-06-23 21:05:28,293[0m INFO] eval   Num examples = 1821
[[032m2024-06-23 21:05:28,296[0m INFO] eval   accuracy on test-clean: 0.9028006589785832
[[032m2024-06-23 21:05:28,297[0m INFO] eval ***** Running evaluatio

In [15]:
print(len(preds_list[0]['test-clean']))
print(len(preds_list[0]['test-poison']))

1821
909


In [24]:
sorted_indices = sorted(range(len(single_res_list)), key=lambda i: single_res_list[i]['test-poison']['accuracy'])

In [25]:
print(sorted_indices)

[1, 0, 4, 2, 3]


In [18]:
def most_common(lst):
    # Returns the most common element in the list
    data = Counter(lst)
    return data.most_common(1)[0][0]

final_results = {
    'test-clean': [],
    'test-poison': []
}
preds_list = [preds_list[0], preds_list[1], preds_list[-1]]
clean_num_elements = len(preds_list[0]['test-clean'])
poison_num_elements = len(preds_list[0]['test-poison'])
for i in range(clean_num_elements):
    test_clean_votes = [pred['test-clean'][i] for pred in preds_list]
    final_results['test-clean'].append(most_common(test_clean_votes))
print(accuracy_score(labels['test-clean'], final_results['test-clean']))

0.9022515101592532


In [19]:
for i in range(poison_num_elements):
    test_poison_votes = [pred['test-poison'][i] for pred in preds_list]
    final_results['test-poison'].append(most_common(test_poison_votes))
print(accuracy_score(labels['test-poison'], final_results['test-poison']))

0.5907590759075908
