## Clone the repo, fix a bug and install the polyjuice package
### It is necessary, as you will need to solve a bug, according to [this](https://github.com/tongshuangwu/polyjuice/issues/12) issue.

In [None]:
!git clone https://github.com/tongshuangwu/polyjuice.git

In [None]:
%cd polyjuice

In [None]:
!pip install -e /content/polyjuice/

## Set up polyjuice with default GPT-2 model

In [None]:
import sys
sys.path.append('/content/polyjuice/polyjuice')

In [None]:
from polyjuice import Polyjuice
pj = Polyjuice(model_path="uw-hai/polyjuice", is_cuda=True)

## Our Adaptation

### Install IG package

In [None]:
!pip install transformers-interpret

### Connect to Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

### Import necessary modules and packages
### Set up variables

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
%load_ext autoreload
%autoreload 2
is_cuda = False

In [None]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
import gc

from polyjuice.generations import ALL_CTRL_CODES

import functools
from copy import deepcopy

from transformers_interpret import SequenceClassificationExplainer
import torch
from transformers import pipeline

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
input_data = "/content/gdrive/MyDrive/thesis_data/models/nature_sentences_data.csv"
experiment= "dist_sent_"
output_path = "/content/gdrive/MyDrive/thesis_data/models/" + experiment + "polyjuice.csv"

### Load our fine-tuned classification model

In [None]:
model_name = "/content/gdrive/MyDrive/thesis_data/models/dist_sent_all/new_distilbert/model_sentences"
tokenizer_model = 'distilbert-base-uncased-finetuned-sst-2-english'

# model_name = "/content/gdrive/MyDrive/thesis_data/models/bert_sent_all/new_bert/model_sentences"
# tokenizer_model = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'

#### Use the following code for a DistilBERT model

In [None]:
# get a model
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

tokenizer = DistilBertTokenizer.from_pretrained(tokenizer_model)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
tokenizer.save_pretrained(model_name)
MAX_LENGTH = 1024

#### Uncomment the following code to load a BERT model

In [None]:
# get a model
# from transformers import BertTokenizer, BertForSequenceClassification

# tokenizer = BertTokenizer.from_pretrained(tokenizer_model)
# model = BertForSequenceClassification.from_pretrained(model_name)
# tokenizer.save_pretrained(model_name)
# MAX_LENGTH = 512

In [None]:
pipe = pipeline(
    "text-classification", model=model, tokenizer=tokenizer, truncation = True, padding=True, max_length = MAX_LENGTH,
    framework="pt", device=0 if is_cuda else -1, return_all_scores=True)

### Set up input data

In [None]:
df=pd.read_csv(input_data)

X_all = df['TEXT'].values
y_all = df['CATEGORY'].values

X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.3, random_state=42, stratify=y_all)

scient_test = X_test[y_test == 0]
scient_list = scient_test.tolist()

del(X_all, y_all, X_train, y_train, df, X_test, y_test, scient_test)

torch.cuda.empty_cache()
gc.collect()
print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

In [None]:
CUDA_LAUNCH_BLOCKING=1

### Declare useful functions (as used in the [Polyjuice](https://github.com/tongshuangwu/polyjuice/blob/main/notebooks/Polyjuice%20demo.ipynb)) 

In [None]:
# some wrapper for prediction
def extract_predict_label(raw_pred):
    raw_pred = sorted(raw_pred, key=lambda r: -r["score"])
    if raw_pred:
        return raw_pred[0]["label"]
    return None

def predict(examples, predictor, batch_size=128):
    raw_preds, preds, distribution = [], [], []
    with torch.no_grad():
        for e in (range(0, len(examples), batch_size)):
            with torch.no_grad():
                raw_preds.extend(predictor(examples[e:e+batch_size]))
    for raw_pred in raw_preds:
        raw_pred = raw_pred if type(raw_pred) == list else [raw_pred]
        for m in raw_pred:
           #m["label"] = int(m["label"].split("_")[1])
          if m["label"] == 'NEGATIVE':
            m["label"]='0'
          else:
            m["label"]='1'
    return raw_preds

def wrap_perturbed_instances(perturb_texts, orig, perturb_idx):
    perturbs = []
    for a in perturb_texts:
        curr_example = deepcopy(list(orig))
        curr_example[perturb_idx] = a
        perturbs.append(tuple(curr_example))
    return perturbs

### Function: Get words weights using IG

In [None]:
def get_words_weights(text, model, tokenizer):
  torch.cuda.empty_cache()
  gc.collect()

  # gettoken weights using ig
  with torch.no_grad():
    cls_explainer = SequenceClassificationExplainer(
      model,
      tokenizer)
    word_attributions = cls_explainer(text, internal_batch_size=1)

  # find word weights from token weights
  weights_df = pd.DataFrame(word_attributions, columns =['Token', 'Weight'])
  weights_df['Count'] = 1

  last_word_index = 0
  for i in range(weights_df.shape[0]):
    if weights_df.Token[i].startswith("##"):
      weights_df.Token[last_word_index] = weights_df.Token[last_word_index] + weights_df.Token[i].replace("##", "")
      weights_df.Weight[last_word_index] += weights_df.Weight[i]
      weights_df.Count[last_word_index] += 1
      weights_df.Token[i] = "##"
    else:
      last_word_index = i

  weights_df = weights_df[weights_df.Token != "##" ]
  weights_df['Avg_Weight'] = weights_df['Weight']/weights_df['Count']

  weights_df.reset_index(drop=True, inplace=True)
  weights_dict = {}

  for i in range(1, (weights_df.shape[0]-1)):
    weights_dict[weights_df['Token'][i]] = weights_df['Avg_Weight'][i]

  gc.collect()
  torch.cuda.empty_cache()

  return weights_dict

In [None]:
TORCH_USE_CUDA_DSA = 'enable'

### Get counterfactuals for all the dataset

In [None]:
results_list = []

for i in range(len(scient_list)):
  
  # save results every 10 instances
  if i%10==0:
    df = pd.DataFrame(results_list, columns=['Original', 'New', 'new_class', 'changed_features', 'n_pert_texts'])
    df.to_csv(output_path)
  
  print("Now checking...")
  print(i)
  try:
    # make original prediction with our classifier
    with torch.no_grad():
      orig_pred = predict(scient_list[i], predictor=pipe)[0]

    if extract_predict_label(orig_pred)==1:
      continue;

    # get multiple pertubed texts from polyjuice
    with torch.no_grad():
      perturb_texts = pj.perturb(
        scient_list[i],
        ctrl_code=ALL_CTRL_CODES,
        num_perturbations=None,
        perplex_thred=10
        )
    perturb_texts = [t.lower() for t in perturb_texts]

    pt = len(perturb_texts)
    if not perturb_texts:
      continue;

    # gettoken weights using ig
    feature_importance_dict = get_words_weights(scient_list[i], model, tokenizer)

    # get probabilities of pertubed texts
    orig = [scient_list[i]]
    perturb_instances = wrap_perturbed_instances(perturb_texts, orig, perturb_idx=0)

    with torch.no_grad():
      perturb_preds = predict(perturb_texts, predictor=pipe)

    # check class change
    surprises = pj.select_surprise_explanations(
        orig_text=orig[0].lower(),
        perturb_texts=perturb_texts,
        orig_pred=orig_pred,
        perturb_preds=perturb_preds,
        feature_importance_dict=feature_importance_dict
        )

    if (not surprises) or len(surprises)==0:
      min_perturb =  perturb_preds[0][0]['score']
      min_idx = 0
      for i in range(1, len(perturb_preds)):
        temp = perturb_preds[i][0]['score']
        if temp < min_perturb:
          min_perturb = temp
          min_idx = i

      label=0 if min_perturb>=0.5 else 1
      results_list.append([scient_list[i], perturb_texts[i], label, "", pt])
    else:
      # save results
      results_list.append([scient_list[i], surprises[0]['perturb_text'], surprises[0]['pred'], surprises[0]['changed_features'], pt])

  except Exception as e:
    print(i)
    print("The error is: ",e)


torch.cuda.empty_cache()
df = pd.DataFrame(results_list, columns=['Original', 'New', 'new_class', 'changed_features', 'n_pert_texts'])
df.to_csv(output_path)