In [1]:
# Group: John Strenio, Scott Klinn, Tuan Nguyen
# Commonsense QA Sequence Scoring Method 
# CS 510: Adventures in NLP
# Professor Ameeta Agrawal
# Contributors: John Strenio, Scott Klinn

# =========== Summary & Instructions ===============================
# This program takes a pretrained bert model and tests it on the 
# commonsenseQA validation set using the 1-gram sequence scoring 
# method as proposed in the paper: 'pretraining is almost all you 
# need' using the last 611 examples as per the paper. It utilizes a
# custom testing structure to allow for a 1 to 1 comparison with a 
# finetuned bert using Next Sentence Prediction and tested on the 
# same testing format. It is also compared with bert without any
# finetuning or scoring method.
# ==================================================================

from datasets import load_dataset
from transformers import AutoTokenizer, DistilBertTokenizerFast, DistilBertForMaskedLM, BertTokenizer, BertForMaskedLM
import torch
import numpy as np
from math import log10

# load the class requisites 
model_class = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizerFast.from_pretrained(model_class)
model = DistilBertForMaskedLM.from_pretrained(model_class)

# load the validation set
valid_set = load_dataset('commonsense_qa', split='validation[-611:]')

# the scoring method
def target_premise_score(probs):
    return sum([log10(prob) for prob in probs])

Using custom data configuration default
Reusing dataset commonsense_qa (C:\Users\johns\.cache\huggingface\datasets\commonsense_qa\default\0.1.0\1ca2d7b680c5bd93c0dc85f9cb65c0c8817e759ff82e405b28de54e83efa80f7)


In [2]:
import random
correct = raw_cor = 0
mask = '[MASK]'
mask_id = tokenizer(mask)['input_ids'][1]

# process each example
for example_count, example in enumerate(valid_set):

    # extract questions/answers
    question = example['question']
    ans_text = [choice for choice in example['choices']['text']]
    choices = []

    # the correct answer will be first, the other answer will be selected randomly from whats left
    # (this follows the same custom testing scheme as the finetuned bert is tested on)
    choices.append(ans_text.pop(ord(example['answerKey']) - 65))
    choices.append(random.choice(ans_text))

    hypothesis_scores = []
    raw_outputs = []
    final_scores = [0] * len(choices)

    # for both of the possibly correct answers
    for c in range(len(choices)):
        masked_inputs = []
        unmasked_labels = []

        # individually mask each token, save the sentence for input as a group of 'correct' or 'wrong'
        masked_sent_copies = 0
        for t in range(len(question.split())):
            # double call to get new object to alter independently
            tokens = question.split()

            # check for incorrectly separated punctuation and skip it; that wouldn't be a token
            if len(tokens[t]) == 1 and tokens[t].isalnum() != True:
                continue

            # generate truth label
            unmasked_sentence = ['Q:'] + tokens + ['A:'] + [choices[c]]

            # mask/concatenate single sentence form and add to input list
            tokens[t] = mask
            masked_concat = ['Q:'] + tokens + ['A:'] + [choices[c]]
            masked_inputs.append(tokenizer.convert_tokens_to_string(masked_concat))
            unmasked_labels.append(tokenizer.convert_tokens_to_string(unmasked_sentence))
            masked_sent_copies += 1

        # encode the inputs
        inputs = tokenizer(masked_inputs, padding='longest', truncation=True, return_tensors='pt')
        labels = tokenizer(unmasked_labels, padding='longest', truncation=True, return_tensors="pt")["input_ids"]
        
        probs_to_score = []

        # returns the 2d positions of the masked indices in the inputs
        masked_index = (inputs["input_ids"] == mask_id).nonzero()
        
        masked_correct_answers = []

        # link correct answers
        for i in range(masked_sent_copies):
            masked_correct_answers.append({})

        for i in masked_index:
            masked_correct_answers[i[0]].update({i[1]:labels[0][i[1]]})

        # test model on encoded inputs
        outputs = model(**inputs)
        activate = torch.nn.Softmax(dim=2)
        probabilities = activate(outputs.logits)

        # store for comparison
        raw_outputs.append(outputs)

        # collect probs for each word for each masked sentence
        for sent in range(masked_sent_copies):
            whole_prob = 1

            for prob in masked_correct_answers[sent]:
                word = masked_correct_answers[sent][prob]
                whole_prob *= float(probabilities[sent][prob][word])

            probs_to_score.append(whole_prob)

        score = target_premise_score(probs_to_score)
        hypothesis_scores.append(score)

    # compare the score of the correct answer to all the scores
    if max(hypothesis_scores) == hypothesis_scores[0]: # the choices are ordered 1 correct, 1 wrong
        correct += 1

    # retrieve results without scoring method
    out1 = raw_outputs[0].logits.softmax(dim=-1).tolist()
    out2 = raw_outputs[1].logits.softmax(dim=-1).tolist()

    # how many correct pairings were predicted over incorrect pairings
    if out1[0][0] > out2[0][0]:
        raw_cor += 1
    
    print('processed: ' + str(example_count))

#print results
print('ssm acc: ' + str(correct / len(valid_set)))
print('raw acc: ' + str(raw_cor / len(valid_set)))


processed: 0
processed: 1
processed: 2
processed: 3
processed: 4
processed: 5
processed: 6
processed: 7
processed: 8
processed: 9
processed: 10
processed: 11
processed: 12
processed: 13
processed: 14
processed: 15
processed: 16
processed: 17
processed: 18
processed: 19
processed: 20
processed: 21
processed: 22
processed: 23
processed: 24
processed: 25
processed: 26
processed: 27
processed: 28
processed: 29
processed: 30
processed: 31
processed: 32
processed: 33
processed: 34
processed: 35
processed: 36
processed: 37
processed: 38
processed: 39
processed: 40
processed: 41
processed: 42
processed: 43
processed: 44
processed: 45
processed: 46
processed: 47
processed: 48
processed: 49
processed: 50
processed: 51
processed: 52
processed: 53
processed: 54
processed: 55
processed: 56
processed: 57
processed: 58
processed: 59
processed: 60
processed: 61
processed: 62
processed: 63
processed: 64
processed: 65
processed: 66
processed: 67
processed: 68
processed: 69
processed: 70
processed: 71
pr