In [1]:
import json
import pickle
import random
import re

import numpy
from scipy import sparse
import torch
from transformers import BertTokenizer, BertForMaskedLM
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
from transformers import RobertaTokenizer, RobertaForMaskedLM
from transformers import GPT2Tokenizer, GPT2LMHeadModel

#import logging
#logging.basicConfig(level=logging.INFO)

tokenizer = None
model = None
loaded_model_name = None

from nltk.corpus import stopwords
from nltk.corpus import cmudict
dictionary = cmudict.dict()
from g2p_en import G2p
g2p = G2p()

m = torch.nn.Softmax(dim=0)

In [2]:
re_word = re.compile(r"[a-zA-Z' ]+")
re_vowel = re.compile(r"[aeiouy]")
def get_pron(tok):
    tok = tokenizer.convert_tokens_to_string([tok])
    if tok.startswith('##'):
        tok = tok[2:]
    if tok.startswith(' '):
        tok = tok[1:]
    if not re_word.match(tok):
        # Punctuation
        return []
    if tok in dictionary:
        pron = dictionary[tok][0]
    else:
        # Word not in CMU dict: guess using g2p_en
        pron = g2p(tok)
    return pron

def get_meter(pron):
    if pron == []:
        return 'p'
    meter = ''
    for ph in pron:
        # We ignore stress levels in favor of poetic scansion
        if ph[-1].isdigit():
            meter += 'u' if ph[-1] == '0' else '-'
    return meter

def get_rhyme(pron):
    if pron == []:
        return 'p'
    rhyme = ''
    for ph in reversed(pron):
        rhyme = ph.replace('1', '').replace('2', '') + rhyme
        if ph[-1].isdigit() and int(ph[-1]) > 0:
            break
    return rhyme

def is_word_piece(model, tok):
    if model.startswith('bert') or model.startswith('distilbert'):
        return tok.startswith('##')
    elif model.startswith('roberta') or model.startswith('gpt2'):
        tok = tokenizer.convert_tokens_to_string([tok])
        return re_word.match(tok) and not tok.startswith(' ')
def join_word_pieces(toks):
    word = ''
    for tok in toks:
        if tok.startswith('##'):
            tok = tok[2:]
        word += tok
    return word

def is_full_word(model_name, tok):
    if model_name.startswith('bert') or model_name.startswith('distilbert'):
        return re_word.match(tok) and not tok.startswith('##')
    elif model_name.startswith('roberta') or model.startswith('gpt2'):
        tok = tokenizer.convert_tokens_to_string([tok])
        return re_word.match(tok) and tok.startswith(' ')

def is_punctuation(tok):
    tok = tokenizer.convert_tokens_to_string([tok])
    return not re_word.match(tok)

In [3]:
def create_meter_dict(model_name):
    print("Generating " + model_name + '_meter_dict.pkl')
    vocab = tokenizer.get_vocab()
    vocab_size = len(vocab)
    meter_dict = {}
    word_pieces = torch.zeros([vocab_size])
    for tok in vocab:
        i = vocab[tok]
        pron = get_pron(tok)
        meter = get_meter(pron)
        if meter not in meter_dict:
            meter_dict[meter] = torch.zeros([vocab_size])
        meter_dict[meter][i] = 1.0
        if is_word_piece(model_name, tok):
            word_pieces[i] = 1.0

    pickle.dump((word_pieces, meter_dict),
                open(model_name + '_meter_dict.pkl', 'wb'))

In [4]:
def create_rhyme_matrix(model_name):
    print("Generating " + model_name + '_rhyme_matrix.pkl')
    vocab = tokenizer.get_vocab()
    vocab_size = len(vocab)
    rhyme_matrix = sparse.lil_matrix((vocab_size, vocab_size))
    rhymable_words = torch.zeros([vocab_size])
    rhyme_groups = {}
    for tok in vocab:
        i = vocab[tok]
        pron = get_pron(tok)
        rhyme = get_rhyme(pron)
        if rhyme not in rhyme_groups:
            rhyme_groups[rhyme] = []
        rhyme_groups[rhyme].append((i, pron))
    for rhyme in rhyme_groups:
        if len(rhyme_groups[rhyme]) < 2:
            continue
        for i, pron1 in rhyme_groups[rhyme]:
            rhymable = False
            for j, pron2 in rhyme_groups[rhyme]:
                # Words with identical pronunciations can't be used as rhymes
                if pron1 != pron2:
                    rhyme_matrix[i,j] = 1.0
                    rhymable = True
            if rhymable:
                rhymable_words[i] = 1.0

    rhyme_matrix = sparse.csc_matrix(rhyme_matrix)
    pickle.dump((rhymable_words, rhyme_matrix), open(model_name + '_rhyme_matrix.pkl', 'wb'))

In [25]:
vocab = None
vocab_size = None
meter_dict = {}
word_pieces = None
rhymable_words = None
rhyme_matrix = None
rhyme_tensors = {}
def initialize_rhyme_and_meter(model, meter=False, rhymes=False):
    global vocab, vocab_size, word_pieces, meter_dict, rhymable_words, rhyme_matrix
    vocab = tokenizer.get_vocab()
    vocab_size = len(vocab)
    if meter:
        try:
            f = open(model + '_meter_dict.pkl', 'rb')
        except FileNotFoundError:
            create_meter_dict(model)
            f = open(model + '_meter_dict.pkl', 'rb')
        word_pieces, meter_dict = pickle.load(f)
    else:
        try:
            f = open(model + '_meter_dict.pkl', 'rb')
        except FileNotFoundError:
            create_meter_dict(model)
            f = open(model + '_meter_dict.pkl', 'rb')
        word_pieces, _ = pickle.load(f)
    if rhymes:
        global rhyme_matrix
        try:
            f = open(model + '_rhyme_matrix.pkl', 'rb')
        except FileNotFoundError:
            create_rhyme_matrix(model)
            f = open(model + '_rhyme_matrix.pkl', 'rb')
        rhymable_words, rhyme_matrix = pickle.load(f)

# Computes the model's predictions for a text with a given set of ranges
# masked by single mask tokens.
def compute_replacement_probs_for_masked_tokens(model, tokenized_text,
                                                masked_indices):
    n = len(masked_indices)
    dim = [vocab_size] * n

    tokenized_text = tokenized_text.copy()
    shift = 0
    for i1, i2 in masked_indices:
        i1 -= shift
        i2 -= shift
        shift += (i2 - i1)
        tokenized_text[i1:i2+1] = ['[MASK]']
        
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])

    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    replacement_probs = [None] * n
    shift = 0
    for k, (i1, i2) in enumerate(masked_indices):
        i1 -= shift
        i2 -= shift
        shift += (i2 - i1)
        replacement_probs[k] = [predictions[0, i1, :]]

    return replacement_probs

# Computes the model's predictions for a text with a given set of ranges
# masked.
def compute_probs_for_masked_tokens(model, tokenized_text, masked_indices):
    n = len(masked_indices)
    dim = [vocab_size] * n

    multipart_words = False
    tokenized_text = tokenized_text.copy()
    for i1, i2 in masked_indices:
        if i2 > i1:
            multipart_words = True
        tokenized_text[i1:i2+1] = ['[MASK]'] * (i2 - i1 + 1)

    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])

    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    probs = [None] * n
    for k, (i1, i2) in enumerate(masked_indices):
        word_probs = []
        for i in range(i1, i2+1):
            word_probs.append(predictions[0, i, :])
        probs[k] = word_probs

    if multipart_words:
        # We need to compute a separate probability with only one mask token
        # for each word, so that we can replace the multipart words with
        # single-part words.
        replacement_probs \
            = compute_replacement_probs_for_masked_tokens(model,
                                                          tokenized_text,
                                                          masked_indices)

    else:
        replacement_probs = probs

    return probs, replacement_probs

# Find words that could, if chosen for the masked indices, take us back to an
# arrangement that has already been tried. Because we are compiling independent
# lists of forbidden words for each index, this method can overcorrect.
def find_forbidden_words(tokenized_text, masked_indices, forbidden_texts):
    forbidden_words = [torch.ones((vocab_size,))
                       for i in range(len(masked_indices))]
    d = forbidden_texts
    def f(d, start):
        i = start
        for tok in tokenized_text[start:]:
            mask_num = None
            mask_len = 0
            for k, (i1, i2) in enumerate(masked_indices):
                if i == i1:
                    mask_num = k
                    mask_len = i2 - i1 + 1
                    break
            if mask_num is not None:
                reached_end = False
                for option_tok in d.keys():
                    if f(d[option_tok], i+mask_len):
                        option_idx = tokenizer \
                                     .convert_tokens_to_ids([option_tok])[0]
                        forbidden_words[mask_num][option_idx] = 0.0
                        reached_end = True
                return reached_end
            else:
                if tok in d:
                    d = d[tok]
                else:
                    return False
            i += 1
        return True
    if f(d, 0):
        return forbidden_words
    else:
        return None

# Function to adjust the output of the model based on various options.
def adjust_probs(model, probs, tokenized_text, start, end, masked_indices,
                 modifier=None, match_meter=None, forbidden_texts=None,
                 random_factor=False, discouraged_words=None,
                 rhyme_with=None, rhymable_only=False, rhymable_with_meters=False,
                 allow_punctuation=None, no_word_pieces=False,
                 strong_topic_bias=False, topicless_probs=None):
        
    if forbidden_texts is not None:
        forbidden_words = find_forbidden_words(tokenized_text,
                                               masked_indices,
                                               forbidden_texts)
    else:
        forbidden_words = None

    adj_probs = [[u.clone() for u in t] for t in probs]
    for k in range(len(adj_probs)):
        for j in range(len(adj_probs[k])):
            if random_factor:
                noise = torch.randn_like(adj_probs[k][j])
                noise = noise * random_factor + 1.0
                adj_probs[k][j] *= noise

            adj_probs[k][j] = m(adj_probs[k][j])

            # Do not produce word pieces. There is no way to keep the model
            # behaving reliably if we allow it to produce words that are not
            # actually in its vocabulary.
            if no_word_pieces:
                adj_probs[k][j] *= (1.0 - word_pieces)
                
            if rhymable_only:
                adj_probs[k][j] *= rhymable_words
            if rhymable_with_meters:
                for rhyme_meter in rhymable_with_meters:
                    test_meter = get_meter(get_pron(rhyme_meter))
                    meter_tensor = meter_dict[test_meter]
                    meter_matrix = sparse.dia_matrix((meter_tensor, [0]),
                                                     shape=(vocab_size, vocab_size))
                    # Take the dot product of meter and rhyme
                    mat = meter_matrix.dot(rhyme_matrix)
                    vec = torch.from_numpy(mat.sum(0)).squeeze().to(dtype=bool)
                    adj_probs[k][j] *= vec

            if forbidden_words is not None:
                adj_probs[k][j] *= forbidden_words[k]

            if allow_punctuation is False:
                adj_probs[k][j] *= (1.0 - meter_dict['p'])

            if match_meter is not None:
                test_meter = get_meter(get_pron(match_meter[k]))
                meter_tensor = meter_dict[test_meter]
                if allow_punctuation is True:
                    adj_probs[k][j] *= (meter_tensor + meter_dict['p'])
                else:
                    adj_probs[k][j] *= meter_tensor

            if modifier is not None:
                adj_probs[k][j] *= modifier
            if discouraged_words is not None:
                adj_probs[k][j] *= discouraged_words

            if rhyme_with is not None:
                for rhyme_word in rhyme_with:
                    rhyme_idx = tokenizer.convert_tokens_to_ids([rhyme_word])[0]
                    rhyme_tensor = rhyme_matrix[rhyme_idx, :].todense()
                    rhyme_tensor = torch.from_numpy(rhyme_tensor)
                    rhyme_tensor = rhyme_tensor.squeeze()
                    adj_probs[k][j] *= rhyme_tensor

            if strong_topic_bias:
                bias_factor = (m(probs[k][j]) / m(topicless_probs[k][j])) ** strong_topic_bias
                adj_probs[k][j] *= bias_factor

    return adj_probs

# Compute a score indicating how well the model's predictions improve the
# probability for certain words. If multiple words are chosen, it is
# assumed that they are supposed to rhyme.
def compute_score_for_tokens(probs1, probs2, tokenized_text,
                             indices, relative):
    n = len(indices)
    dim = [vocab_size] * n
    
    existing_token_ids = [None] * n
    for k, (i1, i2) in enumerate(indices):
        existing_token_ids[k] = []
        for i in range(i1, i2+1):
            token = tokenized_text[i]
            index = tokenizer.convert_tokens_to_ids([token])[0]
            existing_token_ids[k].append(index)
            
    existing_words_prob = 1.0
    if probs1:
        for k in range(n):
            existing_word_prob = 0.0
            for i, tok_id in enumerate(existing_token_ids[k]):
                prob_tensor = probs1[k][i]
                existing_word_prob += prob_tensor[tok_id]
            existing_word_prob /= len(existing_token_ids[k])
            existing_words_prob *= existing_word_prob
        
    if n == 1:
        prob_tensor = probs2[0][0]
        prediction_prob = torch.max(prob_tensor)
        idx = prob_tensor.argmax().item()
        predicted_token_ids = [idx]

    elif n == 2:
        # We compute scores for possible rhyme pairs using sparse matrix
        # arithmetic. We use scipy instead of torch because torch's sparse
        # tensors do not support the .max() function.
        left_mat = sparse.dia_matrix((probs2[0][0], [0]), shape=dim)
        mat = left_mat.dot(rhyme_matrix)
        right_mat = sparse.dia_matrix((probs2[1][0], [0]), shape=dim)
        mat = mat.dot(right_mat)
        prediction_prob = mat.max()
        idx = mat.argmax()
        predicted_token_ids = list(numpy.unravel_index(idx, dim))
    
    if probs1:
        if relative:
            score = existing_words_prob / prediction_prob
        else:
            score = existing_words_prob
    else:
        score = prediction_prob
        
    predicted_tokens = [None] * n
    for i in range(n):
        predicted_tokens[i] \
            = tokenizer.convert_ids_to_tokens([predicted_token_ids[i]])[0]

    return predicted_tokens, score

# Tokenize a text and figure out (as best we can) its rhyme scheme.
def process_text(model, text, start, end, match_rhyme, strip_punctuation=False):
    lines = text.split('\n')
    
    tok_index = start
    toks = []
    rhyme_types = {}
    multipart_words = {}
    fixed = False
    fixed_toks = set()
    line_ends = set()
    for line in lines:
        if model.startswith('roberta') or model.startswith('gpt2'):
            line = ' ' + line

        # Check for the special '[]' characters that indicate fixed text.
        line_new = ''
        shift = 0
        fixed_chars = set()
        for i, ch in enumerate(line):
            if (model.startswith('bert') or model.startswith('distilbert')) and ch == ' ':
                # BERT tokenizer strips spaces, so we must account for that.
                shift += 1
            if ch == '[':
                fixed = True
                shift += 1
            elif ch == ']':
                fixed = False
                shift += 1
            else:
                line_new += ch
                if fixed:
                    fixed_chars.add(i - shift)
        
        line_toks = tokenizer.tokenize(line_new)
        line_fixed_toks = set()
        i = 0
        for j, tok in enumerate(line_toks):
            tok = tokenizer.convert_tokens_to_string([tok])
            if tok.startswith('##'):
                tok = tok[2:]
            nchars = len(tok)
            for k in range(nchars):
                if i+k in fixed_chars:
                    line_fixed_toks.add(j + tok_index)
                    break
            i += nchars
        
        if strip_punctuation:
            stripped_line_toks = []
            stripped_fixed_toks = set()
            shift = 0
            for j, tok in enumerate(line_toks):
                if is_punctuation(tok):
                    shift += 1
                else:
                    stripped_line_toks.append(tok)
                    if j + tok_index in line_fixed_toks:
                        stripped_fixed_toks.add(j + tok_index - shift)
            line_toks = stripped_line_toks
            line_fixed_toks = stripped_fixed_toks
        
        toks += line_toks
        fixed_toks.update(line_fixed_toks)

        # Check for multipart words.
        word_bounds = []
        for i, tok in enumerate(line_toks):
            if is_word_piece(model, tok) or tok in ("'", "s", "st", "d",
                                                    "ve", "re", "nt", "ll",
                                                    "t", "m"):
                if not word_bounds:
                    word_bounds.append([i, i])
                else:
                    word_bounds[-1][1] = i
            else:
                word_bounds.append([i, i])
        for i1, i2 in word_bounds:
            if i1 == i2:
                continue
            for i in range(i1, i2+1):
                multipart_words[i + tok_index] = (i1 + tok_index,
                                                  i2 + tok_index)

        if match_rhyme:
            rhyme_type = None
            # Only check rhyme for the last non-punctuation word of a line.
            word = ''
            i = len(line_toks) - 1
            while i >= 0:
                if i + tok_index in multipart_words:
                    i1, i2 = multipart_words[i + tok_index]
                    word = join_word_pieces(line_toks[i1-tok_index:i2-tok_index+1])
                    i = multipart_words[i + tok_index][0] - tok_index
                else:
                    word = line_toks[i]

                pron = get_pron(word)
                if pron != []:
                    rhyme_type = get_rhyme(pron)
                    if rhyme_type is not None:
                        if not rhyme_type in rhyme_types:
                            rhyme_types[rhyme_type] = []
                        rhyme_types[rhyme_type].append(tok_index + i)
                        break
                
                i -= 1
            
        tok_index += len(line_toks)
        line_ends.add(tok_index)

    if match_rhyme:
        rhyme_groups = {}
        for rhyme in rhyme_types:
            tok_list = rhyme_types[rhyme]
            # Rhyme groups of more than two not currently supported, so we
            # split the groups up into pairs
            for i in range(0, len(tok_list), 2):
                group = tok_list[i:i+2]
                for index in group:
                    rhyme_groups[index] = group

        return toks, fixed_toks, multipart_words, rhyme_groups, line_ends
    
    else:
        return toks, fixed_toks, multipart_words, {}, line_ends

# Alters a text iteratively, word by word, using the model to pick
# replacements.
def depoeticize(text, max_iterations=100,
                match_meter=False, match_rhyme=False, topic=None,
                randomize=False, cooldown=0.01, modifier=None,
                forbid_reversions=True, preserve_punctuation=False,
                topic_prefix="The following poem is about",
                topic_postfix="The preceding poem was about",
                strong_topic_bias=False, stop_score=0.001,
                discourage_repetition=False, stopwords=stopwords.words('english'),
                model_name='bert-base-uncased',
                sequential=False, verbose=True):
    stopwords = set(stopwords)
    
    global tokenizer, model, loaded_model_name
    if loaded_model_name != model_name:
        if model_name.startswith('distilbert'):
            tokenizer = DistilBertTokenizer.from_pretrained(model_name)
            model = DistilBertForMaskedLM.from_pretrained(model_name)
            bos_token = '[CLS]'
            eos_token = '[SEP]'
        if model_name.startswith('bert'):
            tokenizer = BertTokenizer.from_pretrained(model_name)
            model = BertForMaskedLM.from_pretrained(model_name)
            bos_token = '[CLS]'
            eos_token = '[SEP]'
        if model_name.startswith('roberta'):
            tokenizer = RobertaTokenizer.from_pretrained(model_name)
            model = model = RobertaForMaskedLM.from_pretrained(model_name)
            bos_token = tokenizer.bos_token
            eos_token = tokenizer.eos_token
        model.eval()

    initialize_rhyme_and_meter(model_name, meter=match_meter,
                               rhymes=match_rhyme)

    if topic:
        toks1 = tokenizer.tokenize("{0} {1} {2}: "
                                   .format(bos_token, topic_prefix, topic))
        toks3 = tokenizer.tokenize(" {0} {1}. {2}"
                                   .format(topic_postfix, topic, eos_token))
    else:
        toks1 = tokenizer.tokenize("{0} ".format(bos_token))
        toks3 = tokenizer.tokenize(" {0}".format(eos_token))
    start = len(toks1)
    end = len(toks3)

    toks2, fixed_toks, multipart_words, rhyme_groups, line_ends \
        = process_text(model_name, text, start, end, match_rhyme)
    tokenized_text = toks1 + toks2 + toks3
    n = len(tokenized_text)

    forbidden_texts = {}

    if sequential:
        max_iterations = len(toks2)
    for k in range(max_iterations):
        last_score = 0.0
        
        if sequential and k >= len(tokenized_text) - start - end:
            break
            
        # Discourage the selection of words already in the text, save for stopwords.
        if discourage_repetition is not False:
            discouraged_words = torch.ones((vocab_size,))
            for i in range(start, n-end):
                tok = tokenized_text[i]
                if tok in stopwords:
                    continue
                idx = tokenizer.convert_tokens_to_ids([tok])[0]
                discouraged_words[idx] = discourage_repetition
        else:
            discouraged_words = None
        
        # Compute the scores used to choose which word to change
        outputs = [(None, None, float("inf"))] * n
        if sequential:
            test_range = [start + k]
        else:
            test_range = range(start, n-end)
        for i in test_range:
            if preserve_punctuation:
                if is_punctuation(tokenized_text[i]):
                    continue
            if i in fixed_toks:
                continue
            if i in multipart_words and i != multipart_words[i][0]:
                # Only try the first part of a multipart word
                continue
                
            if match_rhyme and i in rhyme_groups:
                if i != rhyme_groups[i][0]:
                    # Only try each rhyme group once
                    continue
                indices = rhyme_groups[i]
            else:
                indices = [i]
                
            indices = [multipart_words.get(idx, [idx, idx])
                       for idx in indices]
            if match_meter:
                meter = [join_word_pieces(tokenized_text[i1:i2+1])
                         for (i1, i2) in indices]
            else:
                meter = None
                
            if sequential:
                probs1 = None
                probs2 = compute_replacement_probs_for_masked_tokens(model,
                                                                     tokenized_text,
                                                                     indices)
            else:
                probs1, probs2 \
                    = compute_probs_for_masked_tokens(model,
                                                      tokenized_text,
                                                      indices)

            # The strong topic bias feature compares the probs with and
            # without the topic and biases the results in favor of words
            # that are more probable with it.
            if topic and strong_topic_bias:
                topicless_indices = [(i1-start, i2-start)
                                     for (i1, i2) in indices]
                if sequential:
                    topicless_probs1 = None
                    topicless_probs2 \
                        = compute_replacement_probs_for_masked_tokens(model,
                                                          tokenized_text[start:-end],
                                                          topicless_indices)
                else:
                    topicless_probs1, topicless_probs2 \
                        = compute_probs_for_masked_tokens(model,
                                                          tokenized_text[start:-end],
                                                          topicless_indices)
            else:
                topicless_probs1 = None
                topicless_probs2 = None
                
            if not sequential:
                probs1 = adjust_probs(model, probs1, tokenized_text, start,
                                      end, indices, modifier,
                                      random_factor=randomize,
                             strong_topic_bias=topic and strong_topic_bias,
                             topicless_probs=strong_topic_bias and topicless_probs1)
            probs2 = adjust_probs(model, probs2, tokenized_text, start,
                                  end, indices, modifier,
                                  meter, forbidden_texts,
                                  discouraged_words=discouraged_words,
                                  random_factor=randomize,
                                  no_word_pieces=True,
                         strong_topic_bias=topic and strong_topic_bias,
                         topicless_probs=strong_topic_bias and topicless_probs2)
            
            predicted_tokens, score \
                = compute_score_for_tokens(probs1, probs2,
                                           tokenized_text, indices,
                                           relative=True)
            outputs[i] = (indices, predicted_tokens, score)

        # Choose a word to change
        outputs.sort(key=lambda t: t[2])
        chosen_indices = None
        for (indices, predicted_tokens, score) in outputs:
            if score >= stop_score:
                break
            if predicted_tokens is None:
                continue
            chosen_indices = indices
            chosen_tokens = predicted_tokens
            last_score = score
            break

        if chosen_indices is None:
            if sequential:
                continue
            else:
                break

        # To prevent loops, we forbid the model from reverting to texts that it
        # has already tried. The texts are stored in a trie (prefix tree) for
        # efficient searchability.
        if forbid_reversions:
            d = forbidden_texts
            for tok in tokenized_text:
                if not tok in d:
                    d[tok] = {}
                d = d[tok]

        # Make the actual revision and make note of what we've done.
        change_made = False
        new_token_indices = []
        shift = 0
        for j, (i1, i2) in enumerate(chosen_indices):
            i1 -= shift
            i2 -= shift
            shift += (i2 - i1)
            n -= (i2 - i1)
            token = chosen_tokens[j]
            if i2 > i1:
                change_made = True
                tokenized_text[i1:i2+1] = [token]
                new_token_indices.append(i1)
            elif tokenized_text[i1] != token:
                change_made = True
                tokenized_text[i1] = token
                new_token_indices.append(i1)

            for i in range(i1, i2+1):
                if i in multipart_words:
                    del multipart_words[i]
            replacements = {}
            for i in list(multipart_words.keys()):
                if i > i2:
                    j1, j2 = multipart_words[i]
                    del multipart_words[i]
                    replacements[i - (i2 - i1)] = (j1 - (i2 - i1),
                                                   j2 - (i2 - i1))
            for i in replacements:
                multipart_words[i] = replacements[i]
                    
            replacements = {}
            for i_old in list(rhyme_groups.keys()):
               group = rhyme_groups[i_old].copy()
               if i_old > i1:
                   i_new = i_old - (i2 - i1)
               else:
                   i_new = i_old
               group = [(idx - (i2 - i1) if idx > i1 else idx)
                        for idx in group]
               replacements[i_new] = group
            rhyme_groups = replacements

        if not change_made:
            if sequential:
                continue
            else:
                break

        if verbose:
            sample = tokenized_text[start:-end].copy()
            for i in new_token_indices:
                sample[i-start] = '<' + sample[i-start] + '>'
            text = tokenizer.convert_tokens_to_string(sample)
            print('-----------------------')
            print('Iteration {0}, score = {1}'.format(k+1, last_score))
            print(tokenizer.clean_up_tokenization(text))

        if randomize and cooldown:
            randomize *= (1.0 - cooldown)
            
    text = tokenizer.convert_tokens_to_string(tokenized_text[start:-end])
    return tokenizer.clean_up_tokenization(text)

# Generates a wholly new text by running a decoder model forward with the specified
# constraints. This doesn't work very well.
def parody(text, match_meter=False, match_rhyme=False, topic=None,
           randomize=False, modifier=None, verbose=True,
           topic_prefix="", model='gpt2'):
    model_name = model
    
    global tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    model.eval()
    eos_token = tokenizer.eos_token
    
    initialize_rhyme_and_meter(model_name, meter=True, rhymes=match_rhyme)

    if topic:
        toks1 = tokenizer.tokenize("{0} {1} {2}. "
                                   .format(eos_token, topic_prefix, topic))
    else:
        toks1 = [eos_token]
    start = len(toks1)

    # We strip punctuation because, not being able to look ahead, the GPT-2
    # model cannot reliably produce text that matches the punctuation of the
    # original; the only way to get coherent output is to let the model decide
    # on the punctuation.
    toks2, fixed_toks, multipart_words, rhyme_groups, line_ends \
        = process_text(model_name, text, start, 0, match_rhyme,
                       strip_punctuation=True)
    
    tokenized_text = toks1 + toks2
    n = len(tokenized_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    
    # As in Beckett's "The Unnamable," we force the model to keep writing
    # even when it wants to stop.
    discouraged_words = torch.ones((vocab_size,))
    eos_token_id = tokenizer.convert_tokens_to_ids([eos_token])[0]
    discouraged_words[eos_token_id] = 0.0
    newline_token_id = tokenizer.convert_tokens_to_ids(['\n'])[0]
    discouraged_words[newline_token_id] = 0.0

    out_toks = indexed_tokens[:start]
    i = start
    just_added_punctuation = False
    just_rhymed = False
    while i < n:
        if i in fixed_toks:
            tok = indexed_tokens[i]
            out_toks.append(tok)
            tok = tokenizer.convert_ids_to_tokens([tok])[0]
            just_added_punctuation = is_punctuation(tok)
            i += 1
            continue
            
        if match_rhyme and i in rhyme_groups:
            rhyming = True
            # We can't look ahead with this model, so the rhyming constraint
            # only looks at words already chosen.
            rhyme_words = [tokenizer.convert_ids_to_tokens([indexed_tokens[idx]])[0]
                           for idx in rhyme_groups[i] if idx < i]
            # ...but we do need to make sure to choose a word that can be rhymed
            # with at least one word of the meter of later rhyming words.
            rhyme_meters = [tokenizer.convert_ids_to_tokens([indexed_tokens[idx]])[0]
                            for idx in rhyme_groups[i] if idx > i]
        else:
            rhyming = False
            rhyme_words = None
            rhyme_meters = None
        
        i1, i2 = multipart_words.get(i, [i, i])
        if match_meter:
            meter = [join_word_pieces(tokenized_text[i1:i2+1])]
        else:
            meter = None
            
        with torch.no_grad():
            tokens_tensor = torch.tensor([out_toks])
            outputs = model(tokens_tensor)
            predictions = outputs[0]
            
        no_word_pieces = (i == start) or rhyme_words or just_added_punctuation or just_rhymed

        probs = [[predictions[0, -1, :]]]
        probs = adjust_probs(model, probs, None, 0, 0, None,
                             modifier, meter,
                             random_factor=randomize,
                             discouraged_words=discouraged_words,
                             allow_punctuation=not just_added_punctuation and not rhyme_words,
                             no_word_pieces=no_word_pieces,
                             rhyme_with=rhyme_words,
                             rhymable_only=not match_meter and rhyming,
                             rhymable_with_meters=match_meter and rhyme_meters)
        
        idx = probs[0][0].argmax().item()
        if idx == eos_token_id:
            break
        
        tok = tokenizer.convert_ids_to_tokens([idx])[0]
        out_toks.append(idx)
        
        just_rhymed = not not rhyme_words

        # Only proceed to the next input token if the output is a
        # non-punctuation token.
        if meter_dict['p'][idx] == 0.0:
            if verbose and i in line_ends:
                print('')
            # Record the chosen token for rhyming purposes.
            indexed_tokens[i] = idx
            i += i2 - i1 + 1
            just_added_punctuation = False
        else:
            # We don't allow multiple punctuation tokens in a row. This is
            # because the model can potentially get stuck in a loop where it
            # generates nothing but punctuation, in which case the process
            # would never end.
            just_added_punctuation = True
            
        if verbose:
            string = tokenizer.convert_tokens_to_string([tok])
            print(string, end='')

    out = tokenizer.convert_ids_to_tokens(out_toks[start:])
    text = tokenizer.convert_tokens_to_string(out)
    return tokenizer.clean_up_tokenization(text)

# Add modifier=metalness_modifier() to bias the results toward words that occur
# frequently in heavy metal lyrics. First you will need to download the data set
# available at https://github.com/ijmbarr/pythonic-metal.
def metalness_modifier():
    f = open('metalness.json', 'r')
    metalness = json.load(f)
    f.close()
    vocab = tokenizer.get_vocab()
    metalness_modifier = [0.0] * len(vocab)
    for i, tok in enumerate(vocab):
        if tok in metalness:
            metalness_modifier[i] = metalness[tok]
    return m(torch.tensor(metalness_modifier))

In [46]:
text = '''Shall I compare thee to a summer's day? 
Thou art more lovely and more temperate:
Rough winds do shake the darling buds of May,
And summer's lease hath all too short a date: 
Sometime too hot the eye of heaven shines,
And often is his gold complexion dimmed; 
And every fair from fair sometime declines,
By chance, or nature's changing course, untrimmed;
But thy eternal summer shall not fade
Nor lose possession of that fair thou ow'st;
Nor shall Death brag thou wander'st in his shade,
When in eternal lines to time thou grow'st; 
So long as men can breathe or eyes can see,
So long lives this, and this gives life to thee.'''

depoeticize(text,
            model_name='bert-base-uncased',
            max_iterations=50,
            randomize=0.0,
            cooldown=0.05,
            stop_score=1.0,
            match_meter=True,
            match_rhyme=True,
            preserve_punctuation=True,
            topic=None,
            strong_topic_bias=2.0,
            sequential=True,
            verbose=False,
            modifier=None)#metalness_modifier())

'may i invite thee on a sunny day? it is more sunny and less sunny : cold winds may blow the autumn leaves to may, and our day is not too late in may : always is where the light of nature comes, and never is thy fair complexion pale ; and different day is that mankind remains, or dies, on any other day, today ; but this eternal beauty shall none gain or take advantage of how much thou gives ; nor does thou as thou suffer in thy pain, but in unbroken time by which thou lives ; as long that breath can hear that breath can be, as thou knows it, and thou gives it to me.'

In [151]:
text = '''[Tyger Tyger], burning bright, 
In the forests of the night; 
What immortal hand or eye, 
Could frame thy fearful symmetry?'''

parody(text,
       model='gpt2',
       match_meter=True,
       match_rhyme=True,
       topic="a little, fluffy kitty cat",
       randomize=0.005,
       modifier=None)

' Tyger Tyger, also known in the English as "the lone wolf," created this cat, named "T-Rex," by writing poetry'

In [18]:
# This text is modified to ensure that the program can pick up on
# all the rhymes—it makes no difference because the parody model
# only looks at the rhyme and meter of the input.
text = '''Shall I compare thee to a summer's day?
Thou art more lovely and more fascinate:
Rough winds do shake the darling buds of May,
And summer's lease hath all too short a date: 
Sometime too hot the eye of heaven shines,
And often is his gold complexion dimmed; 
And every fair from fair sometime declines,
By chance, or nature's changing course, untrimmed;
But thy eternal summer shall not fade
Nor lose possession of that fair thou ow'st;
Nor shall Death brag thou wander'st in his shade,
When in eternal lines to time thou grow'st; 
So long as men can breathe or eyes can see,
So long lives this, and this gives life to thee.'''

parody(text,
       model='gpt2',
       match_meter=True,
       match_rhyme=True,
       randomize=0.005,
       topic="Sonic the Hedgehog and his friend Tails",
       modifier=None)


 I've received some great and funny mail
 from fans of Sonic the Hedge Labyrinth,
 so I thought I'm going to write nail-
 and pencil- on some of them. I'm synth-
 snikt, and I've never done
 a Sonic game, so I decided to "
 in- character" some of them. Although begun
 with my first Sonic, Sonic is into
 it now. Remember, Sonic is an old
 school, old- Nintendo game, so it's not
 like I've got all Sonic in my mold.
 I'm assuming that's what you're
 up to, though. I'll just say that I love
 it, so I'll... and I won't say of

'\n I\'ve received some great and funny mail from fans of Sonic the Hedge Labyrinth, so I thought I\'m going to write nail- and pencil- on some of them. I\'m synth- snikt, and I\'ve never done a Sonic game, so I decided to " in- character" some of them. Although begun with my first Sonic, Sonic is into it now. Remember, Sonic is an old school, old- Nintendo game, so it\'s not like I\'ve got all Sonic in my mold. I\'m assuming that\'s what you\'re up to, though. I\'ll just say that I love it, so I\'ll... and I won\'t say of'

In [12]:
text = '''And did those feet in ancient time
Walk upon Englands mountains green:
And was the holy Lamb of God,
On Englands pleasant pastures seen!
 
And did the Countenance Divine,
Shine forth upon our clouded hills?
And was Jerusalem builded here,
Among these dark Satanic Mills?
 
Bring me my Bow of burning gold:
Bring me my arrows of desire:
Bring me my Spear: O clouds unfold!
Bring me my Chariot of fire!
 
I will not cease from Mental Fight,
Nor shall my sword sleep in my hand:
Till we have built Jerusalem,
In Englands green & pleasant Land.'''

depoeticize(text,
            model_name='bert-base-uncased',
            max_iterations=50,
            randomize=False,
            cooldown=0.05,
            stop_score=1.0,
            match_meter=True,
            match_rhyme=True,
            preserve_punctuation=True,
            topic=None,
            strong_topic_bias=2.0,
            modifier=None)#metalness_modifier())

-----------------------
Iteration 1, score = 8.776559070611256e-07
and did those feet in ancient time walk upon englands mountains green : and was the holy lamb of god, on <our> pleasant pastures seen! and did the countenance divine, shine forth upon our clouded hills? and was jerusalem builded here, among these dark satanic mills? bring me my bow of burning gold : bring me my arrows of desire : bring me my spear : o clouds unfold! bring me my chariot of fire! i will not cease from mental fight, nor shall my sword sleep in my hand : till we have built jerusalem, in englands green & pleasant land.
-----------------------
Iteration 2, score = 1.4563041759174666e-06
and did those feet in ancient time walk upon <our> mountains green : and was the holy lamb of god, on our pleasant pastures seen! and did the countenance divine, shine forth upon our clouded hills? and was jerusalem builded here, among these dark satanic mills? bring me my bow of burning gold : bring me my arrows of desire : b

'and did he who in our time shone upon our pleasant land : and by the holy grace of god, let our pleasant country stand! and did the glorious divine, shine forth upon our pleasant land? and was jerusalem standing here, upon his great almighty hand? bring me my sword of shining light : bring me my weapon of desire : bring me my spear : o lord above! bring me my instrument of fire! i will not die from our fight, nor will my spear be in my hand : for we have reached jerusalem, in our time & pleasant land.'

In [39]:
text = '''[Tyger Tyger], burning bright, 
In the forests of the night; 
What immortal hand or eye, 
Could frame thy fearful symmetry?'''

depoeticize(text,
            model_name='bert-base-uncased',
            max_iterations=50,
            randomize=0.1,
            cooldown=0.05,
            stop_score=1.0,
            match_meter=True,
            match_rhyme=True,
            preserve_punctuation=True,
            topic="a little, fluffy kitty cat",
            strong_topic_bias=2.0,
            modifier=None)#metalness_modifier())

-----------------------
Iteration 1, score = 4.096679333542852e-07
tyger tyger, burning bright, in the forests of the night ; what immortal hand or eye, could <break> thy fearful symmetry?
-----------------------
Iteration 2, score = 1.5753139450680465e-06
tyger tyger, burning bright, in the <middle> of the night ; what immortal hand or eye, could break thy fearful symmetry?
-----------------------
Iteration 3, score = 5.9419497119961306e-05
tyger tyger, <shining> bright, in the middle of the night ; what immortal hand or eye, could break thy fearful symmetry?
-----------------------
Iteration 4, score = 0.0006441521691158414
tyger tyger, shining bright, in the middle of the night ; what immortal hand or eye, could break thy <broken> symmetry?
-----------------------
Iteration 5, score = 0.0005108318291604519
tyger tyger, shining bright, in the middle of the night ; what immortal hand or eye, could break <this> broken symmetry?
-----------------------
Iteration 6, score = 0.01689990423

'tyger tyger, kitty cat, in the posture of a bat ; no protruding fangs or claws, what is this furry animal?'

In [42]:
text = '''O Rose thou art sick. 
The invisible worm, 
That flies in the night 
In the howling storm: 

Has found out thy bed 
Of crimson joy: 
And his dark secret love 
Does thy life destroy.'''

depoeticize(text,
            model_name='bert-base-uncased',
            max_iterations=50,
            randomize=False,
            cooldown=0.05,
            stop_score=1.0,
            match_meter=True,
            match_rhyme=True,
            preserve_punctuation=True,
            topic=None,
            strong_topic_bias=2.0,
            modifier=None)#metalness_modifier())

-----------------------
Iteration 1, score = 0.00013018323807045817
o <god> thou art sick. the invisible worm, that flies in the night in the howling storm : has found out thy bed of crimson joy : and his dark secret love does thy life destroy.
-----------------------
Iteration 2, score = 0.004045261535793543
o god thou art sick. the invisible worm, that flies in the night in the howling storm : has found out thy <source> of crimson joy : and his dark secret love does thy life destroy.
-----------------------
Iteration 3, score = 0.005103247240185738
o god thou art <free>. the invisible worm, that flies in the night in the howling storm : has found out thy source of crimson joy : and his dark secret love does thy life destroy.
-----------------------
Iteration 4, score = 0.010653567500412464
o god thou art free. the invisible worm, that flies in the night in the howling storm : has found out thy source of <deepest> joy : and his dark secret love does thy life destroy.
-----------------

'by god thou art blessed. the invisible man, who walks in the night in a hooded cloak : has found both his source of body heat : and his own power that makes his life complete.'

In [45]:
text = '''O Rose thou art sick. 
The invisible worm, 
That flies in the night 
In the howling storm: 

Has found out thy bed 
Of crimson joy: 
And his dark secret love 
Does thy life destroy.'''

depoeticize(text,
            model_name='roberta-base',
            max_iterations=50,
            randomize=False,
            cooldown=0.05,
            stop_score=1.0,
            match_meter=True,
            match_rhyme=True,
            preserve_punctuation=True,
            topic=None,
            strong_topic_bias=2.0,
            modifier=None)#metalness_modifier())

-----------------------
Iteration 1, score = 9.045047772815451e-05
 O Rose thou< his> sick. The invisible worm, That flies in the night In the howling storm: Has found out thy bed Of crimson joy: And his dark secret love Does thy life destroy.
-----------------------
Iteration 2, score = 9.544840577291325e-05
 O< God> thou his sick. The invisible worm, That flies in the night In the howling storm: Has found out thy bed Of crimson joy: And his dark secret love Does thy life destroy.
-----------------------
Iteration 3, score = 0.0001600701070856303
 O God thou his sick. The invisible< sweet>, That flies in the night In the howling storm: Has found out thy bed Of crimson joy: And his dark secret love Does thy life destroy.
-----------------------
Iteration 4, score = 0.0010359423467889428
 O God thou his sick. The invisible sweet, That flies in the night In the< winter> storm: Has found out thy bed Of crimson joy: And his dark secret love Does thy life destroy.
-----------------------
It

' They are his cold hands. the aluminum shards, his hands in the snow and the melting ice: that carve out his heart from molten clay: and his cold fingers that take his life away.'

In [11]:
text = '''Tyger Tyger, burning bright, 
In the forests of the night; 
What immortal hand or eye, 
Could frame thy fearful symmetry?'''

depoeticize(text,
            model_name='bert-base-uncased',
            max_iterations=200,
            randomize=0.0,
            cooldown=0.02,
            stop_score=1.0,
            match_meter=False,
            match_rhyme=False,
            preserve_punctuation=False,
            discourage_repetition=0.1,
            topic="computational language modeling with artificial neural networks",
            strong_topic_bias=2.0,
            modifier=None)#metalness_modifier())

-----------------------
Iteration 1, score = 1.6370143088306754e-11
tyger tyger, burning bright, in the forests of the night ; what immortal hand or eye, could frame <create> fearful symmetry?
-----------------------
Iteration 2, score = 2.7015327752621943e-09
tyger tyger, burning bright, in the forests of the night ; what immortal hand or eye, could <you> create fearful symmetry?
-----------------------
Iteration 3, score = 2.5341921627841657e-07
tyger tyger, burning bright, in the <middle> of the night ; what immortal hand or eye, could you create fearful symmetry?
-----------------------
Iteration 4, score = 4.249602625350235e-07
tyger tyger, burning bright, in the middle of the night ; what immortal hand or eye, could you create <such> symmetry?
-----------------------
Iteration 5, score = 5.688901794087542e-09
tyger tyger, burning bright, in the middle of the night ; what immortal hand or eye, could you create such <things>?
-----------------------
Iteration 6, score = 1.182900604

'this poem is about noticing that in the middle of the poem exists a computational neural network. how can humans do such things :'

In [167]:
text = '''[Tyger Tyger], burning bright, 
In the forests of the night; 
What immortal hand or eye, 
Could frame thy fearful symmetry?'''

depoeticize(text,
            model_name='roberta-base',
            max_iterations=100,
            randomize=0.0,
            cooldown=0.05,
            stop_score=1.0,
            match_meter=True,
            match_rhyme=True,
            preserve_punctuation=True,
            topic="flowers",
            strong_topic_bias=2.0,
            modifier=None)#metalness_modifier())

-----------------------
Iteration 1, score = 2.446381097698236e-09
 Tyger Tyger,< Standing> bright, In the forests of the night; What immortal hand or eye, Could frame thy fearful symmetry?
-----------------------
Iteration 2, score = 5.7247134463978e-06
 Tyger Tyger, Standing< high>, In the forests of the< sky>; What immortal hand or eye, Could frame thy fearful symmetry?
-----------------------
Iteration 3, score = 3.0911166959413094e-06
 Tyger Tyger, Standing high, In the< centre> of the sky; What immortal hand or eye, Could frame thy fearful symmetry?
-----------------------
Iteration 4, score = 2.5605545488360804e-06
 Tyger Tyger, Standing high, In the centre of the sky; What immortal hand or eye, Could frame< you> fearful symmetry?
-----------------------
Iteration 5, score = 0.00012871516810264438
 Tyger Tyger, Standing high, In the centre of the sky; What immortal hand or eye, Could< give> you fearful symmetry?
-----------------------
Iteration 6, score = 6.3166094150801655e-06

' Tyger Tyger, Flowers lay, In the middle of the day; Are poetic words just words, To give you extra energy?'

In [26]:
text = '''Shall I compare thee to a summer's day? 
Thou art more lovely and more fascinate:
Rough winds do shake the darling buds of May,
And summer's lease hath all too short a date: '''

depoeticize(text,
            model_name='bert-base-uncased',
            max_iterations=100,
            randomize=0.0,
            cooldown=0.05,
            stop_score=1.0,
            match_meter=True,
            match_rhyme=True,
            preserve_punctuation=True,
            topic="flowers",
            strong_topic_bias=2.0,
            sequential=True,
            modifier=None)#metalness_modifier())

-----------------------
Iteration 16, score = 0.003925323020666838
shall i compare thee to a summer's day? thou art more lovely <the> more fascinate : rough winds do shake the darling buds of may, and summer's lease hath all too short a date :
-----------------------
Iteration 18, score = 0.0008427134682307269
shall i compare thee to a summer's day? thou art more lovely the more <evergreen> : rough winds do shake the darling buds of may, and summer's lease hath all too short a <green> :
-----------------------
Iteration 20, score = 0.07224021852016449
shall i compare thee to a summer's day? thou art more lovely the more evergreen : <brisk> winds do shake the darling buds of may, and summer's lease hath all too short a green :
-----------------------
Iteration 25, score = 0.7040115594863892
shall i compare thee to a summer's day? thou art more lovely the more evergreen : brisk winds do shake the <berry> buds of may, and summer's lease hath all too short a green :
-----------------------

"shall i compare thee to a summer's day? thou art more lovely the more evergreen : brisk winds do shake the berry leaves of may, a summer lease hath all too short a green :"