In [1]:
import datasets
import torch
import numpy as np
import json
import pickle
from tqdm import tqdm

## Get negation indices

In [2]:
mnli = datasets.load_dataset("glue", "mnli")

Reusing dataset glue (/home/meissner/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/5 [00:00<?, ?it/s]

In [32]:
negation_words = [" no ", " not ", " don't ", " none ", " nothing ", 
                  " never ", " aren’t ", " isn’t ", " weren’t ", 
                  " neither ", " don’t ", " didn’t ", " doesn’t ", 
                  " cannot ", " hasn’t ", " won’t "]

In [33]:
def find_negation_idxs(dataset):
    ent_indices, cont_indices = [], []
    for sample in dataset:
        if any([word in sample["hypothesis"].lower() for word in negation_words]):
            if sample['label'] == 0:
                ent_indices.append(sample['idx'])
            elif sample['label'] == 2:
                cont_indices.append(sample['idx'])
    return ent_indices, cont_indices

In [34]:
train_ents, train_conts = find_negation_idxs(mnli["train"])
val_m_ents, val_m_conts = find_negation_idxs(mnli["validation_matched"])
val_mm_ents, val_mm_conts = find_negation_idxs(mnli["validation_mismatched"])

In [35]:
all_neg_indices = {"train_ents" : train_ents, "train_conts" : train_conts, "val_m_ents" : val_m_ents, "val_m_conts" : val_m_conts, "val_mm_ents" : val_mm_ents, "val_mm_conts" : val_mm_conts}

In [7]:
!mkdir /home/meissner/shortcut-pruning/data/subsets/

mkdir: cannot create directory ‘/home/meissner/shortcut-pruning/data/subsets/’: File exists


In [8]:
with open("/home/meissner/shortcut-pruning/data/subsets/mnli_negation_indices.json", "w") as _file:
    _file.write(json.dumps(all_neg_indices))

In [38]:
for key, value in all_neg_indices.items():
    print(key)
    print(len(value))

train_ents
11718
train_conts
37681
val_m_ents
319
val_m_conts
906
val_mm_ents
251
val_mm_conts
1028


## Z-Score subsets

In [4]:
with open("/home/meissner/shortcut-pruning/MNLI_all_data.data", "rb") as _file:
    data = pickle.load(_file)

In [12]:
cont_words = np.array(data["contradiction"]["tokens"])
ent_words = np.array(data["entailment"]["tokens"])
cont_zs = np.array(data["contradiction"]["z"])
ent_zs = np.array(data["entailment"]["z"])

In [46]:
for key, value in data['contradiction'].items():
    print(key, len(value))

tokens 83147
total_counts 83147
z 83147
label_count 83147
p_hat 83147


In [13]:
cont_sort_idx = np.argsort(cont_zs)[::-1]
ent_sort_idx = np.argsort(ent_zs)[::-1]

In [47]:
cont_total_counts = np.array(data['contradiction']['total_counts'])
cont_label_counts = np.array(data['contradiction']['label_count'])
cont_p_hats = np.array(data['contradiction']['p_hat'])

In [48]:
cont_total_counts[cont_sort_idx][:10], cont_label_counts[cont_sort_idx][:10], cont_p_hats[cont_sort_idx][:10]

(array([30948, 13209, 15643, 64200,  4884,  6381, 35115,  1799,  7990,
        17898]),
 array([17310,  8215,  8065, 27115,  2797,  3254, 14007,  1058,  3546,
         7275]),
 array([0.55932532, 0.62192445, 0.51556607, 0.42235202, 0.57268632,
        0.50995142, 0.39888936, 0.5881045 , 0.44380476, 0.40647   ]))

In [14]:
# Top cont words
cont_words[cont_sort_idx][:30]

array(['no', 'never', 'any', 'not', 'nothing', 'anything', 'all',
       'completely', 'does', 'only', 'none', 'refused', 'nobody', 'dont',
       'doesnt', 'hate', 'anyone', 'didnt', 'cannot', 'did', 'definitely',
       'stayed', 'boring', 'ever', 'always', 'remained', 'terrible',
       'silent', 'ignored', 'perfectly'], dtype='<U60')

In [49]:
np.where(cont_words[cont_sort_idx] == "sleeping")

(array([5466]),)

In [15]:
ent_words[ent_sort_idx][:30]

array(['a', 'some', 'can', 'by', 'that', 'something', 'of', 'an', 'yes',
       'may', 'aware', 'according', 'sometimes', 'both', 'might',
       'called', 'asked', 'located', 'you', 'if', 'someone', 'be',
       'different', 'two', 'multiple', 'certain', 'similar', 'various',
       'there', 'goodbye'], dtype='<U60')

In [30]:
len(ent_words)

81941

In [16]:
remove_punct = "!.()[]{};:\,?\'\""

In [107]:
def get_set(sample):
    full_input = " ".join([sample["premise"], sample["hypothesis"]]).lower().strip()
    for punct in remove_punct:
        full_input = full_input.replace(punct, "")
    input_set = set(full_input.split(" "))
    return input_set

In [108]:
val_sample_sets = [get_set(sample) for sample in mnli['validation_matched']]

In [109]:
labels = np.array(mnli['validation_matched']['label'])

In [110]:
def get_samples_with_word(word):
    cont_indices = []
    ent_indices = []
    for idx, (sett, label) in enumerate(zip(val_sample_sets, labels)):
        if word in sett:
            if label == 0:
                ent_indices.append(idx)
            elif label == 2:
                cont_indices.append(idx)
    return cont_indices, ent_indices

In [111]:
cont_bias_indices, cont_antibias_indices = [], []
total_count, skip_count = 0, 0
for word in tqdm(cont_words[cont_sort_idx][:200]):
    total_count += 1
    cont_indices, ent_indices = get_samples_with_word(word)
    if len(cont_indices) < 5 or len(ent_indices) < 5:
        skip_count += 1
    else:
        cont_bias_indices.extend(cont_indices)
        cont_antibias_indices.extend(ent_indices)
f"Skipped {skip_count} of {total_count} words."

100%|██████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 451.70it/s]


'Skipped 116 of 200 words.'

In [112]:
cont_bias_indices, cont_antibias_indices = set(cont_bias_indices), set(cont_antibias_indices)

In [113]:
len(cont_bias_indices), len(cont_antibias_indices)

(2405, 1885)

In [114]:
intersection = cont_bias_indices.intersection(cont_antibias_indices)
len(intersection)

0

In [124]:
ent_bias_indices, ent_antibias_indices = [], []
total_count, skip_count = 0, 0
for word in tqdm(ent_words[ent_sort_idx][:200]):
    total_count += 1
    cont_indices, ent_indices = get_samples_with_word(word)
    if len(ent_indices) < 1 or len(cont_indices) < 1:
        skip_count += 1
    else:
        ent_bias_indices.extend(ent_indices)
        ent_antibias_indices.extend(cont_indices)
f"Skipped {skip_count} of {total_count} words."

100%|██████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 422.39it/s]


'Skipped 56 of 200 words.'

In [125]:
ent_bias_indices, ent_antibias_indices = set(ent_bias_indices), set(ent_antibias_indices)

In [126]:
len(ent_bias_indices), len(ent_antibias_indices)

(3235, 2981)

In [127]:
intersection = ent_bias_indices.intersection(ent_antibias_indices)
len(intersection)

0

In [128]:
all_indices = {
    "cont_bias_indices" : list(cont_bias_indices), 
    "cont_antibias_indices" : list(cont_antibias_indices), 
    "ent_bias_indices" : list(ent_bias_indices), 
    "ent_antibias_indices" : list(ent_antibias_indices),
}
with open("/home/meissner/shortcut-pruning/data/subsets/mnli_z_statistic_indices.json", "w") as _file:
    _file.write(json.dumps(all_indices))

## Get lexically similar words

In [7]:
import nltk
import string
from nltk.tokenize import word_tokenize

In [8]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/meissner/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [9]:
puncts = set(string.punctuation)
print(puncts)

{'#', '^', '|', '@', '+', '~', ')', ']', '%', '?', '<', '_', '$', ',', '/', ';', '`', '&', '.', ':', '}', '[', '*', "'", '!', '{', '=', '-', '\\', '>', '"', '('}


In [10]:
def jaccard_similarity(set1, set2):
    return float(len(set1.intersection(set2)) / len(set1.union(set2)))

In [11]:
word_tokenize(mnli['validation_matched'][23]['hypothesis'])

['Leading',
 'organizations',
 'want',
 'to',
 'be',
 'sure',
 'their',
 'employees',
 'are',
 'safe',
 '.']

In [12]:
jaccards = []
for sample in mnli['train']:
    hyposet = set(word_tokenize(sample['hypothesis'].lower())) - puncts
    premset = set(word_tokenize(sample['premise'].lower())) - puncts
    jaccards.append(jaccard_similarity(hyposet, premset))
jaccards = np.array(jaccards)

KeyboardInterrupt: 

In [None]:
np.min(jaccards), np.mean(jaccards), np.max(jaccards), np.argmax(jaccards), np.argmin(jaccards)

In [None]:
from matplotlib import pyplot as plt

In [None]:
bins = np.linspace(0.0, 1.0, 25)
labels = np.array(mnli['train']['label'])
ent_jaccards = jaccards[labels == 0]
cont_jaccards = jaccards[labels == 2]
plt.hist(ent_jaccards, bins, alpha=0.5, label="Entailment")
plt.hist(cont_jaccards, bins, alpha=0.5, label="Contradiction")
plt.legend()
plt.show()

In [21]:
top_indices = np.argsort(jaccards)[-4000:]

In [22]:
jaccards[top_indices[0]] # Cut rate

0.7857142857142857

In [23]:
np.mean(labels[top_indices] == 0) # 

0.79475

In [13]:
def find_high_lex_overlap_idxs(dataset, count_lim=400):
    jaccards = []
    labels = np.array(dataset['label'])
    for sample in dataset:
        hyposet = set(word_tokenize(sample['hypothesis'].lower())) - puncts
        premset = set(word_tokenize(sample['premise'].lower())) - puncts
        jaccards.append(jaccard_similarity(hyposet, premset))
    jaccards = np.array(jaccards)
    top_indices = np.argsort(jaccards)[-count_lim:]
    ent_indices = top_indices[labels[top_indices] == 0]
    cont_indices = top_indices[labels[top_indices] == 2]
    return ent_indices, cont_indices

In [14]:
# train_ents, train_conts = find_high_lex_overlap_idxs(mnli["train"], count_lim=4000)
val_m_ents, val_m_conts = find_high_lex_overlap_idxs(mnli["validation_matched"], count_lim=200)
val_mm_ents, val_mm_conts = find_high_lex_overlap_idxs(mnli["validation_mismatched"], count_lim=200)

In [26]:
val_m_ents[0]

3262

In [27]:
mnli['validation_matched'][int(val_m_ents[0])], mnli['validation_matched'][int(val_m_conts[0])]

({'premise': 'Shoot only the ones that face us, Jon had told Adrin.',
  'hypothesis': 'Jon instructed Adrin to only shoot the ones that face us.',
  'label': 0,
  'idx': 3262},
 {'premise': 'The rain had stopped, but the green glow painted everything around them.',
  'hypothesis': 'The red glow painted everything around them after the rain had stopped.',
  'label': 2,
  'idx': 8901})

In [28]:
all_indices = {"train_ents" : list(train_ents), "train_conts" : list(train_conts), "val_m_ents" : list(val_m_ents), "val_m_conts" : list(val_m_conts), "val_mm_ents" : list(val_mm_ents), "val_mm_conts" : list(val_mm_conts)}

In [29]:
for key, value in all_indices.items():
    print(key)
    print(len(value))

train_ents
3179
train_conts
504
val_m_ents
144
val_m_conts
36
val_mm_ents
144
val_mm_conts
28


In [30]:
type(all_indices['val_m_ents'][0])

numpy.int64

In [31]:
import numpy
import json

class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, numpy.integer):
            return int(obj)
        elif isinstance(obj, numpy.floating):
            return float(obj)
        elif isinstance(obj, numpy.ndarray):
            return obj.tolist()
        else:
            return super(MyEncoder, self).default(obj)

In [32]:
with open("/home/meissner/shortcut-pruning/data/subsets/mnli_lex_overlap_indices.json", "w") as _file:
    _file.write(json.dumps(all_indices, cls=MyEncoder))

## Negation without lexical overlap

In [39]:
val_negs = all_neg_indices["val_m_ents"]

In [40]:
val_neg_ent_samples = mnli['validation_matched'].select(val_negs)

In [41]:
top_ent_arr, top_cont_arr = find_high_lex_overlap_idxs(val_neg_ent_samples, count_lim=0)

In [31]:
top_ent_arr

array([284,  87, 186, 150, 471, 466, 595, 553, 274, 318,   3,   1, 567,
       465, 242, 425, 478, 299, 387, 126, 199, 217, 569, 128, 346, 434,
       428, 545, 506, 502, 234, 246, 268, 243, 447, 120, 556,  63, 517,
       539, 152, 170, 207, 280, 245, 116, 522, 562, 123, 139, 205, 203,
       481, 554, 140, 297, 563, 163, 173, 209, 108, 578, 180, 586, 229,
       571,  17, 546, 550, 508, 467, 371, 518, 175, 459, 415,  45, 135,
       394, 104, 480, 392,  97, 399, 412, 460, 379, 505,  67, 233, 343,
       498, 538, 141,  53, 212, 112,  41, 368, 252, 166, 362, 125, 489,
       106, 319, 367, 220, 321, 192, 335,  92, 581, 464, 155, 390, 598,
       285, 499, 194, 329, 105, 408, 317,  78, 584, 146, 406, 475, 528,
       336, 202, 369, 386, 440, 495,  66, 544, 560, 168,  39, 215, 491,
       119, 383, 332, 345, 354, 124, 353, 342, 366, 527, 130,  19, 298,
       521,  76, 174, 444, 303, 251,  70,  14, 576, 441,  20,  81,  30,
       237, 497, 501, 289, 219, 402, 277, 424,  62, 312, 283, 48

In [42]:
len(val_negs), len(top_ent_arr), len(top_cont_arr)

(319, 319, 0)

In [43]:
val_negs[top_ent_arr[-1]]

4052

In [48]:
for i in top_ent_arr[:5]:
    print(mnli['validation_matched'][val_negs[i]])

{'premise': 'In short, most of the whale is incompressible.', 'hypothesis': 'Whales cannot be compressed well.', 'label': 0, 'idx': 1445}
{'premise': "Several of its beaches are officially designated for nudism (known locally as naturisme) the most popular being Pointe Tarare and a functionary who is a Chevalier de la L??gion d'Honneur has been appointed to supervise all aspects of sunning in the buff.", 'hypothesis': 'They do not mind having nude people.', 'label': 0, 'idx': 2517}
{'premise': 'Normally, these discussions are kept secret.', 'hypothesis': 'In usual circumstances, what is said is not to be shared..', 'label': 0, 'idx': 7252}
{'premise': 'Others watched them with cold eyes and expressionless faces.', 'hypothesis': 'Some people who were not emotive were watching.', 'label': 0, 'idx': 4671}
{'premise': 'Even though national saving remains relatively low by U.S. historical standards, economic growth in recent years has been high because more and better investments were made.

In [44]:
mnli['validation_matched'][4052]

{'premise': 'Part of the reason for the difference in pieces per possible delivery may be due to the fact that five percent of possible residential deliveries are businesses, and it is thought, but not known, that a lesser percentage of possible deliveries on rural routes are businesses.',
 'hypothesis': 'It is thought, but not known, that a lesser percentage of possible deliveries on rural routes are businesses, and part of the reason for the difference in pieces per possible delivery, may be due to the fact that five percent of possible residential deliveries are businesses.',
 'label': 0,
 'idx': 4052}

In [53]:
type(val_negs[0])

int

In [54]:
neg_ent_new_indices = [val_negs[i] for i in top_ent_arr[:50]] 

In [55]:
with open("/home/meissner/shortcut-pruning/data/subsets/mnli_neg_ent_new2.json", "w") as _file:
    _file.write(json.dumps(neg_ent_new_indices))