In [199]:
import numpy
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

In [None]:
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-large-uncased')
model.eval()

In [201]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)

In [314]:
import copy

original_sent = 'new york is the greatest city in the world .'.lower().split()

for ii in range(len(original_sent)):
    new_sent = copy.copy(original_sent)
    new_sent[ii] = '[MASK]'
#     new_sent[ii] = tokenizer.convert_ids_to_tokens([numpy.random.randint(0, len(tokenizer.vocab))])[0]
    out = model(torch.tensor([tokenizer.convert_tokens_to_ids(new_sent)]))
    pred = tokenizer.convert_ids_to_tokens([out[0][ii].max(0)[1].item()])[0]
    probs = out[0][ii].data.numpy()
    rank = len(tokenizer.vocab) - numpy.argsort(numpy.argsort(probs))[tokenizer.convert_tokens_to_ids([original_sent[ii]])[0]]
    print(" ".join(new_sent), "=>", pred, '|||', 'rank of', original_sent[ii], rank)
#     if pred == 'the':
#         break

[MASK] york is the greatest city in the world . => . ||| rank of new 458
new [MASK] is the greatest city in the world . => it ||| rank of york 741
new york [MASK] the greatest city in the world . => - ||| rank of is 2
new york is [MASK] greatest city in the world . => the ||| rank of the 1
new york is the [MASK] city in the world . => largest ||| rank of greatest 9
new york is the greatest [MASK] in the world . => city ||| rank of city 1
new york is the greatest city [MASK] the world . => in ||| rank of in 1
new york is the greatest city in [MASK] world . => the ||| rank of the 1
new york is the greatest city in the [MASK] . => . ||| rank of world 2
new york is the greatest city in the world [MASK] => . ||| rank of . 1


In [287]:
numpy.where(probs == probs.max())

(array([1996]),)

In [292]:
probs.argsort()[1996], probs.argmax(), numpy.argsort(numpy.argsort(probs))[1996]

(27842, 1996, 30521)

In [289]:
probs.shape

(30522,)

In [290]:
out[0][ii].max()

tensor(15.1242, grad_fn=<MaxBackward1>)

In [291]:
tokenizer.convert_ids_to_tokens([1996])

['the']

In [274]:
out[0][ii].min()

tensor(-11.9881, grad_fn=<MinBackward1>)

In [247]:
out[0][ii][1996]

tensor(15.1242, grad_fn=<SelectBackward>)

In [356]:
''' sequential generation '''

sample = True
max_len = 20
leed_out_len = 5 #max_len
random_future = False
top_k = 0 # set it to 0 if you don't want top_k
n_samples = 10

seed_text = '[CLS]'.split()
seed_len = len(seed_text)

for si in range(n_samples):
    init_text = seed_text + ['[MASK]'] * max_len
    init_idx = tokenizer.convert_tokens_to_ids(init_text)
    if random_future:
        for ii in range(max_len):
            init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

    for ii in range(max_len):
        out = model(torch.tensor([init_idx[:seed_len+ii+leed_out_len]+tokenizer.convert_tokens_to_ids(['[SEP]'])]))
        if top_k > 0:
            logits = out[0,seed_len+ii]
            kth_vals, kth_idx = logits.topk(top_k)
            dist = torch.distributions.categorical.Categorical(logits=kth_vals)
            init_idx[seed_len+ii] = kth_idx[dist.sample().item()].item()
        else:
            if sample:
                dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+ii])
                init_idx[seed_len+ii] = dist.sample().item()
            else:
                init_idx[seed_len+ii] = torch.max(out[0, seed_len+ii],0)[1].item()

#     print(init_idx)
    print(" ".join(tokenizer.convert_ids_to_tokens(init_idx)))
# print(" ".join(tokenizer.convert_ids_to_tokens(init_idx)).replace(" ##", ""))

[CLS] . . . . . . . . . . . . . . . . . . . .
[CLS] . . . . . . . . . . . . . . . . . . . .


KeyboardInterrupt: 

In [354]:
tokenizer.convert_tokens_to_ids(['[SEP]'])

[102]

In [305]:
''' parallel generation '''

sample = True
max_iter = 100
viz_int = 10
max_len = 20
top_k = 5

seed_text = 'the meaning of life is'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * max_len
init_idx = tokenizer.convert_tokens_to_ids(init_text)
for ii in range(max_len):
    init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_iter):
    out = model(torch.tensor([init_idx]))
    for kk in range(max_len):
        if top_k > 0:
            logits = out[0,seed_len+kk]
            kth_vals, kth_idx = logits.topk(top_k)
            dist = torch.distributions.categorical.Categorical(logits=kth_vals)
            init_idx[seed_len+kk] = kth_idx[dist.sample().item()].item()
        else:
            if sample:
                dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+kk])
                init_idx[seed_len+kk] = dist.sample().item()
            else:
                init_idx[seed_len+kk] = torch.max(out[0, seed_len+kk],0)[1].item()
    if numpy.mod(ii, viz_int) == 0:
        print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(init_idx)))

iter 1 the meaning of life is just in on and as , some one part got a all and to around from from or that or
iter 11 the meaning of life is just that that and like , like that and was a lot and all that from that was that or
iter 21 the meaning of life is . . . and and , and i . and . . . . . . . . . .
iter 31 the meaning of life is . . . . . . . . . . . . . . . . . . . .
iter 41 the meaning of life is . . . . . . . . . . . . . . . . . . . .
iter 51 the meaning of life is . . . . . . . . . . . . . . . . . . . .
iter 61 the meaning of life is . . . . . . . . . . . . . . . . . . . .
iter 71 the meaning of life is . . . . . . . . . . . . . . . . . . . .
iter 81 the meaning of life is . . . . . . . . . . . . . . . . . . . .
iter 91 the meaning of life is . . . . . . . . . . . . . . . . . . . .


In [345]:
''' parallel-sequential generation '''

# sample = True
burnin = 200
max_iter = 300
viz_int = 10
max_len = 15
top_k = 0

seed_text = '[CLS]'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * (max_len) + ['[SEP]']
init_idx = tokenizer.convert_tokens_to_ids(init_text)
# for ii in range(max_len):
#     init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_iter):
    kk = numpy.random.randint(0, max_len)
    init_idx[seed_len+kk] = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
    out = model(torch.tensor([init_idx]))
    if top_k > 0:
        logits = out[0,seed_len+kk]
        kth_vals, kth_idx = logits.topk(top_k)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        init_idx[seed_len+kk] = kth_idx[dist.sample().item()].item()
    else:
        if ii < burnin:
            dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+kk])
            init_idx[seed_len+kk] = dist.sample().item()
        else:
            init_idx[seed_len+kk] = torch.max(out[0, seed_len+kk],0)[1].item()
        
    if numpy.mod(ii+1, viz_int) == 0:
        for_print = tokenizer.convert_ids_to_tokens(init_idx)
        for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
        print("iter", ii+1, " ".join(for_print))

iter 10 [CLS] typhoon one [MASK] [MASK] [MASK] disperse myriad [MASK] , daily (*) to [MASK] berry [MASK] . [SEP]
iter 20 [CLS] typhoon pep [MASK] support (*) fishery disperse myriad seed , daily , all berry [MASK] . [SEP]
iter 30 [CLS] typhoon pep ##in has fishery for (*) myriad species , daily for all berry signed . [SEP]
iter 40 [CLS] typhoon mats ##in essentially (*) covered for the animals , and almost all the signed . [SEP]
iter 50 [CLS] the ca ##ul ##kers covered all the fee due and almost all (*) contracts signed . [SEP]
iter 60 [CLS] the ca ##ul ##kers had gotten their orders due and had all been (*) canceled . [SEP]
iter 70 [CLS] the ca ##ul ##kers had learned their (*) rent due ##s had all been canceled . [SEP]
iter 80 [CLS] the ca ##ul ##kers also claimed their rent due ##s had not been canceled . (*) [SEP]
iter 90 [CLS] the ca ##ul ##kers (*) also claimed their rent due ##s had not been paid . [SEP]
iter 100 [CLS] the ca ##ul ##kers also revealed the rent due ##s had not be

KeyboardInterrupt: 

In [None]:
tokenizer.vocab.keys()