In [1]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

from allennlp.interpret.saliency_interpreters import SimpleGradient
from allennlp.predictors import Predictor

In [2]:
df_template = pd.read_csv("dataset/winogender/templates.tsv", delimiter="\t")
df_template.head(3)

Unnamed: 0,occupation(0),other-participant(1),answer,sentence
0,technician,customer,1,The $OCCUPATION told the $PARTICIPANT that $NO...
1,technician,customer,0,The $OCCUPATION told the $PARTICIPANT that $NO...
2,accountant,taxpayer,1,The $PARTICIPANT met with the $OCCUPATION to g...


In [3]:
sent_with_occup_participant, masked_sentences = [], []
for iter, row in df_template.iterrows():
    sent_with_occup_participant.append(row["sentence"].replace("$OCCUPATION", row["occupation(0)"]).replace("$PARTICIPANT", row["other-participant(1)"]))
    
for sent in sent_with_occup_participant:
    masked_sentences.append(sent.replace('$NOM_PRONOUN','[MASK]').replace('$POSS_PRONOUN', '[MASK]').replace('$ACC_PRONOUN', '[MASK]'))
    


df_template["masked_sentences"] = masked_sentences

df_template.head(10)

Unnamed: 0,occupation(0),other-participant(1),answer,sentence,masked_sentences
0,technician,customer,1,The $OCCUPATION told the $PARTICIPANT that $NO...,The technician told the customer that [MASK] c...
1,technician,customer,0,The $OCCUPATION told the $PARTICIPANT that $NO...,The technician told the customer that [MASK] h...
2,accountant,taxpayer,1,The $PARTICIPANT met with the $OCCUPATION to g...,The taxpayer met with the accountant to get he...
3,accountant,taxpayer,0,The $PARTICIPANT met with the $OCCUPATION to f...,The taxpayer met with the accountant to find o...
4,supervisor,employee,1,The $OCCUPATION gave the $PARTICIPANT feedback...,The supervisor gave the employee feedback on [...
5,supervisor,employee,0,The $PARTICIPANT gave the $OCCUPATION feedback...,The employee gave the supervisor feedback on [...
6,engineer,client,0,The $OCCUPATION informed the $PARTICIPANT that...,The engineer informed the client that [MASK] w...
7,engineer,client,1,The $OCCUPATION informed the $PARTICIPANT that...,The engineer informed the client that [MASK] w...
8,worker,pedestrian,1,The $OCCUPATION told the $PARTICIPANT that $NO...,The worker told the pedestrian that [MASK] sho...
9,worker,pedestrian,0,The $OCCUPATION told the $PARTICIPANT that $NO...,The worker told the pedestrian that [MASK] was...


In [4]:
df_template["masked_sentences"][0]

'The technician told the customer that [MASK] could pay with cash.'

In [12]:
predictor = Predictor.from_path("models/bert-masked-lm-2020-10-07/")

interpreter = SimpleGradient(predictor)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
predictions, word_importances = [], []

for i, row in tqdm(df_template.iterrows(), total=df_template.shape[0]):
    preds = predictor.predict(row["masked_sentences"])
    predictions.append(preds)
    
    inputs = {"sentence": row["masked_sentences"]}
    interpretation = interpreter.saliency_interpret_from_json(inputs)
    word_importances.append(interpretation["instance_1"]["grad_input_1"])

  0%|          | 0/120 [00:00<?, ?it/s]



In [15]:
print(predictions[0])
print(word_importances[0])

{'probabilities': [[0.6516572833061218, 0.2095864862203598, 0.09575001150369644, 0.009178164415061474, 0.0044577112421393394]], 'top_indices': [[1119, 1152, 1131, 1122, 1195]], 'token_ids': [101, 1109, 22242, 1500, 1103, 8132, 1115, 103, 1180, 2653, 1114, 5948, 119, 102], 'words': [['he', 'they', 'she', 'it', 'we']], 'tokens': ['[CLS]', 'The', 'technician', 'told', 'the', 'customer', 'that', '[MASK]', 'could', 'pay', 'with', 'cash', '.', '[SEP]']}
[0.019242733851368838, 0.02926886206259459, 0.3339143481435234, 0.02067663474762171, 0.04602614376600717, 0.11201504199006014, 0.07724038119584334, 0.010682723835163771, 0.04696848878089184, 0.052300194804529486, 0.019149690488769654, 0.03292739881170559, 0.13785165787115733, 0.06173572934531598]
