In [2]:
!pip install pytorch_pretrained_bert

Collecting pytorch_pretrained_bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |██▋                             | 10kB 26.2MB/s eta 0:00:01[K     |█████▎                          | 20kB 1.7MB/s eta 0:00:01[K     |████████                        | 30kB 2.5MB/s eta 0:00:01[K     |██████████▋                     | 40kB 1.7MB/s eta 0:00:01[K     |█████████████▎                  | 51kB 2.1MB/s eta 0:00:01[K     |███████████████▉                | 61kB 2.5MB/s eta 0:00:01[K     |██████████████████▌             | 71kB 2.9MB/s eta 0:00:01[K     |█████████████████████▏          | 81kB 3.3MB/s eta 0:00:01[K     |███████████████████████▉        | 92kB 3.7MB/s eta 0:00:01[K     |██████████████████████████▌     | 102kB 2.8MB/s eta 0:00:01[K     |█████████████████████████████▏  | 112kB 2.8MB/s eta 0:00:01[K     |██████████████████████

In [3]:
import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet

import math
import time
import random


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


In [4]:
# Load pre-trained model (weights)
model_version = 'bert-base-uncased'
model = BertForMaskedLM.from_pretrained(model_version)
model.eval()
cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda(0)

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=model_version.endswith("uncased"))

def tokenize_batch(batch):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]

def untokenize_batch(batch):
    return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]

def detokenize(sent):
    """ Roughly detokenizes (mainly undoes wordpiece) """
    new_sent = []
    for i, tok in enumerate(sent):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent

CLS = '[CLS]'
SEP = '[SEP]'
MASK = '[MASK]'
mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([SEP])[0]
cls_id = tokenizer.convert_tokens_to_ids([CLS])[0]

100%|██████████| 407873900/407873900 [00:32<00:00, 12735145.60B/s]
100%|██████████| 231508/231508 [00:00<00:00, 425642.17B/s]


In [0]:
def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True):
    """ Generate a word from from out[gen_idx]

    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out[:, gen_idx]
    if temperature is not None: 
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(
            dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1)
    else:
        idx = torch.argmax(logits, dim=-1)
    return idx.tolist() if return_list else idx

In [0]:
def get_word_tokenized(w):
    random.shuffle(w)
    i = 0
    while(i < len(w)):
        try:
            wi = tokenizer.convert_tokens_to_ids([w[i]])[0]
            return wi
            break
        except:
            i += 1

def add_context(batch):
    # EMBEDDING_FILE = '/root/input/GoogleNews-vectors-negative300.bin.gz' # from above
    # word2vec = gensim.models.KeyedVectors.load_word2vec_format(EMBEDDING_FILE, binary=True)
    # w = synonym_generator(WORD)
    w = list(set(sum([ss.lemma_names() for ss in wordnet.synsets(WORD)], [])))
    for b in batch:
        ixs = random.sample(range(1, max_len+1), 1)
        for i in ixs:
            b[i] = get_word_tokenized(w)
    return batch

In [0]:
# Generation modes as functions
def get_init_text(seed_text, max_len, batch_size=1, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]
    # if rand_init:
    #    for ii in range(max_len):
    #        init_idx[seed_len+ii] = np.random.randint(0, len(tokenizer.vocab))
    batch = tokenize_batch(batch)
    batch = add_context(batch)
    #print(batch)
    return batch


def parallel_sequential_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200, cuda=False, print_every=10, verbose=True):
    """ Generate for one random position at a timestep

    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)

    for ii in range(max_iter):
        kk = np.random.randint(0, max_len)
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = mask_id
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        topk = top_k if (ii >= burnin) else 0
        idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk,
                             temperature=temperature, sample=(ii < burnin))
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = idxs[jj]

        if verbose and np.mod(ii+1, print_every) == 0:
            for_print = tokenizer.convert_ids_to_tokens(batch[0])
            for_print = for_print[:seed_len+kk+1] + \
                ['(*)'] + for_print[seed_len+kk+1:]
            print("iter", ii+1, " ".join(for_print))

    return untokenize_batch(batch)

def parallel_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, sample=True, 
                        cuda=False, print_every=10, verbose=True):
    """ Generate for all positions at a time step """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_iter):
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        for kk in range(max_len):
            idxs = generate_step(out, gen_idx=seed_len+kk, top_k=top_k, temperature=temperature, sample=sample)
            for jj in range(batch_size):
                batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii, print_every) == 0:
            print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(batch[0])))
    
    return untokenize_batch(batch)
            
def sequential_generation(seed_text, batch_size=2, max_len=15, leed_out_len=15, 
                          top_k=0, temperature=None, sample=True, cuda=False):
    """ Generate one word at a time, in L->R order """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    batch = batch.cuda() if cuda else batch
    
    for ii in range(max_len):
        inp = [sent[:seed_len+ii+leed_out_len]+[sep_id] for sent in batch]
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        idxs = generate_step(out, gen_idx=seed_len+ii, top_k=top_k, temperature=temperature, sample=sample)
        for jj in range(batch_size):
            batch[jj][seed_len+ii] = idxs[jj]
        
        return untokenize_batch(batch)

def generate(n_samples, seed_text="[CLS]", batch_size=10, max_len=25,
             sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500,
             cuda=False, print_every=1):
    # main generation function to call
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    for batch_n in range(n_batches):
        batch = parallel_sequential_generation(seed_text, max_len=max_len, top_k=top_k,
                                               temperature=temperature, burnin=burnin, max_iter=max_iter,
                                               cuda=cuda, verbose=False)
        #batch = sequential_generation(seed_text, batch_size=20, max_len=max_len, top_k=top_k, temperature=temperature, leed_out_len=leed_out_len, sample=sample)
        #batch = parallel_generation(seed_text, max_len=max_len, top_k=top_k, temperature=temperature, sample=sample, max_iter=max_iter)
        if (batch_n + 1) % print_every == 0:
            print(
                f"Finished batch {(batch_n + 1)} in {(time.time() - start_time)}s")
            start_time = time.time()

        sentences += batch
    return sentences


In [0]:
# Utility functions

def printer(sents, should_detokenize=True):
    for sent in sents:
        sent = detokenize(sent[1:-1]) if should_detokenize else sent
        print(" ".join(sent))
    
    #f should_detokenize:
     #   sent = detokenize(sent)[1:-1]
    
def read_sents(in_file, should_detokenize=False):
    sents = [sent.strip().split() for sent in open(in_file).readlines()]
    if should_detokenize:
        sents = [detokenize(sent) for sent in sents]
    return sents

def write_sents(out_file, sents, should_detokenize=False):
    with open(out_file, "w") as out_fh:
        for sent in sents:
            sent = detokenize(sent[1:-1]) if should_detokenize else sent
            out_fh.write("%s\n" % " ".join(sent))

In [19]:
n_samples = 10
batch_size = 5
max_len = 20
top_k = 100
temperature = 0.7

leed_out_len = 5 # max_len
burnin = 250
sample = True
max_iter = 500

WORD = 'angry'


# Choose the prefix context
# seed_text = f"[CLS] {w[1]}".split()
seed_text = "[CLS]".split()

for temp in [1.0]:
    bert_sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
                          sample=sample, top_k=top_k, temperature=temp, burnin=burnin, max_iter=max_iter,
                          cuda=True)

Finished batch 1 in 8.51607370376587s
Finished batch 2 in 8.455313682556152s


In [20]:
printer(bert_sents, should_detokenize=True)

" hello " ( 2012 ) inspired from " the end of the world ! " for tlc .
sometimes i wonder , whether it was bess who set it down and opened the local book museum ?
the man backed up a few steps , and there roland found jake naked , standing sword against sword .
in the united states , the popularity of electronic music has as a whole grown from home to elsewhere .
" the old man has not reached you . he will rise . " but the voice was gone .
" hi , " the woman said . she had put on a light - red lipstick and jeans .
it was probably revised later by later members , and yet the original text has not yet been released .
after a couple days my ribs were looking pretty good and the pain in my arm was completely gone .
from 1959 , four bn - 56 examples were also manufactured in kirkuk and were exported to turkey .
poisson ; bertrand russell . " numerical analysis of multi - problem problems " . springer / verlag .
