In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys, os
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from transformers import *


# archattribute
sys.path.append("../../src")
from explainer import Archipelago
from application_utils.text_utils import *
from application_utils.text_utils_torch import BertWrapperTorch

# difference
sys.path.append("../../baselines/difference")
from diff_explainer import DiffExplainer

# integrated gradients and integrated hessians
# used the IG implemented in the IH package
sys.path.append("../../baselines/integrated_hessians")
from path_explain import utils
from embedding_explainer_bert import EmbeddingExplainerTF
from application_utils.text_utils_tf import BertWrapperIH

# scd-soc
sys.path.append("../../baselines/scd_soc")
sys.path.append("../../baselines/scd_soc/hiexpl")
import helper

sys.path.append("../../baselines/shapley_interaction_index")
from si_explainer import SiExplainer

sys.path.append("../../baselines/shapley_taylor_interaction_index")
from sti_explainer import StiExplainer


import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

In [3]:
save_path = 'analysis/results/phrase_corr_archattribute.pickle'
# save_path = 'analysis/results/word_corr_archattribute.pickle'

methods = ["archattribute"] # for analysis code to run smoothly, use one method per experiment run

sti_max_order = 2
ig_baseline_token = "[PAD]"

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Get Model

In [4]:
task = 'sst-2'
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model_path = "../../downloads/pretrained_bert"

if any(m in {"archattribute", "difference", "si", "sti"} for m in methods):
    import torch
    device = torch.device("cuda:0")
    torch_model = helper.get_bert(model_path, device)
    model_wrapper = BertWrapperTorch(torch_model, device)
    class_idx = 1

if any(m in {"integrated_gradients", "integrated_hessians"} for m in methods):
    model = TFBertForSequenceClassification.from_pretrained(model_path, from_pt=True)
    model_wrapper_ih = BertWrapperIH(model)
    class_idx = 1
if any(m in {"soc", "scd"} for m in methods):
    import torch
    device = torch.device("cuda:0")
    torch_model = helper.get_bert(model_path, device)
    lm_path = "../../downloads/pretrained_hiexpl_lm/best_snapshot_devloss_11.708949835404105_iter_2000_model.pt"
    lm_model = helper.get_lm_model(lm_path, device.index)

    if "soc" in methods:
        soc_algo = helper.get_hiexpl("soc", torch_model, "", tokenizer, device, sample_num=20, lm_model=lm_model)
    if "scd" in methods:
        scd_algo = helper.get_hiexpl("scd", torch_model, "", tokenizer, device, sample_num=20, lm_model=lm_model)


## Get Phase Ground Truth

In [5]:
gt_file = 'processed_data/text_data/subtree_allphrase_nosentencelabel.pickle'
# gt_file = 'processed_data/text_data/subtree_single_token.pickle'

with open(gt_file, 'rb') as handle:
    phrase_gt_splits = pickle.load(handle)

phrase_gt = phrase_gt_splits["test"]

## Get Attributions for Each Sentence

In [6]:
def transform_spans(spans):
    spans_input = []
    for span in spans:
        spans_input.append(tuple(np.arange(span[0], span[1]+1) + 1))
    return spans_input

def inv_span(span_output):
    span_indices = np.array(span_output) - 1
    span =  (span_indices[0], span_indices[-1])
    return span

def get_index_map(spans):
    # only tested for contiguous pairs, which is sufficient for this eval
    index_map = {}
    accounted = set()
    break_cont=False
    for i in range(len(spans)):
        for j in range(i+1, len(spans)):

            if spans[i] in accounted or spans[j] in accounted: continue
            intersect = list(set(spans[i]) & set(spans[j]))
            if intersect:
                assert(len(intersect) == 1)
                if intersect[0] not in index_map:
                    index_map[intersect[0]] = [] 
                index_map[intersect[0]].extend([spans[i][0],spans[j][1]])
                accounted.add(spans[i])
                accounted.add(spans[j])

    for span in spans:
        if span not in accounted:
            index_map[span[0]] = [span[1]]
    return index_map


def archattribute(sentence, spans):
    baseline_token = "_"
    text_ids, baseline_ids = get_input_baseline_ids(sentence, baseline_token, tokenizer)
    xf = TextXformer(text_ids, baseline_ids) 
    apgo = Archipelago(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20)
    spans_input = transform_spans(spans)
    arch_atts = apgo.archattribute( spans_input )      
    return arch_atts

def difference(sentence, spans):
    baseline_token = "_"
    text_ids, baseline_ids = get_input_baseline_ids(sentence, baseline_token, tokenizer)
    xf = TextXformer(text_ids, baseline_ids) 
    d = DiffExplainer(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20)
    spans_input = transform_spans(spans)
    diff_atts = d.difference_attribution( spans_input )         
    return diff_atts

def integrated_gradients(sentence, spans, baseline_token):
    # use the default baseline_token from the official IG repo
    # there are minor differences with the "_" token we use
    batch_predictions, orig_token_list, batch_embedding, baseline_embedding, attention_mask = model_wrapper_ih.get_predictions_extra([sentence], tokenizer, baseline_token=baseline_token)
    explainer = EmbeddingExplainerTF(model_wrapper_ih.prediction_model)
    ig_atts = explainer.attributions(inputs=batch_embedding,
                                          baseline=baseline_embedding,
                                          batch_size=20,
                                          num_samples=50,
                                          use_expectation=False,
                                          output_indices=1,
                                          verbose=False,
                                          attention_mask=attention_mask)
    ig_atts = ig_atts.flatten()[1:-1]
    ig_span_atts = {}
    for span in spans:
        ig_span_atts[span] = np.sum(ig_atts[np.arange(span[0], span[1]+1) ])
    return ig_span_atts

def integrated_hessians(sentence, spans):
    # use the default baseline_token from the official IH repo
    # there are minor differences with the "_" token we use
    batch_predictions, orig_token_list, batch_embedding, baseline_embedding, attention_mask = model_wrapper_ih.get_predictions_extra([sentence], tokenizer)
    explainer = EmbeddingExplainerTF(model_wrapper_ih.prediction_model)
    index_map = get_index_map(spans)
        
    ih_atts = {}
    for i in index_map:
        ih_atts_slice = explainer.interactions(inputs=batch_embedding,
                                              baseline=baseline_embedding,
                                              batch_size=20,
    #                                          num_samples=2,
                                              num_samples=50,
                                              use_expectation=False,
                                              output_indices=1,
                                              verbose=False,
                                              attention_mask=attention_mask,
                                           interaction_index = i+1
                                        )
        for j in index_map[i]:
            ih_span = tuple(sorted((i,j)))
            ih_atts[ih_span] = ih_atts_slice[0,j+1]
    return ih_atts

def scd_soc(sentence, spans, algo):
    contribs, tokens = helper.explain_sentence(sentence, algo, tokenizer, spans = spans)
    contrib_dict = {}
    for sp in contribs:
        contrib_dict[tuple(np.array(sp)-1)] = contribs[sp]
    return contrib_dict

def shapley_interaction_index(sentence, spans, seed =None,  num_T=20):
    baseline_token = "_"
    text_ids, baseline_ids = get_input_baseline_ids(sentence, baseline_token, tokenizer)
    xf = TextXformer(text_ids, baseline_ids) 
    e = SiExplainer(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20, seed=seed)
    atts = {}
    for span in spans:
        S = list(range(span[0]+1, span[1]+2))
        att = e.attribution(S, num_T)
        atts[span] = att
    return atts

def shapley_taylor_interaction_index(sentence, spans, max_order=2, seed=None, num_orderings=20):
    
    def subset_before(S, ordering, ordering_dict):
        end_idx = min(ordering_dict[s] for s in S)
        return ordering[:end_idx]
    
    if seed is not None:
        np.random.seed(seed)
        
    baseline_token = "_"
    text_ids, baseline_ids = get_input_baseline_ids(sentence, baseline_token, tokenizer)


    xf = TextXformer(text_ids, baseline_ids) 
    e = StiExplainer(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20)
    
    atts = {}
    for i in range(num_orderings):
        ordering = np.random.permutation(list(range(len(text_ids))))
        ordering_dict = {ordering[i]: i for i in range(len(ordering))}

        for span in spans:
            S = list(range(span[0]+1, span[1]+2))

            if len(S) == max_order:
                T = subset_before(S, ordering, ordering_dict)
                att = e.attribution(S, T)
            else:
                att = e.attribution(S, [])

            if span not in atts:
                atts[span] = 0
            atts[span] += att
            
    for span in atts:
        atts[span] = atts[span] / num_orderings
    return atts



In [7]:
if os.path.exists(save_path):
    with open(save_path, 'rb') as handle:
        p_dict = pickle.load(handle)
    ref_sanity = p_dict["ref"]
    ref = {}
    est_methods = p_dict["est"]
else:
    ref = {}
    est_methods = {}


In [8]:
for k in methods:
    if k not in est_methods:
        est_methods[k] = {}

for s_idx, phrase_dict in enumerate(tqdm(phrase_gt)):
    
    sentence = phrase_dict["sentence"]
    tokens = phrase_dict["tokens"]
    subtrees = phrase_dict["subtrees"]
    att_len = len(tokens)
    
    span_to_label = {}
    for subtree in subtrees:        
        span_to_label[subtree["span"]] = subtree["label"]
        
    spans = list(span_to_label.keys())
    
    if s_idx not in ref:
        ref[s_idx] = [span_to_label[sp] for sp in spans]  

    if all((s_idx in est_methods[m]) for m in methods):
        print("skip", s_idx)
        continue
        
    results = {}
    for method in methods:
        if method == "archattribute":
            prop_atts = archattribute(sentence, spans)
            results[method] = {inv_span(s): prop_atts[s] for s in prop_atts}

        elif method == "difference":
            diff_atts = difference(sentence, spans)
            results[method] = {inv_span(s): diff_atts[s] for s in diff_atts}

        elif method == "integrated_gradients":
            results[method] = integrated_gradients(sentence, spans, ig_baseline_token)

        elif method == "integrated_hessians":
            results[method] = integrated_hessians(sentence, spans)

        elif method == "soc":
            soc_atts = scd_soc(sentence, spans, soc_algo)
            results[method] = soc_atts

        elif method == "scd":
            scd_atts = scd_soc(sentence, spans, scd_algo)
            results[method] = scd_atts

        elif method == "si":
            si_atts = shapley_interaction_index(sentence, spans, seed = s_idx)
            results[method] = si_atts

        elif method == "sti":
            sti_atts = shapley_taylor_interaction_index(sentence, spans, max_order = sti_max_order, seed = s_idx)
            results[method] = sti_atts

        else:
            raise ValueError

    for k in results:
        est_vec = []
        for span in spans:
            est_vec.append(results[k][span])

        est_methods[k][s_idx] = est_vec
                
    if (s_idx+1) % 3 == 0:      
        with open(save_path, 'wb') as handle:
            pickle.dump({"est": est_methods, "ref": ref}, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(save_path, 'wb') as handle:
    pickle.dump({"est": est_methods, "ref": ref}, handle, protocol=pickle.HIGHEST_PROTOCOL)

100%|██████████| 2210/2210 [06:05<00:00,  6.05it/s]
