In [10]:
!git clone https://github.com/ml-utils/bert-syntax-it.git
!mv ./bert-syntax-it/marvin_linzen_dataset.tsv ./marvin_linzen_dataset.tsv

Cloning into 'bert-syntax-it'...
remote: Enumerating objects: 206, done.[K
remote: Counting objects: 100% (206/206), done.[K
remote: Compressing objects: 100% (127/127), done.[K
remote: Total 206 (delta 93), reused 183 (delta 71), pack-reused 0[K
Receiving objects: 100% (206/206), 3.12 MiB | 8.21 MiB/s, done.
Resolving deltas: 100% (93/93), done.


In [12]:

!ls -la

total 20244
drwxr-xr-x 1 root root     4096 May  3 07:09 .
drwxr-xr-x 1 root root     4096 May  3 06:44 ..
drwxr-xr-x 6 root root     4096 May  3 07:09 bert-syntax-it
drwxr-xr-x 4 root root     4096 Apr 29 03:18 .config
-rw-r--r-- 1 root root 20705886 May  3 07:08 marvin_linzen_dataset.tsv
drwxr-xr-x 1 root root     4096 Apr 29 03:19 sample_data


In [4]:
!pip install folium==0.2.1
!pip install pytorch-pretrained-bert

Collecting folium==0.2.1
  Downloading folium-0.2.1.tar.gz (69 kB)
[K     |████████████████████████████████| 69 kB 2.9 MB/s 
Building wheels for collected packages: folium
  Building wheel for folium (setup.py) ... [?25l[?25hdone
  Created wheel for folium: filename=folium-0.2.1-py3-none-any.whl size=79808 sha256=6a7636c4cb740c89c93f55df6ccd52db9c3e144d4ea8ce63daf6b9e6e2e22920
  Stored in directory: /root/.cache/pip/wheels/9a/f0/3a/3f79a6914ff5affaf50cabad60c9f4d565283283c97f0bdccf
Successfully built folium
Installing collected packages: folium
  Attempting uninstall: folium
    Found existing installation: folium 0.8.3
    Uninstalling folium-0.8.3:
      Successfully uninstalled folium-0.8.3
Successfully installed folium-0.2.1
Collecting pytorch-pretrained-bert
  Downloading pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123 kB)
[K     |████████████████████████████████| 123 kB 4.2 MB/s 
[?25hCollecting boto3
  Downloading boto3-1.22.5-py3-none-any.whl (132 kB)
[K     |████████

In [5]:
import os.path

from pytorch_pretrained_bert import BertForMaskedLM,tokenization
import torch
import argparse, sys
import csv

In [7]:
def get_probs_for_words(bert,tokenizer,sent,w1,w2):
    print(f'sent: {sent}')
    pre,target,post=sent.split('***')
    print(f'pre: {pre}, target: {target}, post: {post}')
    if 'mask' in target.lower():
        target=['[MASK]']
    else:
        target=tokenizer.tokenize(target)

    # todo, fixme: the vocabulary of the pretrained model from Kaj does not have entries for CLS, UNK
    # fixme: tokenizer.tokenize(pre), does not recognize the words
    tokens=['[CLS]']+tokenizer.tokenize(pre)  # tokens = tokenizer.tokenize(pre)

    target_idx=len(tokens)

    #print(target_idx)
    tokens+=target+tokenizer.tokenize(post)+['[SEP]']
    print(f'tokens {tokens}')
    input_ids=tokenizer.convert_tokens_to_ids(tokens)
    try:
        word_ids=tokenizer.convert_tokens_to_ids([w1,w2])
    except KeyError:
        print("skipping",w1,w2,"bad wins")
        return None
    tens=torch.LongTensor(input_ids).unsqueeze(0)
    res=bert(tens)[0,target_idx]
    #res=torch.nn.functional.softmax(res,-1)
    scores = res[word_ids]
    return [float(x) for x in scores]

from collections import Counter
def load_marvin():
    cc = Counter()
    # note: I edited the LM_Syneval/src/make_templates.py script, and run "python LM_Syneval/src/make_templates.py LM_Syneval/data/templates/ > marvin_linzen_dataset.tsv"
    out = []
    for line in open("marvin_linzen_dataset.tsv"):
        case = line.strip().split("\t")
        cc[case[1]]+=1
        g,ug = case[-2],case[-1]
        g = g.split()
        ug = ug.split()
        assert(len(g)==len(ug)),(g,ug)
        diffs = [i for i,pair in enumerate(zip(g,ug)) if pair[0]!=pair[1]]
        if (len(diffs)!=1):
            #print(diffs)
            #print(g,ug)
            continue    
        assert(len(diffs)==1),diffs
        gv=g[diffs[0]]   # good
        ugv=ug[diffs[0]] # bad
        g[diffs[0]]="***mask***"
        g.append(".")
        out.append((case[0],case[1]," ".join(g),gv,ugv))
    return out

def eval_marvin(bert,tokenizer):
    o = load_marvin()
    print(len(o),file=sys.stderr)
    from collections import defaultdict
    import time
    rc = defaultdict(Counter)
    tc = Counter()
    start = time.time()
    for i,(case,tp,s,g,b) in enumerate(o):
        ps = get_probs_for_words(bert,tokenizer,s,g,b)
        if ps is None: ps = [0,1]
        gp = ps[0]
        bp = ps[1]
        print(gp>bp,case,tp,g,b,s)
        if i % 100==0:
            print(i,time.time()-start,file=sys.stderr)
            start=time.time()
            sys.stdout.flush()

def eval_lgd(bert,tokenizer):
    for i,line in enumerate(open("lgd_dataset_with_is_are.tsv",encoding="utf8")):
        na,_,masked,good,bad = line.strip().split("\t")
        ps = get_probs_for_words(bert,tokenizer,masked,good,bad)
        if ps is None: continue
        gp = ps[0]
        bp = ps[1]
        print(str(gp>bp),na,good,gp,bad,bp,masked.encode("utf8"),sep=u"\t")
        if i%100 == 0:
            print(i,file=sys.stderr)
            sys.stdout.flush()


def read_gulordava():
    rows = csv.DictReader(open("generated.tab",encoding="utf8"),delimiter="\t")
    data=[]
    for row in rows:
        row2=next(rows)
        assert(row['sent']==row2['sent'])
        assert(row['class']=='correct')
        assert(row2['class']=='wrong')
        sent = row['sent'].lower().split()[:-1] # dump the <eos> token.
        good_form = row['form']
        bad_form  = row2['form']
        sent[int(row['len_prefix'])]="***mask***"
        sent = " ".join(sent)
        data.append((sent,row['n_attr'],good_form,bad_form))
    return data

def eval_gulordava(bert,tokenizer):
    for i,(masked,natt,good,bad) in enumerate(read_gulordava()):
        if good in ["is","are"]:
            print("skipping is/are")
            continue
        ps = get_probs_for_words(bert,tokenizer,masked,good,bad)
        if ps is None: continue
        gp = ps[0]
        bp = ps[1]
        print(str(gp>bp),natt,good,gp,bad,bp,masked.encode("utf8"),sep=u"\t")
        if i%100 == 0:
            print(i,file=sys.stderr)
            sys.stdout.flush()

# choose_eval()


def init_bert_model(model_name):
    # model_name = 'bert-large-uncased'
    #if 'base' in sys.argv: model_name = 'bert-base-uncased'
    print(f'model_name: {model_name}')
    print("using model:", model_name, file=sys.stderr)
    bert = BertForMaskedLM.from_pretrained(model_name)
    print("bert model loaded, getting the tokenizer..")
    vocab_filepath = os.path.join(model_name, 'dict.txt')
    tokenizer = tokenization.BertTokenizer.from_pretrained(vocab_filepath)
    print("tokenizer ready.")

    bert.eval()
    return bert, tokenizer


def run_eval(eval_suite, bert, tokenizer):
    print('running eval..')
    if 'marvin' == eval_suite:
        eval_marvin(bert,tokenizer)
    elif 'gul' == eval_suite:
        eval_gulordava(bert,tokenizer)
    else:
        eval_lgd(bert,tokenizer)


def arg_parse():
    print('parsing args..')
    # Python program to demonstrate
    # command line arguments

    import getopt, sys

    # Remove 1st argument from the
    # list of command line arguments
    argumentList = sys.argv[1:]

    options = "be:"

    # Long options
    long_options = ["bert_model", "eval_suite"]

    DEFAULT_MODEL = 'bert-large-uncased'
    DEFAULT_EVAL_SUITE = 'lgd'
    model_name = DEFAULT_MODEL
    eval_suite = DEFAULT_EVAL_SUITE

    try:
        # Parsing argument
        print(f'argumentList: {argumentList}')

        # checking each argument
        for arg_idx, currentArgument  in enumerate(argumentList):
            print(f'persing currentArgument {currentArgument}')
            if currentArgument in ("-h", "--Help"):
                print("Displaying Help")

            elif currentArgument in ("-b", "--bert_model"):

                argValue = argumentList[arg_idx+1]
                print(f'currentArgument: {currentArgument}, argValue: {argValue}')
                if argValue == 'base':
                    model_name = 'bert-base-uncased'
                else:
                    model_name = argValue
                    print(f'set model_name: {model_name}')

            elif currentArgument in ("-e", "--eval_suite"):
                argValue = argumentList[arg_idx + 1]
                print(f'currentArgument: {currentArgument}, argValue: {argValue}')
                eval_suite = argValue

    except getopt.error as err:
        # output error, and return with an error code
        print(str(err))

    print(f'model_name {model_name}, eval_suite {eval_suite}')
    return model_name, eval_suite


def main():
    print('main')
    model_name, eval_suite = arg_parse()
    eval_suite = 'marvin'
    bert, tokenizer = init_bert_model(model_name)
    run_eval(eval_suite, bert, tokenizer)




In [13]:
main()

main
parsing args..
argumentList: ['-f', '/root/.local/share/jupyter/runtime/kernel-3dc272cd-54fa-4625-9756-0ba678ef1978.json']
persing currentArgument -f
persing currentArgument /root/.local/share/jupyter/runtime/kernel-3dc272cd-54fa-4625-9756-0ba678ef1978.json
model_name bert-large-uncased, eval_suite lgd
model_name: bert-large-uncased


using model: bert-large-uncased
Model name 'bert-large-uncased/dict.txt' was not found in model name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese). We assumed 'bert-large-uncased/dict.txt' was a path or url but couldn't find any file associated to this path or url.


bert model loaded, getting the tokenizer..
tokenizer ready.
running eval..
sent: the author that the guard likes ***mask*** .
pre: the author that the guard likes , target: mask, post:  .


147506


AttributeError: ignored