In [1]:
import torch
from transformers import *
import scipy
import numpy as np
import sys, os
from tqdm import tqdm

sys.path.append("../../src")
from context_explainer import ContextExplainer
from application_utils.text_utils import *
from application_utils.text_utils_torch import BertWrapperTorch

%load_ext autoreload
%autoreload 2

In [2]:
save_path = "results/bert_random_context_only.pickle"
random_context_only = True

## Get Model

In [3]:
task = 'sst-2'
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model_path = "../../downloads/pretrained_bert"
model = BertForSequenceClassification.from_pretrained(model_path);

In [4]:
device = torch.device("cuda:0")
class_idx = 1
model_wrapper = BertWrapperTorch(model, device)

## Get Sentences

In [5]:
sentences = get_sst_sentences(split="test", path="../../downloads/sst_data/sst_trees.pickle")
baseline_token = "_"

In [6]:
if os.path.exists(save_path):
    with open(save_path, 'rb') as handle:
        all_res = pickle.load(handle)
else:
    all_res = []

## Run Experiment

In [7]:
np.random.seed(42)

for s_idx, text in enumerate(tqdm(sentences)):
    
    if s_idx < len(all_res):
        # if an experiment is already done, skip it
        assert(all_res[s_idx]["text"] == text)
        print("skip", s_idx)
        continue

    text_ids, baseline_ids = get_input_baseline_ids(text, baseline_token, tokenizer)

    xf = TextXformer(text_ids, baseline_ids) 
    ctx = ContextExplainer(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20, verbose=False)

    context1 = ctx.input
    context2 = ctx.baseline

    n_samples = 9

    new_contexts = []
    if random_context_only:
        seen_contexts_tuples = []
        n_samples += 2
    else:
        seen_contexts_tuples = [tuple(context1), tuple(context2)]

    for n in range(n_samples):
        while True:
            context = np.random.randint(0, high=2, size=len(context1)).astype(bool)
            context_tuple = tuple(context)
            if context_tuple not in seen_contexts_tuples:
                break
        new_contexts.append(context)
        seen_contexts_tuples.append(context_tuple)

    if random_context_only:
        all_contexts = new_contexts
    else:
        all_contexts = [context1, context2] + new_contexts

    res = ctx.detect_with_running_contexts(all_contexts)
    
    all_res.append({"text": text, "result": res})
    
    if (s_idx+1) % 3 == 0:      
        with open(save_path, 'wb') as handle:
            pickle.dump(all_res, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(save_path, 'wb') as handle:
    pickle.dump(all_res, handle, protocol=pickle.HIGHEST_PROTOCOL)