In [4]:
no_trigger = [[896, 76, 21], [ 70, 828, 57], [13, 76, 907]]
universal_trigger = [[256, 327, 410], [6, 246, 703], [6, 157, 833]]
random_trigger = [[707, 217, 69], [57, 802, 96], [14, 70, 912]]

def get_accuracy(matrix):
    for ground_truth in range(3):
        scores = [round(score * 100 / sum(matrix[ground_truth]), 2) for score in matrix[ground_truth]]
        print(f"Ground truth: {ground_truth}, Scores: {scores}")


In [5]:
get_accuracy(no_trigger)
get_accuracy(universal_trigger)
get_accuracy(random_trigger)

Ground truth: 0, Scores: [90.23, 7.65, 2.11]
Ground truth: 1, Scores: [7.33, 86.7, 5.97]
Ground truth: 2, Scores: [1.31, 7.63, 91.06]
Ground truth: 0, Scores: [25.78, 32.93, 41.29]
Ground truth: 1, Scores: [0.63, 25.76, 73.61]
Ground truth: 2, Scores: [0.6, 15.76, 83.63]
Ground truth: 0, Scores: [71.2, 21.85, 6.95]
Ground truth: 1, Scores: [5.97, 83.98, 10.05]
Ground truth: 2, Scores: [1.41, 7.03, 91.57]


Ground Truth    Data                         E% , N% , C%           
Entailment      Validation subset           90.23, 7.65, 2.11
                Challenge Set I             25.78, 32.93, 41.29
                Challenge Set II            71.2, 21.85, 6.95

Neutral         Validation subset           7.33, 86.7, 5.97
                Challenge Set I             0.63, 25.76, 73.61
                Challenge Set II            5.97, 83.98, 10.05

Contradiction   Validation subset           1.31, 7.63, 91.06
                Challenge Set I             0.6, 15.76, 83.63
                Challenge Set II            1.41, 7.03, 91.57

In [6]:
import math
from collections import defaultdict, Counter
from datasets import load_dataset
from tqdm import tqdm

def load_snli_data():
    dataset = load_dataset('snli', split='train')
    return dataset

def get_word_and_class_counts(dataset=load_snli_data()):
    word_counts = defaultdict(lambda: defaultdict(int))
    class_counts = Counter()
    total_words = 0

    for example in tqdm(dataset):
        hypothesis = example['hypothesis'].split()
        label = example['label']
        if label == -1:  # Skip examples with no label
            continue
        class_counts[label] += 1
        for word in hypothesis:
            word_counts[word][label] += 1
            total_words += 1

    return word_counts, class_counts, total_words

Downloading readme:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/412k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/413k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.6M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/550152 [00:00<?, ? examples/s]

In [None]:
word_counts, _, _ = get_word_and_class_counts()

universal_triggers = {
    0: ['nobody', 'no', 'whatsoever', 'excuse', 'unfairly'],
    1: ['cats', 'monkeys', 'crocodiles', 'elephants', 'cat'],
    2: ['joyously', 'celebrating', 'motivational', 'contacting', 'anxiously']
}

random_triggers = {
    0: ['diners', 'sense', 'emerge', 'hands', 'refuge'], 
    1: ['road', 'elders', 'brick', 'mass', 'bicyclists'] ,
    2: ['remain', 'rose', 'towns', 'flashing', 'lip']
}

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 550152/550152 [00:17<00:00, 30881.87it/s]


In [16]:
def majority_class_and_correlation(trigger):
    # Fetch the trigger words across all classes
    trigger_words = []
    for label in trigger:
        trigger_words.extend(trigger[label])
    # Fetch the counts of each trigger word across all classes
    trigger_counts = {}
    for word in trigger_words:
        dummy = [word_counts[word][lbl] for lbl in range(3)]
        trigger_counts[word] = [round(word_counts[word][lbl] /sum(dummy), 2) for lbl in range(3)]
        max_idx = dummy.index(max(dummy))
        print(f"{word} {max_idx} {trigger_counts[word][max_idx]}")
    # print(trigger_counts)

majority_class_and_correlation(universal_triggers)
print("")
majority_class_and_correlation(random_triggers)
    

nobody 2 0.96
no 2 0.83
whatsoever 2 1.0
excuse 1 1.0
unfairly 1 1.0
cats 2 0.96
monkeys 2 0.87
crocodiles 2 1.0
elephants 2 0.62
cat 2 0.85
joyously 1 1.0
celebrating 1 0.79
motivational 1 0.93
contacting 1 0.8
anxiously 1 0.86

diners 1 0.47
sense 1 0.41
emerge 1 0.57
hands 2 0.38
refuge 1 1.0
road 1 0.4
elders 1 0.5
brick 1 0.39
mass 0 0.38
bicyclists 1 0.37
remain 1 0.62
rose 1 0.59
towns 1 0.6
flashing 1 0.5
lip 1 0.5


Universal       class  score       Random   class   score

nobody            2    0.96        diners     1      0.47
no                2    0.83        sense      1      0.41
whatsoever        2    1.0         emerge     1      0.57
excuse            1    1.0         hands      2      0.38
unfairly          1    1.0         refuge     1      1.0
cats              2    0.96        road       1      0.4
monkeys           2    0.87        elders     1      0.5
crocodiles        2    1.0         brick      1      0.39
elephants         2    0.62        mass       0      0.38
cat               2    0.85        bicyclists 1      0.37
joyously          1    1.0         remain     1      0.62
celebrating       1    0.79        rose       1      0.59
motivational      1    0.93        towns      1      0.6
contacting        1    0.8         flashing   1      0.5
anxiously         1    0.86        lip        1      0.5

                        Universal                       Random
                Trigger        class  score       Trigger   class   score

Entailment      nobody            2    0.96        diners     1      0.47
                no                2    0.83        hands      2      0.38

Neutral         cats              2    0.96        road       1      0.4
                cat               2    0.85        mass       0      0.38

Contradiction   joyously          1    1.0         remain     1      0.62
                celebrating       1    0.79        rose       1      0.59
