*Copyright (c) Microsoft Corporation. All rights reserved.*

*Licensed under the MIT License.*

# Text Classification of SST-2 Sentences using a 3-Player Introspective Model

In [None]:
import sys
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

from interpret_text.common.dataset.utils_sst2 import load_sst2_pandas_df
from interpret_text.three_player_introspective.three_player_introspective_explainer import ThreePlayerIntrospectiveExplainer
from interpret_text.common.utils_three_player import GlovePreprocessor, BertPreprocessor, ModelArguments, load_glove_embeddings
from interpret_text.widget import ExplanationDashboard

## Introduction
In this notebook, we train and evaluate a  [three-player explainer](http://people.csail.mit.edu/tommi/papers/YCZJ_EMNLP2019.pdf) model on a subset of the [SST-2](https://nlp.stanford.edu/sentiment/index.html/) dataset. To run this notebook, we used the SST-2 data files provided [here](https://github.com/AcademiaSinicaNLPLab/sentiment_dataset).

### Set parameters
Here we set some parameters that we use for our modeling task.

In [None]:
# if quick run true, skips over embedding, most of model training, and model evaulation; used to quickly test pipeline
QUICK_RUN = False
model_type = "RNN" # currently support either RNN or BERT

# data processing parameters
DATA_FOLDER = "../../../data/sst2"
LABEL_COL = "labels" 
TEXT_COL = "sentences"
token_count_thresh = 1
max_sentence_token_count = 70

# training procedure parameters
model_save_dir = os.path.join("..", "models")
model_prefix = "sst2rnpmodel"
cuda = True
batch_size = 64
if not QUICK_RUN:
    save_best_model = True
    pre_train_cls = True
    num_epochs = 200
else:
    save_best_model = False
    pre_train_cls = False
    num_epochs = 1

# ModelArguments contains default parameters used internally in the model that can changed
args = ModelArguments(cuda, pre_train_cls, batch_size, num_epochs, save_best_model, model_save_dir=model_save_dir, model_prefix=model_prefix)
# example of changing an argument
args.cuda = True
args.lr = 2e-4
args.embedding_path = ""

if model_type == "RNN":
    # (i.e. not using BERT), load pretrained glove embeddings
    # TODO: load glove embedding file in load_glove_embeddings to blob storage
    if not QUICK_RUN:
        args.embedding_path = load_glove_embeddings(DATA_FOLDER)
    else:
        args.embedding_path = os.path.join(DATA_FOLDER, "noEmbeddingFile.txt")

## Read Dataset
We start by loading a subset of the data for training and testing.

In [None]:
# TODO: load dataset to blob storage
train_data = load_sst2_pandas_df('train')
test_data = load_sst2_pandas_df('test')
all_data = pd.concat([train_data, test_data])
if QUICK_RUN:
    train_data = train_data.head(batch_size)
    test_data = test_data.head(batch_size)
X_train = train_data[TEXT_COL]
X_test = test_data[TEXT_COL]

In [None]:
# get all unique labels
labels = all_data[LABEL_COL].unique()
args.labels = np.array(sorted(labels))
args.num_labels = len(labels)

## Tokenization and embedding
The data is then tokenized and embedded using glove embeddings.

In [None]:
if model_type == "RNN":
    preprocessor = GlovePreprocessor(all_data[TEXT_COL], token_count_thresh, max_sentence_token_count)
if model_type == "BERT":
    preprocessor = BertPreprocessor()

# append labels to tokenizer output
df_train = pd.concat([train_data[LABEL_COL], preprocessor.preprocess(X_train)], axis=1)
df_test = pd.concat([test_data[LABEL_COL], preprocessor.preprocess(X_test)], axis=1)

## Explainer
Then, we create and train the explainer.

In [None]:
explainer = ThreePlayerIntrospectiveExplainer(args, preprocessor, classifier_type=model_type)
classifier = explainer.fit(df_train, df_test, pretrain_cls)

We can test the explainer and measure its performance:

In [None]:
if not QUICK_RUN:
    explainer.score(df_test)
    print("Test sparsity: ", explainer.model.avg_sparsity)
    print("Test accuracy: ", explainer.model.avg_accuracy, "% Anti-accuracy: ", explainer.model.avg_anti_accuracy)

Test sparsity:  0.8334719067677765
Test accuracy:  0.7693574958813838 % Anti-accuracy:  0.5315760571114773


In [49]:
from interpret_text.common.utils_three_player import generate_data
from interpret_text.explanation.explanation import _create_local_explanation


def explain_local(
    explainer, sentence, label, preprocessor, hard_importances=True
):
    df_label = pd.DataFrame.from_dict({"labels": [label]})
    df_sentence = pd.concat(
        [df_label, preprocessor.preprocess([sentence.lower()])], axis=1
    )

    batch_dict = generate_data(df_sentence, explainer.args.cuda)
    x = batch_dict["x"]
    m = batch_dict["m"]
    predict_dict = explainer.predict(df_sentence)
    predict = predict_dict["predict"].cpu()
    zs = predict_dict["rationale"]
    if not hard_importances:
        zs = explainer.model.get_z_scores(df_sentence)
        predict_class_idx = np.argmax(predict)
        zs = zs[:, :, predict_class_idx].detach()

    zs = np.array(zs.cpu())

    # generate human-readable tokens (individual words)
    seq_len = int(m.sum().item())
    ids = x[:seq_len][0]
    tokens = preprocessor.decode_single(ids)

    local_explanation = _create_local_explanation(
        classification=True,
        text_explanation=True,
        local_importance_values=zs.flatten(),
        method=str(type(explainer.model)),
        model_task="classification",
        features=tokens,
        classes=explainer.labels,
    )

    return local_explanation


## Local importances
We can display the found local importances (the most and least important words for a given sentence):

In [None]:
# Enter a sentence that needs to be interpreted
sentence = "This great movie was really good, but it could be bad"
label = 0

local_explanation = explain_local(explainer, sentence, label, preprocessor, hard_importances=False)

## Visualize explanations
We can visualize local feature importances as a heatmap over words in the document and view importance values of individual words.

In [None]:
explainer.visualize(local_explanation._local_importance_values, local_explanation._features)

In [None]:
ExplanationDashboard(local_explanation)