# Interpreting Correct-n-Contrast in NLP

This demo will run through several (simplified) processes to load in pretrained models and SHAP values for Correct-n-Contrast, as well as a control ERM model

*Be sure to follow the README.md before running these cells*

### 1. Load in necessary assets

Using some custom utility functions, load in the pretrained ERM and CnC models, as well as a pre-sampled subset of the CivilComments dataset.

In [4]:
# Load in a class for loading pretrained models from the 'data' folder
from utils import load_models as lm

cncLoader = lm.ModelLoader('data/civilcomments_cnc_pretrained.pth.tar')
ermLoader = lm.ModelLoader('data/civilcomments_erm_early.pth.tar')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

log.missing_keys: []


`return_all_scores` is now deprecated,  if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a B

log.missing_keys: []


In [10]:
# Load in a dataset loader to read in the CivilComments dataset
# and randomly sample 15 samples from 19 different categories
from utils import load_civilcomments as lcc

dataset = lcc.CivilCommentsLoader('data/all_data_with_identities.csv').data

### 2a. Calculate SHAP Values

You can run the (2a) cells to actively reproduce the SHAP values using the pipelines you've now imported, *or* you can run (2b) to immediately load in the pre-calculated SHAPs that you've downloaded from the provided Drive.

As a note, the visualizations in (2a) *require* the cells before it in (2a) to be run; cells in (2b) will overwrite the raw output with an object that cannot be used to produce those same visualizations.

In [5]:
# Import SHAP and build explainers using the pipelines we loaded in
import shap

erm_explainer = shap.Explainer(cncLoader.pipeline)
cnc_explainer = shap.Explainer(ermLoader.pipeline)

In [13]:
# Calculate the ERM SHAP values using the sampled dataset
# This cell is designed to take a very long time (~7 hours)
erm_shap_values = erm_explainer(dataset['text'])
shap.plots.text(erm_shap_values)

  0%|          | 0/498 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Calculate the CnC SHAP values using the sampled dataset
# This cell is designed to take a very long time (~7 hours)
cnc_shap_values = cnc_explainer(dataset['text'])
shap.plots.text(cnc_shap_values)

In [None]:
# Write SHAP results to files
file = open("erm_shap_values.txt", "w")
file.write(str(erm_shap_values.values))
file.close()

file = open("erm_shap_data.txt", "w", encoding="utf-8")
file.write(str(erm_shap_values.data))
file.close()

file = open("cnc_shap_values.txt", "w")
file.write(str(cnc_shap_values.values))
file.close()

file = open("cnc_shap_data.txt", "w", encoding="utf-8")
file.write(str(cnc_shap_values.data))
file.close()

In [None]:
# This is a visualization step for viewing the top 30 SHAP values from each model.
# Note: This step only works using raw output from (2a), NOT (2b)

# Here we find the mean of the features that contribute to classifying a comment as
# toxic which is class label 1.
shap.plots.bar(erm_shap_values[:,:,1].mean(0), max_display=31)
shap.plots.bar(cnc_shap_values[:,:,1].mean(0), max_display=31)

### 2b. Import all Pre-Calculated SHAP values (optional)

In [20]:
# Import our pre-calculated SHAPs using a provided function (if you would like to look at them)
from utils import load_shap_values as lsv

erm_shap_values = lsv.SHAPLoader("data/erm_shap_data.txt", "data/erm_shap_values.txt").SHAP_values
cnc_shap_values = lsv.SHAPLoader("data/cnc_shap_data.txt", "data/cnc_shap_values.txt").SHAP_values