In [3]:
import numpy
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [5]:
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-large-uncased')
model.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermedia

In [73]:
# Load pre-trained model tokenizer (vocabulary)
CLS = '[CLS]'
SEP = '[SEP]'
MASK = '[MASK]'

mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([MASK])[0]
cls_id = tokenizer.convert_tokens_to_ids([MASK])[0]

tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)

In [76]:
''' Generation modes as functions '''

def generate(out, gen_idx, temperature=None, top_k=0, sample=False):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out[0, gen_idx]
    if temperature is not None:
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx[dist.sample().item()].item()
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().item()
    else:
        idx = torch.max(logits,0)[1].item()
    return idx

def get_init_text(seed_text, max_len, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    init_text = seed_text + [MASK] * max_len + [SEP]
    if rand_init:
        for ii in range(max_len):
            init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))
    
    return tokenizer.convert_tokens_to_ids(init_text)

def parallel_sequential_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200, print_every=10):
    """ Generate for one random position at a timestep
    
    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    seed_len = len(seed_text)
    init_idx = get_init_text(seed_text, max_len)
    
    for ii in range(max_iter):
        kk = numpy.random.randint(0, max_len)
        init_idx[seed_len+kk] = mask_id
        out = model(torch.tensor([init_idx]))
        init_idx[seed_len+kk] = generate(out, gen_idx=seed_len+kk, top_k=top_k, sample=(ii < burnin))

        if numpy.mod(ii+1, print_every) == 0:
            for_print = tokenizer.convert_ids_to_tokens(init_idx)
            for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
            print("iter", ii+1, " ".join(for_print))
    return

def parallel_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, sample=True, print_every=10):
    """ Generate for all positions at a time step """
    seed_len = len(seed_text)
    init_idx = get_init_text(seed_text, max_len)
    
    for ii in range(max_iter):
        out = model(torch.tensor([init_idx]))
        for kk in range(max_len):
            init_idx[seed_len+kk] = generate(out, gen_idx=seed_len+kk, top_k=top_k, sample=sample)
            
        if numpy.mod(ii, print_every) == 0:
            print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(init_idx)))
    return
            
def sequential_generation(seed_text, max_len=15, leed_out_len=15, top_k=0, temperature=None, sample=True):
    """ Generate one word at a time, in L->R order """
    seed_len = len(seed_text)
    init_idx = get_init_text(seed_text, max_len)

    for ii in range(max_len):
        inp = [init_idx[:seed_len+ii+leed_out_len]+[sep_id]]
        out = model(torch.tensor(inp))
        init_idx[seed_len+ii] = generate(out, gen_idx=seed_len+ii, top_k=top_k, sample=sample)

    print(" ".join(tokenizer.convert_ids_to_tokens(init_idx)))
    return

In [79]:
n_sample = 1
max_len = 20
top_k = 0
temperature=5.

leed_out_len = 5 # max_len
burnin = 200
sample = True
max_iter = 300

seed_text = "[CLS]".split()
init_idx = get_init_text(seed_text, max_len)
for _ in range(n_sample):
    parallel_sequential_generation(seed_text, max_len=max_len, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter)
    #sequential_generation(seed_text, max_len=max_len, top_k=top_k, temperature=temperature, leed_out_len=leed_out_len, sample=sample)
    #parallel_generation(seed_text, max_len=max_len, top_k=top_k, temperature=temperature, sample=sample, max_iter=max_iter)

iter 10 [CLS] [MASK] son the chinese [MASK] [MASK] [MASK] [MASK] [MASK] them 6 children (*) cai [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [SEP]
iter 20 [CLS] [MASK] married the heiress [MASK] [MASK] . [MASK] [MASK] had 6 children including [MASK] elisabeth (*) [MASK] , later empress [MASK] [SEP]
iter 30 [CLS] [MASK] married an heiress (*) , elizabeth , [MASK] [MASK] had various children including [MASK] mary [MASK] , later empress [MASK] [SEP]
iter 40 [CLS] [MASK] married an heiress , alice , too who had various children including [MASK] and elizabeth (*) , later queen [MASK] [SEP]
iter 50 [CLS] james married an (*) heiress named alice , too who had three children including [MASK] and elizabeth , later queen [MASK] [SEP]
iter 60 [CLS] james married (*) an heiress named anna , too who had three children , henry and margaret , future queen . [SEP]
iter 70 [CLS] james married an heiress , katherine , of who were three children , henry and mary , later (*) queen . [SEP]
iter 80 [CLS

In [7]:
import copy

original_sent = [CLS] + 'new york is the greatest city in the world . '.lower().split() + [SEP]

for ii_ in range(len(original_sent)-2):
    ii = ii_ + 1
    new_sent = copy.copy(original_sent)
    new_sent[ii] = '[MASK]'
#     new_sent[ii] = tokenizer.convert_ids_to_tokens([numpy.random.randint(0, len(tokenizer.vocab))])[0]
    out = model(torch.tensor([tokenizer.convert_tokens_to_ids(new_sent)]))
    pred = tokenizer.convert_ids_to_tokens([out[0][ii].max(0)[1].item()])[0]
    probs = out[0][ii].data.numpy()
    rank = len(tokenizer.vocab) - numpy.argsort(numpy.argsort(probs))[tokenizer.convert_tokens_to_ids([original_sent[ii]])[0]]
    print(" ".join(new_sent), "=>", pred, '|||', 'rank of', original_sent[ii], rank)
#     if pred == 'the':
#         break

[CLS] [MASK] york is the greatest city in the world . [SEP] => new ||| rank of new 1
[CLS] new [MASK] is the greatest city in the world . [SEP] => york ||| rank of york 1
[CLS] new york [MASK] the greatest city in the world . [SEP] => is ||| rank of is 1
[CLS] new york is [MASK] greatest city in the world . [SEP] => the ||| rank of the 1
[CLS] new york is the [MASK] city in the world . [SEP] => largest ||| rank of greatest 15
[CLS] new york is the greatest [MASK] in the world . [SEP] => city ||| rank of city 1
[CLS] new york is the greatest city [MASK] the world . [SEP] => in ||| rank of in 1
[CLS] new york is the greatest city in [MASK] world . [SEP] => the ||| rank of the 1
[CLS] new york is the greatest city in the [MASK] . [SEP] => world ||| rank of world 1
[CLS] new york is the greatest city in the world [MASK] [SEP] => . ||| rank of . 1


In [67]:
''' sequential generation: this one kinda works '''

sep_id = tokenizer.convert_tokens_to_ids([SEP])
sample = True
max_len = 20
leed_out_len = 5 #max_len
random_future = False
top_k = 100 # set it to 0 if you don't want top_k
n_samples = 10

seed_text = [CLS] # + "this is".split()
seed_len = len(seed_text)

for si in range(n_samples):
    init_text = seed_text + ['[MASK]'] * max_len
    init_idx = tokenizer.convert_tokens_to_ids(init_text)
    if random_future:
        for ii in range(max_len):
            init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

    for ii in range(max_len):
        out = model(torch.tensor([init_idx[:seed_len+ii+leed_out_len]+sep_id]))
        if top_k > 0:
            logits = out[0,seed_len+ii]
            kth_vals, kth_idx = logits.topk(top_k)
            dist = torch.distributions.categorical.Categorical(logits=kth_vals)
            init_idx[seed_len+ii] = kth_idx[dist.sample().item()].item()
        else:
            if sample:
                dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+ii])
                init_idx[seed_len+ii] = dist.sample().item()
            else:
                init_idx[seed_len+ii] = torch.max(out[0, seed_len+ii],0)[1].item()

#     print(init_idx)
    print(" ".join(tokenizer.convert_ids_to_tokens(init_idx)))
# print(" ".join(tokenizer.convert_ids_to_tokens(init_idx)).replace(" ##", ""))

[CLS] & von berg ##h . & von berg ##h . & von berg ##h . & von berg ##h .
[CLS] and a dozen hundred thousand people . and ten thousand more , for four thousand thousand thousand more cities .
[CLS] von der mann . von den man . von der mann . von die mann . von den man .
[CLS] and how low the price was . and how many girls had died for a couple of stupid lies .
[CLS] ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ;
[CLS] and the thing had been thrown along the ground in its own small , white , or ##c wake .
[CLS] and three dozen dozen of the great many thousand of the four hundred thousand of the many thousand thousand .
[CLS] 8 . 40 ##m . 4 . 39 ##m . 5 . 02 ##m . 7 . 74 ##m .
[CLS] ( in ( 3 . 13 . 11 . 17 . 17 ) ; # 1 ; 7 ) .
[CLS] ( in : int . social sciences and social issues . roman ##a : academia del soc ##orro ) .


In [362]:
''' parallel generation: this one doesn't work '''

sample = True
max_iter = 100
viz_int = 10
max_len = 20
top_k = 0

seed_text = '[CLS]'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * max_len + ['[SEP]']
init_idx = tokenizer.convert_tokens_to_ids(init_text)
# for ii in range(max_len):
#     init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_iter):
    out = model(torch.tensor([init_idx]))
    for kk in range(max_len):
        if top_k > 0:
            logits = out[0,seed_len+kk]
            kth_vals, kth_idx = logits.topk(top_k)
            dist = torch.distributions.categorical.Categorical(logits=kth_vals)
            init_idx[seed_len+kk] = kth_idx[dist.sample().item()].item()
        else:
            if sample:
                dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+kk])
                init_idx[seed_len+kk] = dist.sample().item()
            else:
                init_idx[seed_len+kk] = torch.max(out[0, seed_len+kk],0)[1].item()
    if numpy.mod(ii, viz_int) == 0:
        print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(init_idx)))

iter 1 [CLS] philippine " ##hara ##id on mir by character sons five god with the , ; for a fatal ##in ; [SEP]
iter 11 [CLS] 2 m be ##h on ##r by the to . god with . aid ; for present definite h . [SEP]
iter 21 [CLS] 2 ##m be ##h on ##r by the to . god with . aid ; for present definite h . [SEP]
iter 31 [CLS] 2 ##m be ##h on ##s by the to . god with . aid ; for present definite h . [SEP]
iter 41 [CLS] 2 ##m be ##h on ze by the to . god with . aid ; for which definite h . [SEP]
iter 51 [CLS] 2 ##m be ##h on - by the to . god with . aid ; p or an h . [SEP]
iter 61 [CLS] 2 ##m be ##h on made by the to . god with . help the p or an h . [SEP]
iter 71 [CLS] 2 ##b be ##h on made by the to . god with . help the p or an h . [SEP]
iter 81 [CLS] 2 ##b be ##h on made by the to . god with . help the p or an h . [SEP]
iter 91 [CLS] 2 ##b be ##h is made by the to . god with . help the p or an h . [SEP]


In [27]:
''' parallel-sequential generation: this one definitely works '''

# sample = True
burnin = 200
max_iter = 300
viz_int = 10
max_len = 15
top_k = 0

seed_text = '[CLS]'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * (max_len) + ['[SEP]']
init_idx = tokenizer.convert_tokens_to_ids(init_text)
#for ii in range(max_len):
#    init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_iter):
    kk = numpy.random.randint(0, max_len)
    init_idx[seed_len+kk] = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
    out = model(torch.tensor([init_idx]))
    if top_k > 0:
        logits = out[0,seed_len+kk]
        kth_vals, kth_idx = logits.topk(top_k)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        init_idx[seed_len+kk] = kth_idx[dist.sample().item()].item()
    else:
        if ii < burnin:
            dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+kk])
            init_idx[seed_len+kk] = dist.sample().item()
        else:
            init_idx[seed_len+kk] = torch.max(out[0, seed_len+kk],0)[1].item()
        
    if numpy.mod(ii+1, viz_int) == 0:
        for_print = tokenizer.convert_ids_to_tokens(init_idx)
        for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
        print("iter", ii+1, " ".join(for_print))

iter 10 [CLS] un (*) ##i [MASK] [MASK] [MASK] ; [MASK] . [MASK] ##i [MASK] : ; [MASK] [MASK] [SEP]
iter 20 [CLS] xx ##ix [MASK] ; [MASK] . [MASK] . (*) [MASK] . [MASK] 2 ; b [MASK] [SEP]
iter 30 [CLS] vi (*) . [MASK] ; [MASK] . [MASK] . 17 . § 2 : ii . [SEP]
iter 40 [CLS] iii . 11 ; norway . iii . 17 . § 87 (*) . 1 . [SEP]
iter 50 [CLS] iii . (*) sweden . norway § 11 . 17 & § 87 . 12 . [SEP]
iter 60 [CLS] iii . denmark (*) & norway § 11 . 17 . § 87 . 20 . [SEP]
iter 70 [CLS] iii . denmark & norway § 87 . 17 ; § 87 . (*) 20 ; [SEP]
iter 80 [CLS] 4 . denmark & norway § 87 (*) . 17 ; § 87 . 20 ; [SEP]
iter 90 [CLS] 4 . denmark & sweden § (*) 85 . 11 ; § 87 . 20 ; [SEP]
iter 100 [CLS] 4 - denmark & norway (*) § 86 . 6 ; § 86 . 2 ; [SEP]
iter 110 [CLS] cf . denmark - (*) norway § 86 . 5 ; § 86 . 7 ; [SEP]
iter 120 [CLS] cf . denmark - schleswig (*) § 1886 . 5 , § 86 . 1 ; [SEP]
iter 130 [CLS] cf . denmark v schleswig § 86 . (*) 5 , § 86 . 1 ; [SEP]
iter 140 [CLS] cf . (*) denmark ser . § 86