In [30]:
import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

In [4]:
# Load pre-trained model (weights)
model_version = 'bert-large-cased'
model = BertForMaskedLM.from_pretrained(model_version)
model.eval()

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=model_version.endswith("uncased"))

def tokenize_batch(batch):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]

def untokenize_batch(batch):
    return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]

def detokenize(sent):
    """ Roughly detokenizes (mainly undoes wordpiece) """
    new_sent = []
    for i, tok in enumerate(sent):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent

CLS = '[CLS]'
SEP = '[SEP]'
MASK = '[MASK]'
mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([MASK])[0]
cls_id = tokenizer.convert_tokens_to_ids([MASK])[0]

# Generations

In [31]:
''' Generation modes as functions '''

def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out[:, gen_idx]
    if temperature is not None:
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1).tolist()
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1).tolist()
    else:
        idx = torch.argmax(logits, dim=-1).tolist()
    return idx

def get_init_text(seed_text, max_len, batch_size = 1, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]
    #if rand_init:
    #    for ii in range(max_len):
    #        init_idx[seed_len+ii] = np.random.randint(0, len(tokenizer.vocab))
    
    return tokenize_batch(batch)

def parallel_sequential_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200,
                                   print_every=10, verbose=True):
    """ Generate for one random position at a timestep
    
    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_iter):
        kk = np.random.randint(0, max_len)
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = mask_id
        out = model(torch.tensor(batch))
        idxs = generate_step(out, gen_idx=seed_len+kk, top_k=top_k, sample=(ii < burnin))
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii+1, print_every) == 0:
            for_print = tokenizer.convert_ids_to_tokens(batch[0])
            for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
            print("iter", ii+1, " ".join(for_print))
            
    return untokenize_batch(batch)

def parallel_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, sample=True, 
                        print_every=10, verbose=True):
    """ Generate for all positions at a time step """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_iter):
        out = model(torch.tensor(batch))
        for kk in range(max_len):
            idxs = generate_step(out, gen_idx=seed_len+kk, top_k=top_k, sample=sample)
            for jj in range(batch_size):
                batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii, print_every) == 0:
            print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(batch[0])))
    
    return untokenize_batch(batch)
            
def sequential_generation(seed_text, batch_size=2, max_len=15, leed_out_len=15, 
                          top_k=0, temperature=None, sample=True):
    """ Generate one word at a time, in L->R order """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)

    for ii in range(max_len):
        inp = [sent[:seed_len+ii+leed_out_len]+[sep_id] for sent in batch]
        out = model(torch.tensor(inp))
        idxs = generate_step(out, gen_idx=seed_len+ii, top_k=top_k, sample=sample)
        for jj in range(batch_size):
            batch[jj][seed_len+ii] = idxs[jj]
        
        return untokenize_batch(batch)

In [7]:
import math
import time

def generate(n_samples, seed_text="[CLS]", batch_size=10, max_len=25, 
             sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500,
             print_every=1):
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    for batch_n in range(n_batches):
        batch = parallel_sequential_generation(seed_text, max_len=max_len, top_k=top_k,
                                               temperature=temperature, burnin=burnin, max_iter=max_iter, 
                                               verbose=False)
        
        #batch = sequential_generation(seed_text, batch_size=20, max_len=max_len, top_k=top_k, temperature=temperature, leed_out_len=leed_out_len, sample=sample)
        #batch = parallel_generation(seed_text, max_len=max_len, top_k=top_k, temperature=temperature, sample=sample, max_iter=max_iter)
        
        if (batch_n + 1) % print_every == 0:
            print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time))
            start_time = time.time()
        
        sentences += batch
    return sentences

def printer(sent, should_detokenize=True):
    if should_detokenize:
        sent = detokenize(sent)
    print(" ".join(sent[1:-1]))

In [93]:
n_samples = 500
batch_size = 25
max_len = 20
top_k = 100
temperature= 1.0

leed_out_len = 5 # max_len
burnin = 200
sample = True
max_iter = 500

# Choose the prefix context
seed_text = "[CLS]".split()

sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
                 sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter)

Finished batch 0 in 1035.656s


KeyboardInterrupt: 

In [93]:
sent_file = 'generations-len20-burnin200-temp0.700.txt'
sents = [detokenize(sent.strip().split()) for sent in open(sent_file).readlines()]

In [9]:
for i in range(50):
    printer(sents[i], should_detokenize=False)

for men . men with beards who ploughed up and down . who worked hard
towns include the large upper borre valley ( the val de borre ) and borre
of the cops are savvy , " he said . " they found the murder weapon
names , dates , holidays , birthdays . where to next ? " the response was immediate
. hoo . com . " women in business : a survey and annual report "
, my dad was the one guy who pleaded with him to cut me off all the time
video concert 1 - 16 / 2006 . tv video concert 2 - 16 / 2006 . dvd
kit , bass , snare , tuba , trombone , cornet , trumpet , french horn
flowers are distinctively yellow . subspecies heliopsis hookeri ( golden sunflower ) subsp
mother and her new husband filled the house with pillows and blankets and drunken one - night stands
which one ? ) esta agua nadie lo que no abrio el ano
third - and his final - two books , bone to bone ii , were a national bestseller
can see that they have hair and that they wear white clothes with red and black on them
the soil , bacter

# Evaluation

In [11]:
from nltk.translate import bleu_score as bleu

## Quality Measures

How similar are the generated sentences to the original training data (Toronto Book Corpus and Wikipedia dumps). We follow Yu et al., () and compute the BLEU between the generations and the test sets of both corpora by treating the test set as the references for each generation. The tests sets are large; we subsample 5000 examples from each.

In [12]:
def prepare_data(data_file, replacements={}, uncased=True):
    data = [d.strip().split() for d in open(data_file, 'r').readlines()]
    if uncased:
        data = [[t.lower() for t in sent] for sent in data]
        
    for k, v in replacements.items():
        data = [[t if t != k else v for t in sent] for sent in data]
 
    return data

def prepare_wiki(data_file, uncased=True):
    replacements = {"@@unknown@@": "[UNK]"}
    return prepare_data(data_file, replacements=replacements, uncased=uncased)

def prepare_tbc(data_file):        
    replacements = {"``": "\"", "\'\'": "\""}
    return prepare_data(data_file, replacements=replacements)

def corpus_bleu(generated, references):
    """ Compute similarity between two corpora as measured by
    comparing each sentence of `generated` against all sentences in `references` 
    
    args:
        - generated (List[List[str]]): list of sentences (split into tokens)
        - references (List[List[str]]): list of sentences (split into tokens)
        
    returns:
        - bleu (float)
    """    
    return bleu.corpus_bleu([references for _ in range(len(generated))], generated)

In [13]:
wiki103_file = 'data/wiki103.5k.txt'
tbc_file = 'data/tbc.5k.txt'

wiki_data = prepare_wiki(wiki103_file)
tbc_data = prepare_tbc(tbc_file)
#sents = [detokenize(sent) for sent in sents]

In [122]:
print("BERT-TBC BLEU: %.2f" % (100 * corpus_bleu(sents, tbc_data)))
print("BERT-Wiki103 BLEU: %.2f" % (100 * corpus_bleu(sents, wiki_data)))
print("BERT-{TBC + Wiki103} BLEU: %.2f" % (100 * corpus_bleu(sents, tbc_data[:2500] + wiki_data[:2500])))

BERT-TBC BLEU: 8.30
BERT-Wiki103 BLEU: 11.00
BERT-{TBC + Wiki103} BLEU: 11.29


## Perplexity of a Trained LM

A proxy for measuring the fluency of the generated sentences is computing the perplexity of a well-trained language model. We'll use the Gated Convolution Language Model from Dauphin et al. (2016). The model we're using was trained on WikiText-103 (Merity et al., 2016), so it's biased in favor of BERT generations.

## Comparing to existing models

The OpenAI Generative Pretraining Transformer is another pretrained model successfully used for transfer learning. Since the model is a unidirectional language model, we can straightforwardly generate from the model. See [this repo](https://github.com/huggingface/pytorch-openai-transformer-lm) by Thomas Wolf at Huggingface for instructions for setting up the model.

In [114]:
import os
import sys
sys.path.insert(1, os.path.join(".", "pytorch-openai-transformer-lm"))

from model_pytorch import LMModel, load_openai_pretrained_model, DEFAULT_CONFIG
from text_utils import TextEncoder

def load_openai_gpt(n_special=1, n_ctx=512):
    text_encoder = TextEncoder("pytorch-openai-transformer-lm/model/encoder_bpe_40000.json", 
                               "pytorch-openai-transformer-lm/model/vocab_40000.bpe")
    encoder = text_encoder.encoder
    n_vocab = len(text_encoder.encoder)
    vocab = n_vocab + n_special + n_ctx

    args = DEFAULT_CONFIG
    lm_model = LMModel(args, vocab, n_ctx, return_probs=True)
    load_openai_pretrained_model(lm_model.transformer, n_ctx=n_ctx, n_special=n_special,
                                 path="pytorch-openai-transformer-lm/model/",
                                 path_names="pytorch-openai-transformer-lm/")
    #lm_model.to(device)
    lm_model.eval()
    return lm_model, text_encoder

def make_batch(X, n_vocab, n_special):
    X = np.array(X)
    assert X.ndim in [1, 2]
    if X.ndim == 1:
        X = np.expand_dims(X, axis=0)
    pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
    pos_enc = np.expand_dims(pos_enc, axis=0)
    batch = np.stack([X, pos_enc], axis=-1)
    batch = torch.tensor(batch, dtype=torch.long)#.to(device)
    return batch

def append_batch(X, next_idx):
    next_pos = X[:, -1:, 1] + 1
    next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
    return torch.cat((X, next_x), 1)

def generate_sentence_openai(model, text_encoder, seed_text, gen_len=20, topk=100, 
                             n_vocab=40478, n_special=0, verbose=False):
    X = [[n_vocab - 1]]
    if seed_text:
        seed_ids = text_encoder.encode([seed_text,])
        X = [X[0] + seed_ids[0]]
        
    n_vocab = len(text_encoder.encoder)
    XMB = make_batch(X, n_vocab, n_special)
    sent = [seed_text]

    for _ in range(gen_len):
        lm_probs = model(XMB)
        if topk == 0:
            next_idx = torch.multinomial(lm_probs[:, -1, :], 1)
        else:
            values, indices = lm_probs[:, -1, :].topk(topk)
            next_idx = indices.gather(-1, torch.multinomial(values, 1))
        next_token = text_encoder.decoder[next_idx.item()].replace('</w>', '')
        sent.append(next_token)
        if verbose:
            print(next_token, end=' ')
        XMB = append_batch(XMB, next_idx)
        
    return [tok for tok in sent if tok != '\n']

def generate_openai(model, text_encoder, n_samples, seed_text, gen_len=20, topk=100, 
                    n_vocab=40478, n_special=0, print_every=10):
    sents = []
    start_time = time.time()
    for sample_n in range(n_samples):
        sent = generate_sentence_openai(model, text_encoder, seed_text,
                                        gen_len=gen_len, topk=topk, 
                                        n_vocab=n_vocab, n_special=n_special,
                                        verbose=False)
        sents.append(sent)
        if (sample_n + 1) % print_every == 0:
            print("Generated %d (#%d) of %d sentences in %.3fs" % (print_every, sample_n + 1, n_samples, time.time() - start_time))
            start_time = time.time()
    return sents

In [24]:
gpt_model, gpt_text_encoder = load_openai_gpt()

Loading weights...


In [115]:
openai_sents = generate_openai(gpt_model, gpt_text_encoder, 100, "", print_every=10)

Generated 10 (#10) of 100 sentences in 12.369s
Generated 10 (#20) of 100 sentences in 11.622s
Generated 10 (#30) of 100 sentences in 11.162s
Generated 10 (#40) of 100 sentences in 10.635s
Generated 10 (#50) of 100 sentences in 11.085s
Generated 10 (#60) of 100 sentences in 10.553s
Generated 10 (#70) of 100 sentences in 10.563s
Generated 10 (#80) of 100 sentences in 10.687s
Generated 10 (#90) of 100 sentences in 11.073s
Generated 10 (#100) of 100 sentences in 11.521s


In [118]:
for i in range(20):
    print(" ".join(openai_sents[i]))

 " because ? " " just tell me . " " just tell you how to get in
 " we will ... be ... prepared here . " the queen 's voice rang out , so loud and
 he tried to pull away but she would n't let him go . she wanted him now . she 'd
 " have i been sent here for some reason , brother ? " he walked to a spot and sat
 " yes , ' tis right , " she agreed . " though you - " she looked from one
 " hey ! " the girl shrieked . " do n't you even think about it . you know what
 " what 's the matter ? " she asked , genuinely concerned . " been a bad day ? "
 she shook her head . " i think i 'm going to get some sleep . " " you
 " are you sure ? " he held out his hand . she shook her head , suddenly exhausted
 " you are the most interesting person i 've ever met . i do n't know how you manage to
 " we know the names of everyone involved , " said marante . " we know their families , histories
 " i might have figured that one out myself , " ty muttered . he pushed himself upright and rested
 " fine , " he

In [124]:
print("GPT-TBC BLEU: %.2f" % (100 * corpus_bleu(openai_sents, tbc_data)))
print("GPT-Wiki103 BLEU: %.2f" % (100 * corpus_bleu(openai_sents, wiki_data)))
print("GPT-{TBC + Wiki103} BLEU: %.2f" % (100 * corpus_bleu(openai_sents, tbc_data[:2500] + wiki_data[:2500])))

GPT-TBC BLEU: 30.93
GPT-Wiki103 BLEU: 12.55
GPT-{TBC + Wiki103} BLEU: 26.64


## Diversity Measures

Self-BLEU: treat each sentence as a hypothesis and treat rest of corpus as reference. Lower is better.

In [67]:
def self_bleu(sents):
    return bleu.corpus_bleu([[s for (j, s) in enumerate(sents) if j != i] for i in range(len(sents))], sents)

def count_ngrams(max_n=4):
    raise NotImplementedError

In [123]:
print("BERT self-BLEU: %.2f" % (100 * self_bleu(sents)))
print("OpenAI self-BLEU: %.2f" % (100 * self_bleu(openai_sents)))

Self-BLEU: 9.59
Self-BLEU: 19.96



### Scratch ###

In [16]:
# Quality measure via outside language models

# KN5 (KenLM)
# pip install https://github.com/kpu/kenlm/archive/master.zip

# Gated Convolutional LM (Fairseq)
# https://github.com/pytorch/fairseq/blob/master/examples/language_model/README.md

# OpenAI Generative Pretraining LM
# https://github.com/huggingface/pytorch-openai-transformer-lm

In [31]:
STR = "A man of wordly wealth, Sansom was primarily a business man but was also a politician."
" ".join(detokenize(tokenizer.tokenize(STR)))

'A man of wordly wealth , Sansom was primarily a business man but was also a politician .'

In [5]:
""" Get some generations """
import time

n_sample = 500
max_len = 20
top_k = 0
temperature = 1.
burnin = 200
max_iter = 400
print_every = 25

for top_k in [100]:
    for temp in [.1, .5, .7, 2.]:
        if top_k:
            out_file = "generations-len%d-topk%d-temp%.3f.txt" % (max_len, top_k, temp)
        else:
            out_file = "generations-len%d-burnin%d-temp%.3f.txt" % (max_len, burnin, temp)

        times = []
        with open(out_file, "w") as out_fh:
            start_time = time.time()
            for step_n in range(n_sample):
                seed_text = "[CLS]".split()
                sent = parallel_sequential_generation(seed_text, max_len=max_len, 
                                                      top_k=top_k, temperature=temp, 
                                                      burnin=burnin, max_iter=max_iter,
                                                      verbose=False)
                out_fh.write("%s\n" % " ".join(sent[1:-1]))
                times.append(time.time() - start_time)
                start_time = time.time()
                if (step_n + 1) % print_every == 0:
                    print("Generated sentence %d in %.3fs" % (step_n + 1, times[-1]))

        print("Generated %d sentences in %.3fm (~%.3fs/sentence)" % (n_sample, sum(times) / 60, sum(times) / len(times)))

Generated sentence 25 in 60.228s
Generated sentence 50 in 57.656s
Generated sentence 75 in 57.172s
Generated sentence 100 in 57.339s
Generated sentence 125 in 58.277s
Generated sentence 150 in 58.121s
Generated sentence 175 in 58.085s
Generated sentence 200 in 57.711s
Generated sentence 225 in 57.998s
Generated sentence 250 in 57.194s
Generated sentence 275 in 57.329s
Generated sentence 300 in 58.291s
Generated sentence 325 in 57.264s
Generated sentence 350 in 57.242s
Generated sentence 375 in 57.198s
Generated sentence 400 in 57.487s
Generated sentence 425 in 62.306s
Generated sentence 450 in 57.114s
Generated sentence 475 in 57.273s
Generated sentence 500 in 57.203s
Generated 500 sentences in 482.908m (~57.949s/sentence)
Generated sentence 25 in 58.081s
Generated sentence 50 in 58.023s
Generated sentence 75 in 57.322s
Generated sentence 100 in 58.552s
Generated sentence 125 in 58.004s
Generated sentence 150 in 57.925s
Generated sentence 175 in 57.623s
Generated sentence 200 in 58.549

In [7]:
import copy

original_sent = [CLS] + 'new york is the greatest city in the world . '.lower().split() + [SEP]

for ii_ in range(len(original_sent)-2):
    ii = ii_ + 1
    new_sent = copy.copy(original_sent)
    new_sent[ii] = '[MASK]'
#     new_sent[ii] = tokenizer.convert_ids_to_tokens([numpy.random.randint(0, len(tokenizer.vocab))])[0]
    out = model(torch.tensor([tokenizer.convert_tokens_to_ids(new_sent)]))
    pred = tokenizer.convert_ids_to_tokens([out[0][ii].max(0)[1].item()])[0]
    probs = out[0][ii].data.numpy()
    rank = len(tokenizer.vocab) - numpy.argsort(numpy.argsort(probs))[tokenizer.convert_tokens_to_ids([original_sent[ii]])[0]]
    print(" ".join(new_sent), "=>", pred, '|||', 'rank of', original_sent[ii], rank)
#     if pred == 'the':
#         break

[CLS] [MASK] york is the greatest city in the world . [SEP] => new ||| rank of new 1
[CLS] new [MASK] is the greatest city in the world . [SEP] => york ||| rank of york 1
[CLS] new york [MASK] the greatest city in the world . [SEP] => is ||| rank of is 1
[CLS] new york is [MASK] greatest city in the world . [SEP] => the ||| rank of the 1
[CLS] new york is the [MASK] city in the world . [SEP] => largest ||| rank of greatest 15
[CLS] new york is the greatest [MASK] in the world . [SEP] => city ||| rank of city 1
[CLS] new york is the greatest city [MASK] the world . [SEP] => in ||| rank of in 1
[CLS] new york is the greatest city in [MASK] world . [SEP] => the ||| rank of the 1
[CLS] new york is the greatest city in the [MASK] . [SEP] => world ||| rank of world 1
[CLS] new york is the greatest city in the world [MASK] [SEP] => . ||| rank of . 1


In [33]:
batch_size = 10

In [78]:
''' sequential generation: this one kinda works '''


sep_id = tokenizer.convert_tokens_to_ids([SEP])
sample = True
max_len = 20
leed_out_len = 5 #max_len
random_future = False
top_k = 100 # set it to 0 if you don't want top_k
n_samples = 1

seed_text = [[CLS] for _ in range(batch_size)]
seed_len = len(seed_text[0])

for si in range(n_samples):
    #init_text = seed_text + ['[MASK]'] * max_len
    init_text = [seed + ['[MASK]'] * max_len for seed in seed_text]
    init_idx = tokenize_batch(init_text) #tokenizer.convert_tokens_to_ids(init_text)
    #if random_future:
    #    for ii in range(max_len):
    #        init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

    for ii in range(max_len):
        out = model(torch.tensor([i[:seed_len+ii+leed_out_len]+sep_id for i in init_idx]))
        if top_k > 0:
            logits = out[:,seed_len+ii]
            kth_vals, kth_idx = logits.topk(top_k, dim=1)
            dist = torch.distributions.categorical.Categorical(logits=kth_vals)
            new_idxs = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1).tolist()
            for jj in range(len(init_idx)):
                init_idx[jj][ii] = new_idxs[jj]
        else:
            if sample:
                dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+ii])
                init_idx[seed_len+ii] = dist.sample().item()
            else:
                init_idx[seed_len+ii] = torch.max(out[0, seed_len+ii],0)[1].item()

#     print(init_idx)
    for sent in init_idx:
        print(" ".join(tokenizer.convert_ids_to_tokens(sent)))
# print(" ".join(tokenizer.convert_ids_to_tokens(init_idx)).replace(" ##", ""))

" . . . . . , . . , . . . . . , . . . . [MASK]
" and Felix - ( - ) = + . - = . he = . - = . . [MASK]
the and formula ##s ; , and - algebra ; ; . . . . . ; . . . [MASK]
/ was . by . . . . . . and , from , gave . . . . . [MASK]
* by army use of and as the of were applied as ( ( ) , and = ) . [MASK]
. and ##i . . , . . . and ... , . . , to the part , . [MASK]
. you ; ; ; ; ; ' Mr . Scott to the and ##ra of the . " . [MASK]
. . ##1 . : and . . . : . . : . . . . . . . [MASK]
king - as of 2014 . | . / _ . / < < - | . / > | [MASK]
. ##2 = . . = = = = - . = = = = ( = = ) | [MASK]


In [76]:
new_idxs

[[119], [119], [1103], [170], [168], [1110], [119], [176], [119], [119]]

In [362]:
''' parallel generation: this one doesn't work '''

sample = True
max_iter = 100
viz_int = 10
max_len = 20
top_k = 0

seed_text = '[CLS]'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * max_len + ['[SEP]']
init_idx = tokenizer.convert_tokens_to_ids(init_text)
# for ii in range(max_len):
#     init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_iter):
    out = model(torch.tensor([init_idx]))
    for kk in range(max_len):
        if top_k > 0:
            logits = out[0,seed_len+kk]
            kth_vals, kth_idx = logits.topk(top_k)
            dist = torch.distributions.categorical.Categorical(logits=kth_vals)
            init_idx[seed_len+kk] = kth_idx[dist.sample().item()].item()
        else:
            if sample:
                dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+kk])
                init_idx[seed_len+kk] = dist.sample().item()
            else:
                init_idx[seed_len+kk] = torch.max(out[0, seed_len+kk],0)[1].item()
    if numpy.mod(ii, viz_int) == 0:
        print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(init_idx)))

iter 1 [CLS] philippine " ##hara ##id on mir by character sons five god with the , ; for a fatal ##in ; [SEP]
iter 11 [CLS] 2 m be ##h on ##r by the to . god with . aid ; for present definite h . [SEP]
iter 21 [CLS] 2 ##m be ##h on ##r by the to . god with . aid ; for present definite h . [SEP]
iter 31 [CLS] 2 ##m be ##h on ##s by the to . god with . aid ; for present definite h . [SEP]
iter 41 [CLS] 2 ##m be ##h on ze by the to . god with . aid ; for which definite h . [SEP]
iter 51 [CLS] 2 ##m be ##h on - by the to . god with . aid ; p or an h . [SEP]
iter 61 [CLS] 2 ##m be ##h on made by the to . god with . help the p or an h . [SEP]
iter 71 [CLS] 2 ##b be ##h on made by the to . god with . help the p or an h . [SEP]
iter 81 [CLS] 2 ##b be ##h on made by the to . god with . help the p or an h . [SEP]
iter 91 [CLS] 2 ##b be ##h is made by the to . god with . help the p or an h . [SEP]


In [27]:
''' parallel-sequential generation: this one definitely works '''

# sample = True
burnin = 200
max_iter = 300
viz_int = 10
max_len = 15
top_k = 0

seed_text = '[CLS]'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * (max_len) + ['[SEP]']
init_idx = tokenizer.convert_tokens_to_ids(init_text)
#for ii in range(max_len):
#    init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_iter):
    kk = numpy.random.randint(0, max_len)
    init_idx[seed_len+kk] = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
    out = model(torch.tensor([init_idx]))
    if top_k > 0:
        logits = out[0,seed_len+kk]
        kth_vals, kth_idx = logits.topk(top_k)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        init_idx[seed_len+kk] = kth_idx[dist.sample().item()].item()
    else:
        if ii < burnin:
            dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+kk])
            init_idx[seed_len+kk] = dist.sample().item()
        else:
            init_idx[seed_len+kk] = torch.max(out[0, seed_len+kk],0)[1].item()
        
    if numpy.mod(ii+1, viz_int) == 0:
        for_print = tokenizer.convert_ids_to_tokens(init_idx)
        for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
        print("iter", ii+1, " ".join(for_print))

iter 10 [CLS] un (*) ##i [MASK] [MASK] [MASK] ; [MASK] . [MASK] ##i [MASK] : ; [MASK] [MASK] [SEP]
iter 20 [CLS] xx ##ix [MASK] ; [MASK] . [MASK] . (*) [MASK] . [MASK] 2 ; b [MASK] [SEP]
iter 30 [CLS] vi (*) . [MASK] ; [MASK] . [MASK] . 17 . § 2 : ii . [SEP]
iter 40 [CLS] iii . 11 ; norway . iii . 17 . § 87 (*) . 1 . [SEP]
iter 50 [CLS] iii . (*) sweden . norway § 11 . 17 & § 87 . 12 . [SEP]
iter 60 [CLS] iii . denmark (*) & norway § 11 . 17 . § 87 . 20 . [SEP]
iter 70 [CLS] iii . denmark & norway § 87 . 17 ; § 87 . (*) 20 ; [SEP]
iter 80 [CLS] 4 . denmark & norway § 87 (*) . 17 ; § 87 . 20 ; [SEP]
iter 90 [CLS] 4 . denmark & sweden § (*) 85 . 11 ; § 87 . 20 ; [SEP]
iter 100 [CLS] 4 - denmark & norway (*) § 86 . 6 ; § 86 . 2 ; [SEP]
iter 110 [CLS] cf . denmark - (*) norway § 86 . 5 ; § 86 . 7 ; [SEP]
iter 120 [CLS] cf . denmark - schleswig (*) § 1886 . 5 , § 86 . 1 ; [SEP]
iter 130 [CLS] cf . denmark v schleswig § 86 . (*) 5 , § 86 . 1 ; [SEP]
iter 140 [CLS] cf . (*) denmark ser . § 86