In [1]:
import datasets
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM

import mlm
from mlm.scorers import MLMScorerPT 
from mlm.models import get_pretrained

import mxnet as mx
ctxs = [mx.cpu()]

import pathlib
import os

from dataset_orm import *
from wordbank_tasks import *

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker



## NOTES

* **TODO**: verify the surgery I did to the mlm.scorers codebase to accept RoBERTa is legit
* If we want other models, we'll have to add them there, too, perhaps with a bit more work if their output format is very different
* The function below implements the very basic test. Next steps I can see us wanting to do:
    * Combine it with sentences from the real data
    * Check at least two alternative word-replacement strategies (within category, between categories)
    * Write more of a pipeline that samples words, sentences, replacement words for each sentence, and spits out scorers
* Open questions:
    * How do we measure how well the model did? Rank of the correct sentence? NLL difference from correct and other best-performing sentence? Both? 

In [2]:
def scorer_from_transformers_checkpoint(checkpotint_name, contexts):
    tokenizer = AutoTokenizer.from_pretrained(checkpotint_name)
    model = AutoModelForMaskedLM.from_pretrained(checkpotint_name)
    return MLMScorerPT(model, None, tokenizer, contexts)

roberta_scorer = scorer_from_transformers_checkpoint('nyu-mll/roberta-base-100M-1', ctxs)
bert_scorer = scorer_from_transformers_checkpoint('bert-base-uncased', ctxs)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
DB_FILE = 'wordbank.db'
DB_PATH = pathlib.Path(os.getcwd()).parent.absolute() / 'data' / DB_FILE
engine = create_engine(f'sqlite:///{DB_PATH}')
Session = sessionmaker(bind=engine)

In [4]:
session = Session()
discriminative_task_all_words(
    session_maker=Session, n_sentences_per_word=5, n_alternative_words=5,
    model_names=('bert', 'roberta'), model_scorers=[bert_scorer, roberta_scorer],
    criterion_func=smallest_nll_criterion)

  0%|          | 2/592 [01:55<9:28:50, 57.85s/it]


KeyboardInterrupt: 

In [9]:
session = Session()
table_word = session.query(WordbankWord).filter(WordbankWord.word == 'table').one()

In [10]:
word_query = session.query(WordbankWord.id, WordbankWord.word)

In [40]:
l = [[(0, 'a'), (1, 'b'), (2, 'c')], [(10, 'd'), (11, 'e'), (12, 'f')]]
ids, words = list(zip(*[list(zip(*x)) for x in l]))
print(ids)
print(words)

((0, 1, 2), (10, 11, 12))
(('a', 'b', 'c'), ('d', 'e', 'f'))


In [41]:
list(zip(*[(0, 'a'), (1, 'b'), (2, 'c')]))

[(0, 1, 2), ('a', 'b', 'c')]

In [24]:
import pandas as pd
words_df = pd.read_csv('../data/worbank_with_category.tsv', delimiter='\t')

In [29]:
sum([len(s.split(' ')) > 1 for s in words_df.value])

42

In [31]:
words_df.value[[len(s.split(' ')) > 1 for s in words_df.value]]

0                  a lot
6               all gone
22               baa baa
25     babysitter's name
47          belly button
111     child's own name
114            choo choo
199         french fries
207          gas station
213        give me five!
219             go potty
220       gonna get you!
221             going to
224               got to
230          green beans
233              have to
257           high chair
274            ice cream
299           lawn mower
301               let me
309          living room
350              need to
352              next to
355          night night
369            on top of
378              boo boo
388        peanut butter
398           pet's name
409           play dough
410             play pen
421          potato chip
436          quack quack
447        rocking chair
496              so big!
538            thank you
548    this little piggy
581          turn around
584                uh oh
599              want to
601      washing machine
