In [None]:
%load_ext autoreload
%autoreload 2

In [21]:
import math

import torch
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
from tqdm import tqdm_notebook

In [22]:
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [37]:
def get_score(sentence: str):
    tokenize_input = bert_tokenizer.tokenize(sentence)
    tensor_input = torch.tensor([bert_tokenizer.convert_tokens_to_ids(tokenize_input)])
    predictions = bert_model(tensor_input)
    loss_fct = torch.nn.CrossEntropyLoss()
    loss = loss_fct(predictions.squeeze(), tensor_input.squeeze()).data
    return math.exp(loss)

def sorted_tuple_scores(tuple_data_path: str):
    scores = []
    with open(tuple_data_path, 'r') as f:
        for line in tqdm_notebook(f.readlines()):
            line = line.strip()
            sentence = " ".join(line.split('\t'))
            scores.append((line, get_score(sentence)))
    scores.sort(key=lambda x: x[1])
    return scores

def print_scores(scores: list):
    print("Lowest losses:")
    print(*scores[:5], sep="\n")
    print()

    print("Highest losses:")
    print(*scores[-5:], sep="\n")

In [38]:
# aristo_scores = sorted_tuple_scores('data/aristo.txt')
# print(aristo_scores)

In [39]:
openbook_scores = sorted_tuple_scores('data/openbook_tuples.tsv')
print_scores(openbook_scores)

HBox(children=(IntProgress(value=0, max=3793), HTML(value='')))

Lowest losses:
('most canyons\tflowing rivers over\tlong periods of time', 2.158749725759598)
('environment\tis cold at\tnorthern lattitude below 0 degrees celsius during most of year', 3.8701482849718687)
('example\tis\tchimpanzee digging for insects with stick', 4.059842366680984)
('tidal range\tis measure of\tvertical distance from high tide to low tide', 4.235079800987837)
('environment\tis cold at\tlattitude below 0 degrees celsius', 4.264977688931828)

Highest losses:
('magnet\tattracts\tmagnetic metals', 393452.34152887564)
('animal\trequires\twarmth', 401064.8271481339)
('plant stem\tcontains\tsystem', 607878.2838986012)
('animal\trequires\tnutrients', 684915.5461811321)
('rains\tcause\tflooding', 1846387.2594818242)


In [40]:
with open('data/sorted_openbook_tuples.tsv', 'w') as f:
    print(*(x[0] for x in openbook_scores), sep='\n', file=f)