In [1]:
import numpy as np
import torch
import transformers

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
from life_after_bert import load_olmpics_data, MCDataset, evaluate_encoder, evaluate_decoder, evaluate_encoder_decoder

In [3]:
from torch.utils.data import Dataset
class PunctuationDataset(Dataset):  # TODO: move from notebook to src
    """ Misnomer, assumes periods and not question marks """
    def __init__(self, questions, choices, answer_ids, tokenizer, mask_token=None, max_length=25, punctuation=True):
        mask_token = mask_token if mask_token is not None else tokenizer.mask_token
        assert mask_token is not None, "mask_token must be provided if tokenizer.mask_token does not exist"
        questions = [question.replace("[MASK]", mask_token).strip(" ") for question in questions]
        if not punctuation:
            questions = [question.strip(".").strip(" ") for question in questions]  # Some examples have space before period
        else:
            questions = [question if question.endswith(".") else f"{question}." for question in questions]
        
        out = tokenizer(questions, max_length=max_length, padding="max_length")
        self.input_ids = out["input_ids"]
        self.attention_mask = out["attention_mask"]
        self.questions = questions
        self.choices = choices
        self.answer_ids = answer_ids

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, i):
        return {
            "input_ids": self.input_ids[i],
            "attention_mask": self.attention_mask[i],
            "choice_list": self.choices[i],
            "answer_id": self.answer_ids[i],
        }

In [5]:
model_names = [("bert-base-uncased", None), ("roberta-large", None), ("albert-large-v1", None), ("gpt2-large", "[MASK]"), ("t5-large", "<extra_id_0>")]
def get_model_type(model_name):
    if "t5" in model_name:
        return "t5"
    if "gpt" in model_name:
        return "decoder"
    if "bert" in model_name:
        return "encoder"
    
    raise NotImplementedError

task = "oLMpics MLM"

eval_datasets = [("age_comparison", 2), ("size_comparison", 2), ("antonym_negation", 2), 
                 ("always_never", 5), ("taxonomy_conjunction", 3), ("multihop_composition", 3)]

In [6]:
all_accs_dict = {}
for model_name, mask_token in model_names:
    model_type = get_model_type(model_name)
    if model_type == "t5":
        model = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
    elif model_type == "encoder":
        model = transformers.AutoModelForMaskedLM.from_pretrained(model_name)
    elif model_type == "decoder":
        model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
    else:
        raise NotImplementedError

    model.eval()
    model.to(device)
    if mask_token is not None:
        tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, mask_token=mask_token)
    else:
        tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
        
    if tokenizer.pad_token == None:
        print("Defaulting pad token to EOS token.")
        tokenizer.pad_token = tokenizer.eos_token
    
    all_accs = []
    for data_path, num_choices in eval_datasets:
        max_length = 26 if data_path == "taxonomy_conjunction" else 25
        questions, choice_lists, answer_ids = load_olmpics_data(f"../tests/data/oLMpics_{data_path}_dev.jsonl", num_choices, progress_bar=False)

        acc_list = []
        for punctuation in [True, False]:
            dataset = PunctuationDataset(questions, choice_lists, answer_ids, tokenizer, max_length=max_length, punctuation=punctuation)

            if model_type == "t5":
                decoder_prompt = tokenizer("<pad> <extra_id_0>", add_special_tokens=False, return_tensors="pt").input_ids
                all_answers, all_preds = evaluate_encoder_decoder(model, tokenizer, task, dataset, decoder_prompt, device, progress_bar=False)
            elif model_type == "encoder":
                all_answers, all_preds = evaluate_encoder(model, tokenizer, task, dataset, device, progress_bar=False)
            elif model_type == "decoder":
                all_answers, all_preds = evaluate_decoder(model, tokenizer, task, num_choices, dataset, device, progress_bar=False)
            else:
                raise NotImplementedError

            acc_list.append((np.array(all_answers) == np.array(all_preds)).mean())

        all_accs.append((data_path, acc_list))

    model_accs_dict = {}
    print(f"Accuracy for {model_name}:")
    for task_name, accs in all_accs:
        model_accs_dict[task_name] = accs
        print(f"Task: {task_name} \t Accuracy with punctuation: {accs[0]} \t Accuracy without punctuation: {accs[1]}")
        
    all_accs_dict[model_name] = model_accs_dict

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy for bert-base-uncased:
Task: age_comparison 	 Accuracy with punctuation: 0.494 	 Accuracy without punctuation: 0.494
Task: size_comparison 	 Accuracy with punctuation: 0.554 	 Accuracy without punctuation: 0.556
Task: antonym_negation 	 Accuracy with punctuation: 0.538 	 Accuracy without punctuation: 0.532
Task: always_never 	 Accuracy with punctuation: 0.13214285714285715 	 Accuracy without punctuation: 0.1
Task: taxonomy_conjunction 	 Accuracy with punctuation: 0.44073455759599334 	 Accuracy without punctuation: 0.4674457429048414
Task: multihop_composition 	 Accuracy with punctuation: 0.322 	 Accuracy without punctuation: 0.332
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 

Downloading:   0%|          | 0.00/685 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/68.2M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/742k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1 token: pasta
Answer choice more than 1

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Using pad_token, but it is not set yet.


Defaulting pad token to EOS token.
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more than 1 token: primate
Answer choice more tha

Answer choice more than 1 token: primate
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: primate
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: mammal
Answer choice more than 1 token: primate
Answer choice more than 1 token: deer
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: primate
Answer choice more than 1 token: mammal
Answer choice more than 1 token: primate
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: mammal
Answer choice more than 1 token: primate
Answer

Answer choice more than 1 token: primate
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: primate
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: mammal
Answer choice more than 1 token: primate
Answer choice more than 1 token: deer
Answer choice more than 1 token: mammal
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: primate
Answer choice more than 1 token: mammal
Answer choice more than 1 token: primate
Answer choice more than 1 token: mammal
Answer choice more than 1 token: deer
Answer choice more than 1 token: mammal
Answer choice more than 1 token: primate
Answer

In [7]:
all_accs_dict

{'bert-base-uncased': {'age_comparison': [0.494, 0.494],
  'size_comparison': [0.554, 0.556],
  'antonym_negation': [0.538, 0.532],
  'always_never': [0.13214285714285715, 0.1],
  'taxonomy_conjunction': [0.44073455759599334, 0.4674457429048414],
  'multihop_composition': [0.322, 0.332]},
 'roberta-large': {'age_comparison': [0.986, 0.978],
  'size_comparison': [0.874, 0.836],
  'antonym_negation': [0.746, 0.664],
  'always_never': [0.1357142857142857, 0.16071428571428573],
  'taxonomy_conjunction': [0.43906510851419034, 0.4607679465776294],
  'multihop_composition': [0.296, 0.29]},
 'albert-large-v1': {'age_comparison': [0.53, 0.514],
  'size_comparison': [0.492, 0.492],
  'antonym_negation': [0.502, 0.51],
  'always_never': [0.30714285714285716, 0.21785714285714286],
  'taxonomy_conjunction': [0.5091819699499165, 0.5191986644407346],
  'multihop_composition': [0.34, 0.34]},
 'gpt2-large': {'age_comparison': [0.696, 0.7],
  'size_comparison': [0.508, 0.506],
  'antonym_negation': [0.5

In [11]:
for model_name in all_accs_dict.keys():
    print(f"Accuracy for {model_name}:")
    for task_name, accs in all_accs_dict[model_name].items():
        print(f"Task: {task_name} \t Accuracy with punctuation: {accs[0]} \t Accuracy without punctuation: {accs[1]}")
    print("\n")

Accuracy for bert-base-uncased:
Task: age_comparison 	 Accuracy with punctuation: 0.494 	 Accuracy without punctuation: 0.494
Task: size_comparison 	 Accuracy with punctuation: 0.554 	 Accuracy without punctuation: 0.556
Task: antonym_negation 	 Accuracy with punctuation: 0.538 	 Accuracy without punctuation: 0.532
Task: always_never 	 Accuracy with punctuation: 0.13214285714285715 	 Accuracy without punctuation: 0.1
Task: taxonomy_conjunction 	 Accuracy with punctuation: 0.44073455759599334 	 Accuracy without punctuation: 0.4674457429048414
Task: multihop_composition 	 Accuracy with punctuation: 0.322 	 Accuracy without punctuation: 0.332


Accuracy for roberta-large:
Task: age_comparison 	 Accuracy with punctuation: 0.986 	 Accuracy without punctuation: 0.978
Task: size_comparison 	 Accuracy with punctuation: 0.874 	 Accuracy without punctuation: 0.836
Task: antonym_negation 	 Accuracy with punctuation: 0.746 	 Accuracy without punctuation: 0.664
Task: always_never 	 Accuracy with pu