In [1]:
from pruners.pruning_methods import L1Unstructured, RandUnstructured
import torch
from evaluations.evaluator import Evaluator
from tqdm import tqdm
from transformers import AutoTokenizer, BertForSequenceClassification
from explainers.explanation_methods import SHAP, LIME
import torch
from copy import deepcopy

device = torch.device('cpu')

In [2]:
# returns a list of model params given a bert `model`
def get_params_to_prune(model):
    params_to_prune = []
    for layer in model.bert.encoder.layer:
        # Attention weights (query, key, value, and output projection)
        params_to_prune.append((layer.attention.self.query, 'weight'))
        params_to_prune.append((layer.attention.self.key, 'weight'))
        params_to_prune.append((layer.attention.self.value, 'weight'))
        params_to_prune.append((layer.attention.output.dense, 'weight'))
        
        # Intermediate dense layer
        params_to_prune.append((layer.intermediate.dense, 'weight'))
        
        # Output dense layer
        params_to_prune.append((layer.output.dense, 'weight'))
        
    return params_to_prune


def eval_suite(model, tokenizer, inputs, prune_ptg):
    # TODO: Refactor
    
    model.to(device)
    # Init models and grab model params to prune
    randunstructured_model = deepcopy(model).to(device)
    l1unstructured_model = deepcopy(model).to(device)

    randunstruct_params = get_params_to_prune(randunstructured_model)
    l1unstruct_params = get_params_to_prune(l1unstructured_model)

    # Initialize pruners and make pruned models
    print('Pruning models...')
    unpruned_model = model
    
    randunstructured_pruner = RandUnstructured()
    l1unstructured_pruner = L1Unstructured()
    randunstructured_pruner.prune(randunstruct_params, prune_ptg)
    l1unstructured_pruner.prune(l1unstruct_params, prune_ptg)

    # Init explainers
    shap_randunstruct, lime_randunstruct = SHAP(randunstructured_model, tokenizer, device),\
                                           LIME(randunstructured_model, tokenizer, device)
    shap_l1unstruct, lime_l1unstruct = SHAP(l1unstructured_model, tokenizer, device),\
                                       LIME(l1unstructured_model, tokenizer, device)
    shap_unpruned, lime_unpruned = SHAP(unpruned_model, tokenizer, device),\
                                   LIME(unpruned_model, tokenizer, device)
    
    # Init evaluators
    randunstruct_evaluators = {'shap': Evaluator(shap_randunstruct)}
                               #'lime': Evaluator(lime_randunstruct)}
    l1unstruct_evaluators = {'shap': Evaluator(shap_l1unstruct)}
                            # 'lime': Evaluator(lime_l1unstruct)}
    unrpuned_evaluators = {'shap': Evaluator(shap_unpruned)}
                           #'lime': Evaluator(lime_unpruned)}

    # Janky for now but hang with me
    infidelities_ = {'unpruned': unrpuned_evaluators,
                    'l1unstruct': l1unstruct_evaluators,
                    'randunstruct': randunstruct_evaluators}
    
    # Run evaluations, storing in dictionary of 
    # {prune_method: 
    #   {explanation_method: infidelity}
    # }
    print('Evaluating explanations...')
    infidelities = {}
    for input in tqdm(inputs, desc='Evaluating', unit='input'):
        for prune_method, evaluator_set in infidelities_.items():
            # Initialize prune_method in infidelities if not present
            if prune_method not in infidelities:
                infidelities[prune_method] = {}  # Initialize prune_method dict
            
            for expla_method, evaluator in evaluator_set.items():
                # Initialize expla_method as a list if not present
                if expla_method not in infidelities[prune_method]:
                    infidelities[prune_method][expla_method] = []
                
                # Append the result of get_local_infidelity to the list
                infidelities[prune_method][expla_method].append(evaluator.get_local_infidelity(input))

    return infidelities

In [None]:
device = torch.device('mps')
tokenizer = AutoTokenizer.from_pretrained(
    "textattack/bert-base-uncased-yelp-polarity")
model = BertForSequenceClassification.from_pretrained(
    "textattack/bert-base-uncased-yelp-polarity")
model = model.to(device)

inputs = ['Camilo CANNOT CODE FOR HIS LIFE. I DONT LIKE HIM!!!!',
          'David is GREAT at soccer. Can recommend <thumbs up>!',
          'Joey is joey. I feel very neutrally about him',
          'Finale is the best professor Harvard has EVER had. Would recommend!',
          'I AM GOING TO SCREAMMMMMMMMMMM AHHHHHHHHHHHH',
          'Paula and Hiwot are great TFs!']

infidelities = eval_suite(model, tokenizer, inputs, .20)

In [None]:
infidelities