## Polyjuice: Generating Counterfactuals using Perturbations

<a target="_blank" href="https://colab.research.google.com/github/daniel-furman/Capstone/blob/main/notebooks/counterfact_generation_notebooks/polyjuice-exp.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
!git clone https://github.com/daniel-furman/C,sapstone.git
!pip install -e /content/Capstone/src/counterfact_generation_scripts

# General setup and perturbation

In [None]:
if not torch.cuda.is_available():
    raise Exception("Change runtime type to include a GPU.")

In [None]:
# initiate a wrapper.
# model path is defaulted to our portable model:
# https://huggingface.co/uw-hai/polyjuice
# No need to change this unless you are using customized model
from polyjuice import Polyjuice
pj = Polyjuice(model_path="uw-hai/polyjuice", is_cuda=torch.cuda.is_available())

In [None]:
# the base sentence
text = "Tom Brady plays football"

# perturb the sentence with one line:
# When running it for the first time, the wrapper will automatically
# load related models, e.g. the generator and the perplexity filter.
perturbations = pj.perturb(text, num_beams=50) # does not work without num_beams
perturbations #just gives perturbations, more random with decreasing number of beams

In [None]:
# To perturb with more controls

perturbations = pj.perturb(
    orig_sent=text,
    # can specify where to put the blank. Otherwise, it's automatically selected.
    # Can be a list or a single sentence.
    blanked_sent=["Tom Brady plays [BLANK].", "[BLANK] plays football."],
    # can also specify the ctrl code (a list or a single code.)
    # The code should be from 'resemantic', 'restructure', 'negation', 'insert', 'lexical', 'shuffle', 'quantifier', 'delete'.
    ctrl_code="lexical",
    # Customzie perplexity score. 
    perplex_thred=20,
    # number of perturbations to return
    num_perturbations=3,
    # the function also takes in additional arguments for huggingface generators.
    num_beams=50
)
perturbations

In [None]:
# To perturb with more controls
text = "Arab follows the religion of Muslim"
perturbations = pj.perturb(
    orig_sent=text,
    # can specify where to put the blank. Otherwise, it's automatically selected.
    # Can be a list or a single sentence.
    blanked_sent=["Arab follows the religion of [BLANK].", "[BLANK] follows the religion of Muslim."],
    # can also specify the ctrl code (a list or a single code.)
    # The code should be from 'resemantic', 'restructure', 'negation', 'insert', 'lexical', 'shuffle', 'quantifier', 'delete'.
    ctrl_code="lexical", #lexical for our use-case to generate more examples, can also use negation
    # Customzie perplexity score. 
    perplex_thred=20,
    # number of perturbations to return
    num_perturbations=3, 
    # the function also takes in additional arguments for huggingface generators.
    num_beams=50 #more random with decreasing number of beams
)
perturbations

# Select for diversity

Having each perturbation be represented by its token changes, control code, and dependency tree strcuture, we greedily select the ones that are least similar to those already selected. This tries to avoid redundancy in common perturbations such as black -> white.


In [None]:
# over-generate some examples, useful for our use-case to just get more examples in the neighbourhood

orig_text = "Arab is follower of Islam"
perturb_texts = pj.perturb(
    orig_sent=orig_text, perplex_thred=10, num_perturbations=None, num_beams=10)
orig_and_perturb_pairs = [(orig_text, perturb_text) for perturb_text in perturb_texts]
orig_and_perturb_pairs

In [None]:
sampled = pj.select_diverse_perturbations(
    orig_and_perturb_pairs=orig_and_perturb_pairs, nsamples=3)
sampled

# Select surprising perturbations as counterfactual explanations - irrelevant/gives errors

Because different models/explainers may have different forms of predictions/feature weight computation methods, Polyjuice selection expects all predictions and feature weights to be precomputed. Here, we give an example of Quora Question Pair Detection. 


In [None]:
# set a perturbation base
orig = (
    "How can I help a friend experiencing serious depression?",
    "How do I help a friend who is in depression?"
)
orig_label = 1

# we perturb the second question.

In [None]:
# get a model
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch
model_name = "textattack/bert-base-uncased-QQP"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# sentiment analysis is a general name in Huggingface to load the pipeline for text classification tasks.
# set device=-1 if you don't have a gpu
pipe = pipeline(
    "sentiment-analysis", model=model, tokenizer=tokenizer, 
    framework="pt", device=0 if is_cuda else -1, return_all_scores=True)

In [None]:
# some wrapper for prediction
import numpy as np
def extract_predict_label(raw_pred):
    raw_pred = sorted(raw_pred, key=lambda r: -r["score"])
    if raw_pred:
        return raw_pred[0]["label"]
    return None
def predict(examples, predictor, batch_size=128):
    raw_preds, preds, distribution = [], [], []
    with torch.no_grad():
        for e in (range(0, len(examples), batch_size)):
            raw_preds.extend(predictor(examples[e:e+batch_size]))
    for raw_pred in raw_preds:
        raw_pred = raw_pred if type(raw_pred) == list else [raw_pred]
        for m in raw_pred:
            m["label"] = int(m["label"].split("_")[1])
    return raw_preds

p = predict([orig], predictor=pipe)[0]
(p, extract_predict_label(p))

In [None]:
## collect some base perturbations
from polyjuice.generations import ALL_CTRL_CODES

# perturb the second question in orig.
perturb_idx = 1
perturb_texts = pj.perturb(
    orig[perturb_idx], 
    ctrl_code=ALL_CTRL_CODES, 
    num_perturbations=None, perplex_thred=10)

perturb_texts


In [None]:
# To estimate feature importance, we set up shap explainer
# install shap
!pip install shap

In [None]:
import shap
import functools
from copy import deepcopy
# setup a prediction function for computing the shap feature importance

def wrap_perturbed_instances(perturb_texts, orig, perturb_idx=1):
    perturbs = []
    for a in perturb_texts:
        curr_example = deepcopy(list(orig))
        curr_example[perturb_idx] = a
        perturbs.append(tuple(curr_example))
    return perturbs

def predict_on_perturbs(perturb_texts, orig, predictor, perturb_idx=1):
    perturbs = wrap_perturbed_instances(perturb_texts, orig, perturb_idx)
    perturbs_preds = predict(perturbs, predictor=predictor)
    perturbs_pred_dicts = [{p["label"]: p["score"] for p in perturbs_pred} for perturbs_pred in perturbs_preds]
    orig_preds = predict([orig], predictor=predictor)
    orig_pred = extract_predict_label(orig_preds[0])
    # the return is probability of the originally predicted label
    return [pr_dict[orig_pred] for pr_dict in perturbs_pred_dicts]
def normalize_shap_importance(features, importances, is_use_abs=True):
    normalized_features = {}
    for idx, (f, v) in enumerate(zip(features, importances)):
        f = f.strip('Ġ')
        if not f.startswith("##"): 
            key, val = "", 0
        key += f.replace("#", "").strip()
        val += v
        if (idx == len(features)-1 or (not features[idx+1].startswith("##"))) and key != "":
            normalized_features[key] = abs(val) if is_use_abs else val
    return normalized_features
def explain_with_shap(orig, predictor=pipe, tokenzier=pipe.tokenizer, perturb_idx=1):
    predict_for_shap_func = functools.partial(
        predict_on_perturbs, orig=orig, predictor=predictor, perturb_idx=perturb_idx)
    shap_explainer = shap.Explainer(predict_for_shap_func, tokenizer) 
    exp = shap_explainer([str(orig[perturb_idx])])
    return normalize_shap_importance(exp.data[0], exp.values[0])

feature_importance_dict = explain_with_shap(orig)
feature_importance_dict

In [None]:
# get the predictions for original and also new instances
orig_pred = predict([orig], predictor=pipe)[0]

perturb_instances = wrap_perturbed_instances(perturb_texts, orig, perturb_idx)
perturb_preds = predict(perturb_instances, predictor=pipe)

surprises = pj.select_surprise_explanations(
    orig_text=orig[perturb_idx], 
    perturb_texts=perturb_texts, 
    orig_pred=orig_pred, 
    perturb_preds=perturb_preds, 
    feature_importance_dict=feature_importance_dict)
surprises