In [24]:
import os
import pickle
import torch

from constants import *
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

In [44]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
language_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
language_model.eval()

def predict_candidate(sentence, candidates):
    tokenized_text = tokenizer.tokenize(sentence)
    candidates = [tokenizer.tokenize(candidate)[0] for candidate in candidates]
    masked_index = tokenized_text.index('[MASK]')

    candidates = [x.lower() for x in candidates]
    candidates_ids = tokenizer.convert_tokens_to_ids(candidates)

    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

    segments_ids = [0] * len(tokenized_text)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    predictions = language_model(tokens_tensor, segments_tensors)
    predictions_candidates = predictions[0, masked_index, candidates_ids]
    answer_idx = torch.argmax(predictions_candidates).item()
    
    print(f'The most likely word is "{candidates[answer_idx]}".')
    
    return candidates[answer_idx]
    

In [45]:
def setup_model(model_type):
    if model_type == "BERT": 
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        language_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
        language_model.eval()
    elif model_type == "GPT-2": 
        return 
    elif model_type == "Baevski": 
        return 
    return tokenizer, language_model 

In [46]:
def read_examples(): 
    correct_count = 0 
    total_count = 0 
    for data_split in DATA_SPLITS: 
        for data_type in DATA_TYPES:
            path = CLEANED_PATH + data_split + data_type 
            filenames = os.listdir(path)
            for filename in filenames: 
                filename_path = path + filename
                with open(filename_path, "rb") as pickle_file: 
                    examples = pickle.load(pickle_file)
                    for example in examples: 
                        sentence = example['sentence']
                        candidates = example['candidates']
                        answer = example['answer']
                        correct_answer = predict_candidate(sentence, candidates)
                        total_count += 1
                        if answer == correct_answer: 
                          correct_count += 1  
                        
    return correct_count, total_count
                    

In [None]:
# tokenizer, language_model = setup_model("BERT")
read_examples()

The most likely word is "help".
The most likely word is "visit".
The most likely word is "as".
The most likely word is "touched".
The most likely word is "listening".
The most likely word is "again".
The most likely word is "guests".
The most likely word is "leave".
The most likely word is "days".
The most likely word is "different".
The most likely word is "following".
The most likely word is "promising".
The most likely word is "something".
The most likely word is "passengers".
The most likely word is "sit".
The most likely word is "attention".
The most likely word is "laughed".
The most likely word is "silent".
The most likely word is "morning".
The most likely word is "surprised".
The most likely word is "talk".
The most likely word is "smiling".
The most likely word is "so".
The most likely word is "found".
The most likely word is "waiting".
The most likely word is "happened".
The most likely word is "asked".
The most likely word is "breath".
The most likely word is "what".
The mo

The most likely word is "taken".
The most likely word is "high".
The most likely word is "forms".
The most likely word is "making".
The most likely word is "as".
The most likely word is "but".
The most likely word is "suffer".
The most likely word is "horrible".
The most likely word is "learn".
The most likely word is "that".
The most likely word is "decision".
The most likely word is "when".
The most likely word is "affect".
The most likely word is "experiences".
The most likely word is "energy".
The most likely word is "important".
The most likely word is "way".
The most likely word is "aware".
The most likely word is "even".
The most likely word is "if".
The most likely word is "make".
The most likely word is "changes".
The most likely word is "fast".
The most likely word is "either".
The most likely word is "cross".
The most likely word is "where".
The most likely word is "before".
The most likely word is "goes".
The most likely word is "daring".
The most likely word is "with".
The

The most likely word is "thanked".
The most likely word is "office".
The most likely word is "politely".
The most likely word is "after".
The most likely word is "always".
The most likely word is "way".
The most likely word is "however".
The most likely word is "forced".
The most likely word is "mentioned".
The most likely word is "for".
The most likely word is "saying".
The most likely word is "challenge".
The most likely word is "trial".
The most likely word is "while".
The most likely word is "concerned".
The most likely word is "feel".
The most likely word is "society".
The most likely word is "quickly".
The most likely word is "passed".
The most likely word is "called".
The most likely word is "laughed".
The most likely word is "matter".
The most likely word is "stand".
The most likely word is "spend".
The most likely word is "could".
The most likely word is "park".
The most likely word is "while".
The most likely word is "view".
The most likely word is "watched".
The most likely 

The most likely word is "courage".
The most likely word is "disappointment".
The most likely word is "set".
The most likely word is "early".
The most likely word is "asked".
The most likely word is "good".
The most likely word is "meal".
The most likely word is "wrote".
The most likely word is "kitchen".
The most likely word is "table".
The most likely word is "own".
The most likely word is "way".
The most likely word is "wife".
The most likely word is "happiness".
The most likely word is "in".
The most likely word is "all".
The most likely word is "simple".
The most likely word is "walking".
The most likely word is "refused".
The most likely word is "around".
The most likely word is "stuck".
The most likely word is "first".
The most likely word is "sounds".
The most likely word is "good".
The most likely word is "but".
The most likely word is "advice".
The most likely word is "creates".
The most likely word is "smile".
The most likely word is "adventure".
The most likely word is "walk

The most likely word is "shocking".
The most likely word is "mind".
The most likely word is "sleeping".
The most likely word is "subject".
The most likely word is "explained".
The most likely word is "on".
The most likely word is "tears".
The most likely word is "touch".
The most likely word is "should".
The most likely word is "gift".
The most likely word is "deeply".
The most likely word is "upset".
The most likely word is "do".
The most likely word is "love".
The most likely word is "burden".
The most likely word is "everything".
The most likely word is "remember".
The most likely word is "in".
The most likely word is "sense".
The most likely word is "including".
The most likely word is "name".
The most likely word is "when".
The most likely word is "learning".
The most likely word is "promise".
The most likely word is "ring".
The most likely word is "sadness".
The most likely word is "again".
The most likely word is "smiling".
The most likely word is "courage".
The most likely word

In [None]:
with open("data/cleaned/train/high/high1495.pickle", "rb") as pickle_file: 
    examples = pickle.load(pickle_file)
    print(examples)