# Interpreting Correct-n-Contrast in NLP

This demo will run through several (simplified) processes to load in pretrained models and SHAP values for Correct-n-Contrast, as well as a control ERM model

*Be sure to follow the README.md before running these cells*

### 1. Load in necessary assets

Using some custom utility functions, load in the pretrained ERM and CnC models, as well as a pre-sampled subset of the CivilComments dataset.

In [2]:
# Load in a class for loading pretrained models from the 'data' folder
from utils import load_models as lm

cncLoader = lm.ModelLoader('data/civilcomments_cnc_pretrained.pth.tar')
ermLoader = lm.ModelLoader('data/civilcomments_erm_early.pth.tar')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

log.missing_keys: []


`return_all_scores` is now deprecated,  if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a B

log.missing_keys: []


In [3]:
# Load in a dataset loader to read in the CivilComments dataset
# and randomly sample 15 samples from 19 different categories
from utils import load_civilcomments as lcc

dataset = lcc.CivilCommentsLoader('data/all_data_with_identities.csv').data

### 2a. Calculate SHAP Values

You can run the (2a) cells to actively reproduce the SHAP values using the pipelines you've now imported, *or* you can run (2b) to immediately load in the pre-calculated SHAPs that you've downloaded from the provided Drive.

As a note, the visualizations in (2a) *require* the cells before it in (2a) to be run; cells in (2b) will overwrite the raw output with an object that cannot be used to produce those same visualizations.

In [9]:
# Import SHAP and build explainers using the pipelines we loaded in
import shap

erm_explainer = shap.Explainer(cncLoader.pipeline)
cnc_explainer = shap.Explainer(ermLoader.pipeline)

# Import the 285 random sample IDs
sample_indices = [5501204,  745527, 6011650, 6028377,  689468, 5836741, 5325575, 5268943,
                          7068476,  395975, 7017781,  374886, 5996752, 5790334,  874205, 5247322,
                          7087935,  551485,  837864, 5225928, 5996362, 6049191, 7010178,  969205,
                          912424, 5807085, 5258444,  302244, 5691616, 5706291,  263826, 5673529,
                          1042100, 5668535, 5660666, 5671172, 5849196,  854766, 5714403, 5677939,
                          317738, 5661906, 7030152, 5273685, 5667538, 5271747, 5634407, 6195881,
                          5250160, 5413617, 5082363, 6289685,  765032, 5223210, 5619106, 5227845,
                          6212185,  805615, 5926602, 5739778,  562695,  916552, 5299060, 7121989,
                          5061615, 5413135, 5619134, 5636688, 6207105, 5355613,  459284, 5481493,
                          5437664, 6194010, 6086468, 5021733, 5670747,  364360, 5701054, 6232221,
                          5702495, 6137334, 7186383, 5654569, 5738177, 5866509,  345804, 6000535,
                          5801799, 7037625,  591666,  703019, 1058652, 5437859, 5758811, 1036211,
                          5482323,  796796, 5130262, 5150596,  871928, 6217765, 5740264,  291905,
                          332398,  940570, 1009617,  478245, 5670871, 5615730, 5760844,  528144,
                          7098717, 5754496, 7166976, 1014032, 7068738, 5722790, 6081854, 5617979,
                          911525,  924296, 5972210, 5060992, 6091732,  924059, 5177923,  359103,
                          5363060, 5104459, 6102821, 5015383, 5996226, 7157963,  929809, 5150270,
                          5807595,  830440, 5405057,  793671, 6010742, 5668823, 6211944, 6042510,
                          6180079, 5246452, 1065278, 7091716, 5838669, 5535379, 7100054, 5511312,
                          6156879, 5789327, 4982134, 6192361,  518302, 5952937, 5665505,  982496,
                          5146957, 5720920, 5376205, 6229436, 5870940, 5507166, 5395067,  768060,
                          5567718,  775551, 6198500, 6078953, 7158473, 6305426,  440051, 7026882,
                          850396,  929241, 7038266, 5449238,  688431, 7087713, 6160565, 1070663,
                          7067877, 5791746, 5358390, 5316820, 5662524, 5547323, 6150904,  634605,
                          277145, 5988370, 6068702, 5083110, 5864967, 6264813, 5570643, 1063971,
                          5229899, 6170349, 6221978, 6222995, 5798216, 7064045, 5827378, 6203736,
                          691407, 5319080, 6318572,  778195, 5202048, 5050173, 1058174, 5064416,
                          6093098,  468354,  640894,  683294, 1021911, 5633642, 6132423, 5903169,
                          841347, 6091704,  603228, 5864320, 6305041, 1036874, 5971755,  865220,
                          5548384,  816903, 1022930, 6279524,  901438, 7139278, 1033869, 7071539,
                          5697223,  884837, 5154939, 5208818, 7047243, 6186959, 1083160, 5863209,
                          1079446,  635378, 5448718, 5172545, 5802195,  448084, 7069276,  349702,
                          508013,  891563, 7187074, 7156493, 5885610, 5890739, 5146285, 5577310,
                          264044, 5659300,  486934, 7068576, 6197915, 6008986, 6309507, 5205561,
                          810498,  509950, 5310217, 6221773,  245711, 6081457, 5666897, 5848243,
                          1001118, 5662096,  735476, 7008599, 5082780]

# Mask the original dataset on the condition of the ID being in the list and apply to the dataset
samples_mask = dataset['id'].isin(sample_indices)
samples = dataset.loc[samples_mask]

In [11]:
# Calculate the ERM SHAP values using the sampled dataset
# This cell is designed to take a very long time (~7 hours)
erm_shap_values = erm_explainer(samples['text'])
shap.plots.text(erm_shap_values)

  0%|          | 0/498 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Calculate the CnC SHAP values using the sampled dataset
# This cell is designed to take a very long time (~7 hours)
cnc_shap_values = cnc_explainer(samples['text'])
shap.plots.text(cnc_shap_values)

In [None]:
# Write SHAP results to files
file = open("erm_shap_values.txt", "w")
file.write(str(erm_shap_values.values))
file.close()

file = open("erm_shap_data.txt", "w", encoding="utf-8")
file.write(str(erm_shap_values.data))
file.close()

file = open("cnc_shap_values.txt", "w")
file.write(str(cnc_shap_values.values))
file.close()

file = open("cnc_shap_data.txt", "w", encoding="utf-8")
file.write(str(cnc_shap_values.data))
file.close()

In [None]:
# This is a visualization step for viewing the top 30 SHAP values from each model.
# Note: This step only works using raw output from (2a), NOT (2b)

# Here we find the mean of the features that contribute to classifying a comment as
# toxic which is class label 1.
shap.plots.bar(erm_shap_values[:,:,1].mean(0), max_display=31)
shap.plots.bar(cnc_shap_values[:,:,1].mean(0), max_display=31)

### 2b. Import all Pre-Calculated SHAP values (optional)

In [20]:
# Import our pre-calculated SHAPs using a provided function (if you would like to look at them)
from utils import load_shap_values as lsv

erm_shap_values = lsv.SHAPLoader("data/erm_shap_data.txt", "data/erm_shap_values.txt").SHAP_values
cnc_shap_values = lsv.SHAPLoader("data/cnc_shap_data.txt", "data/cnc_shap_values.txt").SHAP_values