*Copyright (c) Microsoft Corporation. All rights reserved.*

*Licensed under the MIT License.*

# Text Classification of SST-2 Sentences using a 3-Player Introspective Model

In [None]:
import sys
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import scrapbook as sb

from interpret_text.common.dataset.utils_sst2 import load_sst2_pandas_df
from interpret_text.introspective_rationale.introspective_rationale_explainer import IntrospectiveRationaleExplainer
from interpret_text.common.utils_introspective_rationale import GlovePreprocessor, BertPreprocessor, ModelArguments, load_glove_embeddings
from interpret_text.widget import ExplanationDashboard

## Introduction
In this notebook, we train and evaluate a  [three-player explainer](http://people.csail.mit.edu/tommi/papers/YCZJ_EMNLP2019.pdf) model on a subset of the [SST-2](https://nlp.stanford.edu/sentiment/index.html/) dataset. To run this notebook, we used the SST-2 data files provided [here](https://github.com/AcademiaSinicaNLPLab/sentiment_dataset).

### Set parameters
Here we set some parameters that we use for our modeling task.

In [None]:
# if quick run true, skips over embedding, most of model training, and model evaluation; used to quickly test pipeline
QUICK_RUN = False
MODEL_TYPE = "RNN" # currently support either RNN, BERT, or a combination of RNN and BERT
CUDA = True

# data processing parameters
DATA_FOLDER = "../../../data/sst2"
LABEL_COL = "labels" 
TEXT_COL = "sentences"
token_count_thresh = 1
max_sentence_token_count = 70

# training procedure parameters
load_pretrained_model = False
pretrained_model_path = "../models/rnn.pth"
MODEL_SAVE_DIR = os.path.join("..", "models")
model_prefix = "sst2rnpmodel"

In [None]:
# ModelArguments contains default parameters used internally in the model that can be changed
args = ModelArguments(cuda=CUDA, model_save_dir=MODEL_SAVE_DIR, model_prefix=model_prefix)
                      
# examples of changing args after initialization
args.lr = 2e-4
if QUICK_RUN:
    args.save_best_model = False
    args.pre_train_cls = False
    args.num_epochs = 1


if MODEL_TYPE == "RNN":
    # (i.e. not using BERT), load pretrained glove embeddings
    if not QUICK_RUN:
        args.embedding_path = load_glove_embeddings(DATA_FOLDER)
    else:
        args.embedding_path = os.path.join(DATA_FOLDER, "")

## Read Dataset
We start by loading a subset of the data for training and testing.

In [None]:
train_data = load_sst2_pandas_df('train')
test_data = load_sst2_pandas_df('test')
all_data = pd.concat([train_data, test_data])
if QUICK_RUN:
    train_data = train_data.head(batch_size)
    test_data = test_data.head(batch_size)
X_train = train_data[TEXT_COL]
X_test = test_data[TEXT_COL]

In [None]:
# get all unique labels
y_labels = all_data[LABEL_COL].unique()
args.labels = np.array(sorted(y_labels))
args.num_labels = len(y_labels)

## Tokenization and embedding
The data is then tokenized and embedded using glove embeddings.

In [None]:
if MODEL_TYPE == "RNN":
    preprocessor = GlovePreprocessor(all_data[TEXT_COL], token_count_thresh, max_sentence_token_count)
if MODEL_TYPE == "BERT":
    preprocessor = BertPreprocessor()

# append labels to tokenizer output
df_train = pd.concat([train_data[LABEL_COL], preprocessor.preprocess(X_train)], axis=1)
df_test = pd.concat([test_data[LABEL_COL], preprocessor.preprocess(X_test)], axis=1)

## Explainer
Then, we create the explainer and train it (or load a pretrained model).

In [None]:
explainer = IntrospectiveRationaleExplainer(args, preprocessor, classifier_type=MODEL_TYPE)

if load_pretrained_model:
    classifier = explainer.load_pretrained_model(pretrained_model_path)
else:
    classifier = explainer.fit(df_train, df_test)

We can test the explainer and measure its performance:

In [None]:
if not QUICK_RUN:
    explainer.score(df_test)
    sparsity = explainer.model.avg_sparsity
    accuracy = explainer.model.avg_accuracy
    anti_accuracy = explainer.model.avg_anti_accuracy
    print("Test sparsity: ", sparsity)
    print("Test accuracy: ", accuracy, "% Anti-accuracy: ", anti_accuracy)
    
    # for testing
    sb.glue("sparsity", sparsity)
    sb.glue("accuracy", accuracy)
    sb.glue("anti_accuracy", anti_accuracy)

## Local importances
We can display the found local importances (the most and least important words for a given sentence):

In [None]:
# Enter a sentence that needs to be interpreted
sentence = "Beautiful movie ; really good , the popcorn was bad"
s2 = "a beautiful and haunting examination of the stories we tell ourselves to make sense of the mundane horrors of the world."
s3 = "the premise is in extremely bad taste , and the film's supposed insights are so poorly executed and done that even a high school dropout taking his or her first psychology class could dismiss them ."

local_explanation = explainer.explain_local("This is a super amazing movie with bad acting", preprocessor, hard_importances=False)

## Visualize explanations
We can visualize local feature importances as a heatmap over words in the document and view importance values of individual words.

In [None]:
explainer.visualize(local_explanation)

In [None]:
ExplanationDashboard(local_explanation)