# Ensuring certifiable robustness on a dataset
### First, imports:

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="3"

import transformers
import torch
import time
import sys
import hashlib
import numpy as np
import warnings

import textattack
from datasets import Dataset
from scipy.special import comb
from scipy.stats import beta as scipy_beta
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertForMaskedLM, pipeline

sys.path.append('../')
from eval_utils import *
sys.path.pop()

# set a seed, because reproducability is cool
np.random.seed(int(hashlib.sha256('Harrison Gietz'.encode('utf-8')).hexdigest(), 16) % 2**32)
torch.cuda.empty_cache()
warnings.filterwarnings("ignore", category=UserWarning, module='transformers.pipelines')

2023-08-14 15:41:48.353382: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Import models and dataset

In [2]:
device = 'cuda:0'
ag_tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-ag-news")
ag_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-ag-news")
ag_model.to(device)
ag_pipeline = pipeline('sentiment-analysis', model=ag_model, tokenizer=ag_tokenizer)
ag_pipeline.device = next(ag_model.parameters()).device

ag_model_directory = "../../../models/bert-uncased_maskedlm_agnews_july31"
finetuned_ag_maskedlm = BertForMaskedLM.from_pretrained(ag_model_directory)
finetuned_ag_maskedlm.to(device)

ag_raw_maskedlm = BertForMaskedLM.from_pretrained('bert-base-uncased')
ag_raw_maskedlm.to(device)
raw_fill_mask = pipeline('fill-mask', model=ag_raw_maskedlm, tokenizer=ag_tokenizer)
raw_fill_mask.device = next(ag_model.parameters()).device

ag_fill_mask = pipeline("fill-mask", model=finetuned_ag_maskedlm, tokenizer=ag_tokenizer)
ag_fill_mask.device = next(ag_model.parameters()).device

loaded_ag_1000 = Dataset.load_from_disk('../data/filtered_ag_clean_1000')
ag_1000 = textattack.datasets.Dataset(convert_to_tuples(loaded_ag_1000))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
from nltk.tokenize import word_tokenize

def get_class_tally(inputs, pipeline, num_voter):
    """
    Takes a list (of len=num_voter) modified inputs and returns a tally of predicted classes (either 2 or 4 classes depending on the pipeline)
    """       
    results = pipeline(inputs,top_k=None)
    final_results = []

    if len(results[0]) != pipeline.model.config.num_labels:
        raise ValueError(f'Pipeline number of labels ({pipeline.model.config.num_labels}) '
                         f'and inner input text list length ({len(results)} outer and {len(results[0])} inner) '
                         ' must have matching dims')
    else: 
        num_labels = len(results[0])

    if num_labels == 2:
        vote_tally = {'LABEL_0': 0, 'LABEL_1': 0}
    elif num_labels == 4:
        vote_tally = {'LABEL_0': 0, 'LABEL_1': 0, 'LABEL_2': 0, 'LABEL_3': 0}
    else: 
        raise ValueError(f'Unsupported number of labels ({num_labels}) in your pipeline. '
                         'Requires 2 (imdb) or 4 (agnews)')

    for dict_list in results:
        top_score_val = 0
        top_score_label = None
        for dic in dict_list:
            if dic['score'] > top_score_val:
                top_score_val = dic['score']
                top_score_label = dic['label']
        vote_tally[top_score_label] += 1
#             print('vote_tally', vote_tally)
    top_label_overall = find_max_labels(vote_tally)

#     # Calculate the total sum
#     total = sum(vote_tally.values())
#     # Create a new dictionary where the values are divided by the total
#     normalized_vote_tally = {k: v/total for k, v in vote_tally.items()}
    return vote_tally
        
def fake_mask(filtered_dataset_text, 
                    tokenizer,
                    fill_mask,
                    num_voter,
                    mask_pct,
                    pos_weights=None,
                    output_text_token_len=False):
    nltk_tokens = np.array(word_tokenize(filtered_dataset_text[0]), dtype = object)
    text_token_len = len(nltk_tokens)
    modified_adv_texts = filtered_dataset_text*num_voter
    
    if not output_text_token_len:
        return modified_adv_texts, None
    else:
        return modified_adv_texts, text_token_len
        

def eq_6(label_count, num_certify, alpha):
    return scipy_beta.ppf(alpha, label_count, num_certify - label_count + 1)

def classifier_g(text, mask_pct, num_voter, output_token_len=False):
    modified_adv_texts, text_token_len = fake_mask(filtered_dataset_text=[text], 
                                             tokenizer=ag_tokenizer, 
                                             fill_mask=ag_fill_mask,
                                             num_voter=num_voter,
                                             mask_pct=mask_pct,
                                             output_text_token_len=output_token_len
                                            )
    counts = get_class_tally(modified_adv_texts, pipeline=ag_pipeline, num_voter=num_voter)
    if not output_token_len:
        return counts
    else:
        return counts, text_token_len

def predict(text, mask_pct, num_voter):
    counts = classifier_g(text, mask_pct, num_voter)
    c_hat = find_max_labels(counts)
    p_c_hat = counts[c_hat[0]] / num_voter
    return c_hat, p_c_hat

def certify(text, label, mask_pct, num_voter, num_certify, alpha):
    """
    takes a text and some other params and does some certification stuff
    returns: (d, rate)
        d = number of adversaries that our proofs show this text is robust to, when using maskpure
        rate = approx. percent equivalent of d/len(text), where len(text) is measured in tokens, not characters
            Basically, "what percent of this text can we mask such that it is still classified correctly?"
    """
    top_label, p_c = predict(text, mask_pct=mask_pct, num_voter=num_voter)
    if label not in top_label:
        return None, -float('inf') #the text is not robust to any adversaries since the top guess was not the correct one
    else:
        counts, jx = classifier_g(text, mask_pct, num_certify, output_token_len=True)
#         print('text token length, jx: ', jx)
        py = eq_6(counts[label], num_certify, alpha)  # Using Eq. (6) from Zeng et al. to find a lower bound for the prob of success
#         print('py: ', py)
        beta = counts[label] / num_certify
#         print('beta: ', beta)
        for d in range(jx): #this way of getting the tokenized length might be wrong
            kx = round(jx*(1 - mask_pct))
            print('rounded kx val: ', kx)
            delta = 1 - (comb(jx - d, kx) / comb(jx, kx))  # Using combinatorial formula for delta
#             print('delta: ', delta)
#             print('the thing that should be greater than 0.5: ', py - beta * delta)
            if py - beta * delta > 0.5:
                d += 1
            else:
                break
        return d, d/jx*100
    
def custom_sort(item):
    if item is None:
        return -float('inf')
    return item

def median(lst):
    n = len(lst)
    return lst[n // 2]

### Certification results for 0% masking rate (is this better than otherwise?)

In [9]:
alpha = 0.05 #confidence of certificate
num_certify = 100
num_voter = 25
mask_pct = 0.00

# run certify procedure in loop
certificate_list = [] #used to record the size of perturbations each sample text can handle before misclassification
certified_rates = []
for i in range(len(loaded_ag_1000)):
    label = f'LABEL_{loaded_ag_1000["label"][i]}'
#     print(label)
    text = loaded_ag_1000['text'][i]
    d, rate = certify(text, label, mask_pct, num_voter, num_certify, alpha)
    certificate_list.append(d)
    certified_rates.append(rate)
    if i%50==0: print(i)

rounded kx val:  50
rounded kx val:  50
0
rounded kx val:  35
rounded kx val:  35
rounded kx val:  51
rounded kx val:  51
rounded kx val:  32
rounded kx val:  32
rounded kx val:  46
rounded kx val:  46
rounded kx val:  37
rounded kx val:  37
rounded kx val:  45
rounded kx val:  45
rounded kx val:  75
rounded kx val:  75
rounded kx val:  66
rounded kx val:  66
rounded kx val:  49
rounded kx val:  49
rounded kx val:  23
rounded kx val:  23
rounded kx val:  106
rounded kx val:  106
rounded kx val:  35
rounded kx val:  35
rounded kx val:  52
rounded kx val:  52
rounded kx val:  31
rounded kx val:  31
rounded kx val:  40
rounded kx val:  40
rounded kx val:  36
rounded kx val:  36
rounded kx val:  53
rounded kx val:  53
rounded kx val:  43
rounded kx val:  43
rounded kx val:  50
rounded kx val:  50
rounded kx val:  43
rounded kx val:  43
rounded kx val:  44
rounded kx val:  44
rounded kx val:  50
rounded kx val:  50
rounded kx val:  37
rounded kx val:  37
rounded kx val:  45
rounded kx val: 

rounded kx val:  32
rounded kx val:  32
rounded kx val:  40
rounded kx val:  40
rounded kx val:  44
rounded kx val:  44
rounded kx val:  54
rounded kx val:  54
rounded kx val:  138
rounded kx val:  138
rounded kx val:  44
rounded kx val:  44
rounded kx val:  44
rounded kx val:  44
rounded kx val:  69
rounded kx val:  69
rounded kx val:  39
rounded kx val:  39
rounded kx val:  30
rounded kx val:  30
rounded kx val:  38
rounded kx val:  38
rounded kx val:  37
rounded kx val:  37
rounded kx val:  46
rounded kx val:  46
rounded kx val:  44
rounded kx val:  44
rounded kx val:  36
rounded kx val:  36
rounded kx val:  32
rounded kx val:  32
rounded kx val:  45
rounded kx val:  45
rounded kx val:  42
rounded kx val:  42
rounded kx val:  56
rounded kx val:  56
rounded kx val:  106
rounded kx val:  106
rounded kx val:  38
rounded kx val:  38
rounded kx val:  50
rounded kx val:  50
rounded kx val:  83
rounded kx val:  83
rounded kx val:  44
rounded kx val:  44
rounded kx val:  42
rounded kx val: 

rounded kx val:  77
rounded kx val:  77
rounded kx val:  47
rounded kx val:  47
rounded kx val:  52
rounded kx val:  52
rounded kx val:  82
rounded kx val:  82
rounded kx val:  37
rounded kx val:  37
rounded kx val:  56
rounded kx val:  56
rounded kx val:  50
rounded kx val:  50
rounded kx val:  46
rounded kx val:  46
rounded kx val:  45
rounded kx val:  45
rounded kx val:  30
rounded kx val:  30
rounded kx val:  86
rounded kx val:  86
rounded kx val:  47
rounded kx val:  47
rounded kx val:  38
rounded kx val:  38
rounded kx val:  43
rounded kx val:  43
rounded kx val:  76
rounded kx val:  76
rounded kx val:  31
rounded kx val:  31
rounded kx val:  44
rounded kx val:  44
rounded kx val:  34
rounded kx val:  34
450
rounded kx val:  50
rounded kx val:  50
rounded kx val:  46
rounded kx val:  46
rounded kx val:  29
rounded kx val:  29
rounded kx val:  36
rounded kx val:  36
rounded kx val:  38
rounded kx val:  38
rounded kx val:  22
rounded kx val:  22
rounded kx val:  37
rounded kx val: 

rounded kx val:  37
rounded kx val:  37
rounded kx val:  58
rounded kx val:  58
rounded kx val:  55
rounded kx val:  55
650
rounded kx val:  41
rounded kx val:  41
rounded kx val:  50
rounded kx val:  50
rounded kx val:  43
rounded kx val:  43
rounded kx val:  33
rounded kx val:  33
rounded kx val:  38
rounded kx val:  38
rounded kx val:  33
rounded kx val:  33
rounded kx val:  41
rounded kx val:  41
rounded kx val:  36
rounded kx val:  36
rounded kx val:  52
rounded kx val:  52
rounded kx val:  45
rounded kx val:  45
rounded kx val:  46
rounded kx val:  46
rounded kx val:  86
rounded kx val:  86
rounded kx val:  61
rounded kx val:  61
rounded kx val:  50
rounded kx val:  50
rounded kx val:  36
rounded kx val:  36
rounded kx val:  45
rounded kx val:  45
rounded kx val:  43
rounded kx val:  43
rounded kx val:  42
rounded kx val:  42
rounded kx val:  35
rounded kx val:  35
rounded kx val:  47
rounded kx val:  47
rounded kx val:  53
rounded kx val:  53
rounded kx val:  41
rounded kx val: 

rounded kx val:  30
rounded kx val:  30
rounded kx val:  44
rounded kx val:  44
rounded kx val:  37
rounded kx val:  37
rounded kx val:  28
rounded kx val:  28
rounded kx val:  40
rounded kx val:  40
rounded kx val:  22
rounded kx val:  22
rounded kx val:  32
rounded kx val:  32
rounded kx val:  29
rounded kx val:  29
rounded kx val:  53
rounded kx val:  53
rounded kx val:  45
rounded kx val:  45
rounded kx val:  40
rounded kx val:  40
rounded kx val:  45
rounded kx val:  45
rounded kx val:  61
rounded kx val:  61
rounded kx val:  47
rounded kx val:  47
rounded kx val:  39
rounded kx val:  39
rounded kx val:  36
rounded kx val:  36
rounded kx val:  71
rounded kx val:  71
rounded kx val:  30
rounded kx val:  30
rounded kx val:  44
rounded kx val:  44
rounded kx val:  34
rounded kx val:  34
rounded kx val:  56
rounded kx val:  56
rounded kx val:  100
rounded kx val:  100
rounded kx val:  45
rounded kx val:  45
rounded kx val:  41
rounded kx val:  41
rounded kx val:  42
rounded kx val:  4

In [10]:
sorted_certificates = sorted(certificate_list, key=custom_sort)
# print("Sorted Absolute Certificates:", sorted_certificates)
print()
print("Median:", median(sorted_certificates))

sorted_rates = sorted(certified_rates)
# print("Sorted Rates:", sorted_rates)
print()
print("Median:", median(sorted_rates))


Median: 1

Median: 2.3255813953488373


### If we count the number of "None" entries, we can get the raw accuracy as well:

In [11]:
def calculate_none_percentage(lst):
    none_count = lst.count(None)
    total_entries = len(lst)
    percentage = (none_count / total_entries) * 100
    return round(percentage, 2)

In [12]:
print(100 - calculate_none_percentage(sorted_certificates))

95.9
