## Initialize

In [None]:
%%capture
!pip3 install transformers

import json
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForMaskedLM
import matplotlib.pyplot as plt

In [None]:
%%capture
!git clone https://github.com/alexwarstadt/blimp

!gdown -O en_bert.tar.gz https://drive.google.com/uc?id=1-VJjnqLGKafSoiELTHmg-REVBz0a0QEC  # en BERT
!tar xzf en_bert.tar.gz

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

our_tokenizer = AutoTokenizer.from_pretrained("en_bert")
our_model = AutoModelForMaskedLM.from_pretrained("en_bert").to(device)

original_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
original_model = AutoModelForMaskedLM.from_pretrained("bert-base-cased").to(device)

In [3]:
def load(filename):
    pairs = []
    with open(filename) as f:
        for line in f.readlines():
            pairs.append(json.loads(line))
    return pairs

In [12]:
def get_ppl(sentence, model, tokenizer, device):
    MASK_ID = tokenizer.convert_tokens_to_ids("[MASK]")

    tokens = tokenizer(sentence, return_tensors='pt')["input_ids"].to(device)
    tokens = tokens.repeat(tokens.size(1) - 2, 1)
    mask = torch.eye(tokens.size(1), device=device).bool()[1:-1, :]
    bert_input = tokens.masked_fill(mask, value=MASK_ID)

    logits = model(input_ids=bert_input)["logits"]
    log_p = F.log_softmax(logits, dim=-1)
    log_p = log_p.gather(index=tokens.unsqueeze(-1), dim=-1).squeeze(-1)
    
    ppl = log_p.masked_fill(~mask, 0.0).sum().item()
    return ppl

In [5]:
@torch.no_grad()
def run(model, tokenizer, pairs, device):
    correct = 0
    for pair in pairs:
        good, bad = pair["sentence_good"], pair["sentence_bad"]

        good_ppl = get_ppl(good, model, tokenizer, device)
        bad_ppl = get_ppl(bad, model, tokenizer, device)

        if good_ppl > bad_ppl:
            correct += 1

    return correct / len(pairs)

In [6]:
class Blimp:
    def __init__(self):
        self.phenomena = {}
    
    def add_result(self, phenomenon, uid, accuracy):
        if phenomenon not in self.phenomena:
            self.phenomena[phenomenon] = {}
        self.phenomena[phenomenon][uid] = accuracy
    
    def __str__(self):
        def iterator():
            for phenomenon_key in sorted(self.phenomena.keys()):
                phenomenon = self.phenomena[phenomenon_key]
                for uid_key in sorted(phenomenon.keys()):
                    yield f"{phenomenon_key},{uid_key},{phenomenon[uid_key]}"
        return '\n'.join(iterator())

## BLiMP on our model

In [15]:
blimp = Blimp()

for dataset in glob.glob(f"blimp/data/*.jsonl"):
    dataset = load(dataset)
    accuracy = run(our_model, our_tokenizer, dataset, device)
    blimp.add_result(dataset[0]["linguistics_term"], dataset[0]["UID"], accuracy)

print(blimp)

binding principle_A_case_2 0.97
argument_structure passive_1 0.743
filler_gap_dependency wh_vs_that_with_gap 0.647
island_effects complex_NP_island 0.399
filler_gap_dependency wh_vs_that_with_gap_long_distance 0.35
determiner_noun_agreement determiner_noun_agreement_irregular_1 0.864
binding principle_A_c_command 0.558
filler_gap_dependency wh_questions_object_gap 0.788
control_raising tough_vs_raising_1 0.461
control_raising existential_there_subject_raising 0.877
filler_gap_dependency wh_vs_that_no_gap_long_distance 0.964
island_effects left_branch_island_echo_question 0.202
subject_verb_agreement distractor_agreement_relational_noun 0.842
argument_structure intransitive 0.725
quantifiers superlative_quantifiers_1 0.824
argument_structure passive_2 0.777
control_raising existential_there_object_raising 0.762
control_raising tough_vs_raising_2 0.759
npi_licensing matrix_question_npi_licensor_present 0.714
subject_verb_agreement regular_plural_subject_verb_agreement_2 0.898
island_effe

## BLiMP on the original BERT

In [16]:
blimp = Blimp()

for dataset in glob.glob(f"blimp/data/*.jsonl"):
    dataset = load(dataset)
    accuracy = run(original_model, original_tokenizer, dataset, device)
    blimp.add_result(dataset[0]["linguistics_term"], dataset[0]["UID"], accuracy)

print(blimp)

binding principle_A_case_2 0.973
argument_structure passive_1 0.785
filler_gap_dependency wh_vs_that_with_gap 0.731
island_effects complex_NP_island 0.56
filler_gap_dependency wh_vs_that_with_gap_long_distance 0.584
determiner_noun_agreement determiner_noun_agreement_irregular_1 0.916
binding principle_A_c_command 0.675
filler_gap_dependency wh_questions_object_gap 0.892
control_raising tough_vs_raising_1 0.686
control_raising existential_there_subject_raising 0.921
filler_gap_dependency wh_vs_that_no_gap_long_distance 0.957
island_effects left_branch_island_echo_question 0.568
subject_verb_agreement distractor_agreement_relational_noun 0.951
argument_structure intransitive 0.846
quantifiers superlative_quantifiers_1 0.913
argument_structure passive_2 0.877
control_raising existential_there_object_raising 0.767
control_raising tough_vs_raising_2 0.89
npi_licensing matrix_question_npi_licensor_present 0.932
subject_verb_agreement regular_plural_subject_verb_agreement_2 0.961
island_effe