*Copyright (c) Microsoft Corporation. All rights reserved.*

*Licensed under the MIT License.*

# Text Classification of SST-2 Sentences using a 3-Player Model

In [1]:
import sys
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

from interpret_text.three_player_introspective.three_player_introspective_explainer import ThreePlayerIntrospectiveExplainer
from interpret_text.common.utils_three_player import load_pandas_df, GloveTokenizer
from interpret_text.widget import ExplanationDashboard

## Introduction
In this notebook, we train and evaluate a  [three-player explainer](http://people.csail.mit.edu/tommi/papers/YCZJ_EMNLP2019.pdf) model on a subset of the [SST-2](https://nlp.stanford.edu/sentiment/index.html/) dataset. To run this notebook, we used the SST-2 data files provided [here](https://github.com/AcademiaSinicaNLPLab/sentiment_dataset). You should download files matching data/sst2.binary.* into a folder and point DATA_FOLDER (below) to that folder.

### Set parameters
Here we set some parameters that we use for our modeling task.

In [2]:
# data processing parameters
DATA_FOLDER = "../../../data/sst2"
LABEL_COL = "labels" 
TEXT_COL = "sentences"
token_count_thresh = 1
max_sentence_token_count = 70

# training procedure params
pre_trained_model_prefix = 'pre_trained_cls.model'
save_path = os.path.join("..", "models")
model_prefix = "sst2rnpmodel"
save_best_model = True
pre_train_cls = True

# arguments for the model
class Argument():
    def __init__(self):
        # to initialize classifierModule and introspectionGeneratorModule
        self.embedding_dim = 100
        self.hidden_dim = 200
        self.layer_num = 1
        self.z_dim = 2
        self.dropout_rate = 0.5

        # to init only introspectionGeneratorModule
        self.num_labels = 2
        self.label_embedding_dim = 400
        self.fixed_classifier = True

        # to init model
        self.fine_tuning = False
        self.cuda = True
        self.batch_size = 40
        self.lambda_sparsity = 1.0
        self.lambda_continuity = 1.0
        self.lambda_anti = 1.0
        self.exploration_rate = 0.05
        self.count_tokens = 8
        self.count_pieces = 4
        self.lambda_acc_gap = 1.2
        self.lr=0.001
        self.embedding_path = os.path.join(DATA_FOLDER, "hglove.6B.100d.txt")
args = Argument()
args_dict = vars(args)

## Read Dataset
We start by loading a subset of the data for training and testing.

In [3]:
# TODO: load dataset to blob storage
df_train = load_pandas_df('train', LABEL_COL, TEXT_COL)
df_test = load_pandas_df('test', LABEL_COL, TEXT_COL)
df_all = pd.concat([df_train, df_test])
x_train = df_train[TEXT_COL]
x_test = df_test[TEXT_COL]
y_train = df_train[LABEL_COL]
y_test = df_train[LABEL_COL]
labels = df_all[LABEL_COL].unique()

## Tokenization and embedding
The data is then tokenized and embedded using glove embeddings.

In [4]:
tokenizer = GloveTokenizer(df_all[TEXT_COL], token_count_thresh, max_sentence_token_count)
word_vocab = tokenizer.word_vocab

# append tokenizations to data
df_train = pd.concat([df_train, tokenizer.tokenize(x_train)], axis=1)
df_test = pd.concat([df_test, tokenizer.tokenize(x_test)], axis=1)
df_all = pd.concat([df_train, df_test])
x_train = df_train[TEXT_COL]
x_test = df_test[TEXT_COL]
y_train = df_train[LABEL_COL]
y_test = df_train[LABEL_COL]
labels = df_all[LABEL_COL].unique()
print(df_all)

      labels                                          sentences  \
0          1  a stirring , funny and finally transporting re...   
1          0  apparently reassembled from the cutting-room f...   
2          0  they presume their audience wo n't sit still f...   
3          1  this is a visually stunning rumination on love...   
4          1  jonathan parker 's bartleby should have been t...   
...      ...                                                ...   
1816       0  an often-deadly boring , strange reading of a ...   
1817       0  the problem with concept films is that if the ...   
1818       0  safe conduct , however ambitious and well-inte...   
1819       0  a film made with as little wit , interest , an...   
1820       0  but here 's the real damn : it is n't funny , ...   

                                                 tokens  \
0     [2, 3, 4, 5, 6, 7, 8, 1, 10, 11, 6, 12, 13, 6,...   
1     [17, 1, 19, 12, 1, 21, 10, 22, 23, 24, 25, 26,...   
2     [27, 1, 29, 

## Explainer
Then, we create and train the explainer.

In [5]:
explainer = ThreePlayerIntrospectiveExplainer(args, word_vocab)
classifier = explainer.fit(df_train, df_test, args.batch_size, num_iteration=1000, pretrain_cls=True)

  "num_layers={}".format(dropout, num_layers))


embedding is initialized fully randomly.




pre-training the classifier


  0%|                                                                                            | 0/5 [00:00<?, ?it/s]


AssertionError: Torch not compiled with CUDA enabled

We can test the explainer and measure its performance:

In [None]:
accuracy, anti_accuracy, sparsity = explainer.score(df_test)
print("Test sparsity: ", sparsity)
print("Test accuracy: ", accuracy, "% Anti-accuracy: ", anti_accuracy)

## Local importances
We can display the found local importances (the most and least important words for a given sentence)

In [None]:
# Enter a sentence that needs to be interpreted
sentence = "This great movie was really good"
label = 0

# Tokenize the sentence
df_sentence = pd.DataFrame.from_dict({TEXT_COL: [sentence], LABEL_COL: [label]})
tokenizer = GloveTokenizer(df_sentence[TEXT_COL], token_count_thresh, max_sentence_token_count)
df_sentence = pd.concat([df_sentence, tokenizer.tokenize(df_sentence[TEXT_COL])], axis=1)
local_explanantion = explainer.explain_local(sentence, df_sentence, np.array([0, 1]), hard_importances=False)

# Visualize local feature importances as a heatmap over words in the document
# TODO: a less hacky way of getting words
explainer.visualize(local_explanantion._local_importance_values, local_explanantion._features)

ExplanationDashboard(local_explanantion)

In [None]:
print(local_explanantion._local_importance_values)