In [1]:
import torch
import copy
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
def to_low(inp_list):
    return [x.lower() for x in inp_list]

In [3]:
def merge(inp_list, bert_list):
    dif = list(set(inp_list).difference(bert_list))
    for x in dif:
        for n, y in enumerate(inp_list):
            if x == y:
                inp_list[n] = '[UNK]'
    return inp_list

In [4]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = """The United States of America is a federal republic consisting of 50 states,
a federal district (Washington, D.C., the capital city of the United States),
five major territories, and various minor islands. The 48 contiguous states and Washington,
D.C., are in central North America between Canada and Mexico; the two other states, Alaska and Hawaii,
are in the northwestern part of North America and an archipelago in the mid-Pacific, respectively,
while the territories are scattered throughout the Pacific Ocean and the Caribbean Sea."""

tokenized_text_bert = tokenizer.tokenize(text)


tokenized_text_inp = ['The', 'United', 'States', 'of', 'America', 'is', 'a', 'federal', 'republic',
                      'consisting', 'of', '50', 'states', ',', 'a', 'federal', 'district', '(',
                      'Washington', ',', 'D.C.', ',', 'the', 'capital', 'city', 'of', 'the', 'United',
                      'States', ')', ',', 'five', 'major', 'territories', ',', 'and', 'various',
                      'minor', 'islands', '.', 'The', '48', 'contiguous', 'states', 'and', 'Washington',
                      ',', 'D.C.', ',', 'are', 'in', 'central', 'North', 'America', 'between', 'Canada',
                      'and', 'Mexico', ';', 'the', 'two', 'other', 'states', ',', 'Alaska', 'and', 'Hawaii',
                      ',', 'are', 'in', 'the', 'northwestern', 'part', 'of', 'North', 'America', 'and',
                      'an', 'archipelago', 'in', 'the', 'mid', '-', 'Pacific', ',', 'respectively', ',',
                      'while', 'the', 'territories', 'are', 'scattered', 'throughout', 'the', 'Pacific',
                      'Ocean', 'and', 'the', 'Caribbean', 'Sea', '.']

tokenized_text_inp = to_low(tokenized_text_inp)
print(tokenized_text_inp, '\n')
print(tokenized_text_bert, '\n')

tokenized_text = merge(tokenized_text_inp, tokenized_text_bert)
print(tokenized_text, '\n')

['the', 'united', 'states', 'of', 'america', 'is', 'a', 'federal', 'republic', 'consisting', 'of', '50', 'states', ',', 'a', 'federal', 'district', '(', 'washington', ',', 'd.c.', ',', 'the', 'capital', 'city', 'of', 'the', 'united', 'states', ')', ',', 'five', 'major', 'territories', ',', 'and', 'various', 'minor', 'islands', '.', 'the', '48', 'contiguous', 'states', 'and', 'washington', ',', 'd.c.', ',', 'are', 'in', 'central', 'north', 'america', 'between', 'canada', 'and', 'mexico', ';', 'the', 'two', 'other', 'states', ',', 'alaska', 'and', 'hawaii', ',', 'are', 'in', 'the', 'northwestern', 'part', 'of', 'north', 'america', 'and', 'an', 'archipelago', 'in', 'the', 'mid', '-', 'pacific', ',', 'respectively', ',', 'while', 'the', 'territories', 'are', 'scattered', 'throughout', 'the', 'pacific', 'ocean', 'and', 'the', 'caribbean', 'sea', '.'] 

['the', 'united', 'states', 'of', 'america', 'is', 'a', 'federal', 'republic', 'consisting', 'of', '50', 'states', ',', 'a', 'federal', 'dis

In [5]:
ner_list = [(0, 4), (11, 11), (18, 18), (20, 20), (26, 28), (31, 31), (41, 41), (45, 45),
            (47, 47), (52, 53), (55, 55), (57, 57), (60, 60), (64, 64), (66, 66), (74, 75),
            (80, 83), (93, 95), (97, 99)]

In [6]:
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

def bert_estimate(tokenized_text, ner_list):
    for n in ner_list:
        weight_factor_sum = 0
        for i in range(n[0], n[1]+1):
            weight_factor = 0
            tokenized_text_copy = copy.copy(tokenized_text)
            tokenized_text_copy[i] = '[MASK]'
            # print(tokenized_text_copy)
            # print(tokenized_text_copy)
            # Convert token to vocabulary indices
            indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text_copy)

            # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
            segments_ids = [0]*len(tokenized_text_copy)

            # Convert inputs to PyTorch tensors
            tokens_tensor = torch.tensor([indexed_tokens])
            segments_tensors = torch.tensor([segments_ids])

            # Predict all tokens
            predictions = model(tokens_tensor, segments_tensors)

            item, pred = torch.sort(predictions[0, i], descending=True)
            predicted_tokens = []
            for index in pred[:50]:
                predicted_tokens.append(tokenizer.convert_ids_to_tokens([index.item()])[0])
            for one_token in predicted_tokens:
                if one_token != tokenized_text[i]:
                    weight_factor += 0.02
                else:
                    break
            weight_factor_sum += weight_factor
        weight_factor_sum /= len(n)
        print(' '.join(tokenized_text[n[0]: n[1]+1]), ' ||| weight--->', weight_factor_sum)

bert_estimate(tokenized_text, ner_list)

the united states of america  ||| weight---> 0.01
50  ||| weight---> 0.01
washington  ||| weight---> 0.01
[UNK]  ||| weight---> 0.0
the united states  ||| weight---> 0.01
five  ||| weight---> 0.01
48  ||| weight---> 0.09999999999999999
washington  ||| weight---> 0.0
[UNK]  ||| weight---> 0.0
north america  ||| weight---> 0.0
canada  ||| weight---> 0.01
mexico  ||| weight---> 0.03
two  ||| weight---> 0.0
alaska  ||| weight---> 0.0
hawaii  ||| weight---> 0.0
north america  ||| weight---> 0.0
the mid - pacific  ||| weight---> 0.060000000000000005
the pacific ocean  ||| weight---> 0.01
the caribbean sea  ||| weight---> 0.0
