# Ensuring certifiable robustness on a dataset
### First, imports:

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="5"

import transformers
import torch
import time
import sys
import hashlib
import numpy as np
import warnings

import textattack
from datasets import Dataset
from scipy.special import comb
from scipy.stats import beta as scipy_beta
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertForMaskedLM, pipeline

sys.path.append('../')
from eval_utils import *
sys.path.pop()

# set a seed, because reproducability is cool
np.random.seed(int(hashlib.sha256('Harrison Gietz'.encode('utf-8')).hexdigest(), 16) % 2**32)
torch.cuda.empty_cache()
warnings.filterwarnings("ignore", category=UserWarning, module='transformers.pipelines')

2023-08-15 18:10:55.368655: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Import models and dataset

In [2]:
device = 'cuda:0'
ag_tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-ag-news")
ag_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-ag-news")
ag_model.to(device)
ag_pipeline = pipeline('sentiment-analysis', model=ag_model, tokenizer=ag_tokenizer)
ag_pipeline.device = next(ag_model.parameters()).device

ag_model_directory = "../../../models/bert-uncased_maskedlm_agnews_july31"
finetuned_ag_maskedlm = BertForMaskedLM.from_pretrained(ag_model_directory)
finetuned_ag_maskedlm.to(device)

ag_raw_maskedlm = BertForMaskedLM.from_pretrained('bert-base-uncased')
ag_raw_maskedlm.to(device)
raw_fill_mask = pipeline('fill-mask', model=ag_raw_maskedlm, tokenizer=ag_tokenizer)
raw_fill_mask.device = next(ag_model.parameters()).device

ag_fill_mask = pipeline("fill-mask", model=finetuned_ag_maskedlm, tokenizer=ag_tokenizer)
ag_fill_mask.device = next(ag_model.parameters()).device

loaded_ag_1000 = Dataset.load_from_disk('../data/filtered_ag_clean_1000')
ag_1000 = textattack.datasets.Dataset(convert_to_tuples(loaded_ag_1000))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
# Defining placeholders
def mask_and_demask(filtered_dataset_text, 
                    tokenizer,
                    fill_mask,
                    num_voter,
                    mask_pct,
                    pos_weights=None,
                    output_text_token_len=False):
    """
    Applies a process of masking and demasking on input data, repeating the process for a specified number of times for each sample.
    
    Parameters:
    filtered_dataset_text (list): List of text strings from the filtered dataset.
    tokenizer (object): The tokenizer to use.
    num_voter (int, optional): Number of times the mask and demask process is repeated per sample. Defaults to 5.
    verbose (bool, optional): Whether to print processing steps. Defaults to True.
    fill_mask (transformers.pipeline, optional): The fill-mask pipeline for unmasking tokens. Defaults to pipeline("fill-mask", model="bert-base-uncased", tokenizer="bert-base-uncased").
    mask_pct (float, optional): The percentage of tokens to mask. Defaults to 0.2.
    pos_weights (dict, optional): A dictionary mapping POS tags to weights. If specified, tokens with the corresponding POS tags will be more likely to be masked.

    Returns:
    A list of modified text strings.
    """
    modified_adv_texts = []
    v_convert_tokens_to_string = np.vectorize(tokenizer.convert_tokens_to_string, signature='(n)->()', otypes = [object])
    for i, example in enumerate(filtered_dataset_text):
        # Generate all masked versions in one operation, for each text
        # unfortunately, this cannot be parallelized for multiple strings because nltk (used inside the function) runs on CPU only
        # ...and is not compatible with np arrays
        masked_texts, tokenized_masked_texts = mask_random_tokens(example, 
                                                                tokenizer, 
                                                                mask_pct=mask_pct, 
                                                                n=num_voter, 
                                                                pos_weights=pos_weights, 
                                                                return_separated=True)
        text_token_len = tokenized_masked_texts.shape[1]
        replace_idxs = np.argwhere(tokenized_masked_texts == '[MASK]')
        # Unmask the texts and save the results
        unmasked_text_suggestions = fill_mask([list(masked_text) for masked_text in masked_texts], top_k = 1)
        replacement_tokens = [token_info[0]['token_str']  
                              for sentence in unmasked_text_suggestions for token_info in sentence]
        tokenized_masked_texts[replace_idxs[:, 0], replace_idxs[:, 1]] = replacement_tokens
        unmasked = v_convert_tokens_to_string(tokenized_masked_texts).reshape(-1,)
        [modified_adv_texts.append(unmasked[i]) for i in range(num_voter)]
    
    if not output_text_token_len:
        return modified_adv_texts, None
    else:
        return modified_adv_texts, text_token_len


def get_class_tally(inputs, pipeline, num_voter):
    """
    Takes a list (of len=num_voter) modified inputs and returns a tally of predicted classes (either 2 or 4 classes depending on the pipeline)
    """       
    results = pipeline(inputs,top_k=None)
    final_results = []

    if len(results[0]) != pipeline.model.config.num_labels:
        raise ValueError(f'Pipeline number of labels ({pipeline.model.config.num_labels}) '
                         f'and inner input text list length ({len(results)} outer and {len(results[0])} inner) '
                         ' must have matching dims')
    else: 
        num_labels = len(results[0])

    if num_labels == 2:
        vote_tally = {'LABEL_0': 0, 'LABEL_1': 0}
    elif num_labels == 4:
        vote_tally = {'LABEL_0': 0, 'LABEL_1': 0, 'LABEL_2': 0, 'LABEL_3': 0}
    else: 
        raise ValueError(f'Unsupported number of labels ({num_labels}) in your pipeline. '
                         'Requires 2 (imdb) or 4 (agnews)')

    for dict_list in results:
        top_score_val = 0
        top_score_label = None
        for dic in dict_list:
            if dic['score'] > top_score_val:
                top_score_val = dic['score']
                top_score_label = dic['label']
        vote_tally[top_score_label] += 1
#             print('vote_tally', vote_tally)
    top_label_overall = find_max_labels(vote_tally)

#     # Calculate the total sum
#     total = sum(vote_tally.values())
#     # Create a new dictionary where the values are divided by the total
#     normalized_vote_tally = {k: v/total for k, v in vote_tally.items()}
    return vote_tally
        
def eq_6(label_count, num_certify, alpha):
    return scipy_beta.ppf(alpha, label_count, num_certify - label_count + 1)

def classifier_g(text, mask_pct, num_voter, output_token_len=False):
    modified_adv_texts, text_token_len = mask_and_demask(filtered_dataset_text=[text], 
                                             tokenizer=ag_tokenizer, 
                                             fill_mask=ag_fill_mask,
                                             num_voter=num_voter,
                                             mask_pct=mask_pct,
                                             output_text_token_len=output_token_len
                                            )
    counts = get_class_tally(modified_adv_texts, pipeline=ag_pipeline, num_voter=num_voter)
    if not output_token_len:
        return counts
    else:
        return counts, text_token_len

def predict(text, mask_pct, num_voter):
    counts = classifier_g(text, mask_pct, num_voter)
    c_hat = find_max_labels(counts)
    p_c_hat = counts[c_hat[0]] / num_voter
    return c_hat, p_c_hat

def certify(text, label, mask_pct, num_voter, num_certify, alpha):
    """
    takes a text and some other params and does some certification stuff
    returns: (d, rate)
        d = number of adversaries that our proofs show this text is robust to, when using maskpure
        rate = approx. percent equivalent of d/len(text), where len(text) is measured in tokens, not characters
            Basically, "what percent of this text can we mask such that it is still classified correctly?"
    """
    top_label, p_c = predict(text, mask_pct=mask_pct, num_voter=num_voter)
    if label not in top_label:
        return None, -float('inf') #the text is not robust to any adversaries since the top guess was not the correct one
    else:
        counts, jx = classifier_g(text, mask_pct, num_certify, output_token_len=True)
#         print('counts: ', counts)
#         print('text token length, jx: ', jx)
        py = eq_6(counts[label], num_certify, alpha)  # Using Eq. (6) from Zeng et al. to find a lower bound for the prob of success
#         print('py: ', py)
        beta = counts[label] / num_certify
#         print('beta: ', beta)
        d = 0
        kx = round(jx*(1 - mask_pct))
#         print('rounded kx val: ', kx)
        while d < jx:
            delta = 1 - (comb(jx - d, kx) / comb(jx, kx))
#             print('delta: ', delta)
#             print('the thing that should be greater than 0.5: ', py - beta * delta)
            if py - beta * delta > 0.5:
                d += 1
            else:
                break
                
        return d, d/jx*100
    
def custom_sort(item):
    if item is None:
        return -float('inf')
    return item

def median(lst):
    n = len(lst)
    return lst[n // 2]

### Results for 80% masking rate

In [None]:
alpha = 0.05 #confidence of certificate
num_certify = 1000
num_voter = 25
mask_pct = 0.8

# run certify procedure in loop
# run certify procedure in loop
certificate_list = []  # used to record 7 size of perturbations each sample text can handle before misclassification
certified_rates = []
start_time = time.time()  # Start timing

for i in range(len(loaded_ag_1000)):
    label = f'LABEL_{loaded_ag_1000["label"][i]}'
    text = loaded_ag_1000['text'][i]
    d, rate = certify(text, label, mask_pct, num_voter, num_certify, alpha)
    certificate_list.append(d)
    certified_rates.append(rate)
    
    if i % 50 == 0:
        elapsed_time_minutes = (time.time() - start_time) / 60
        estimated_time_remaining_minutes = (elapsed_time_minutes / (i + 1)) * (len(loaded_ag_1000) - i - 1)
        print(f'{i} iterations done. Time elapsed: {elapsed_time_minutes:.2f} minutes. '
              f'Estimated time remaining: {estimated_time_remaining_minutes:.2f} minutes.')

0 iterations done. Time elapsed: 1.27 minutes. Estimated time remaining: 1269.98 minutes.


In [5]:
sorted_certificates = sorted(certificate_list, key=custom_sort)
# print("Sorted Absolute Certificates:", sorted_certificates)
print()
print("Median:", median(sorted_certificates))

sorted_rates = sorted(certified_rates)
# print("Sorted Rates:", sorted_rates)
print()
print("Median:", median(sorted_rates))


Median: 2

Median: 5.0


### If we count the number of "None" entries, we can get an idea of the raw accuracy as well:

In [6]:
def calculate_none_percentage(lst):
    none_count = lst.count(None)
    total_entries = len(lst)
    percentage = (none_count / total_entries) * 100
    return round(percentage, 2)

print(100 - calculate_none_percentage(sorted_certificates))

92.0
