In [1]:
# Start by installing required libraries (mainly Transformers)
!pip install transformers==4.17.0
!pip install scikit-learn
!pip install hydra-core
!pip install pronouncing

Collecting transformers==4.17.0
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 11.4 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 14.4 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 8.3 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 58.0 MB/s 
Collecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 54.3 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyya

Collecting pronouncing
  Downloading pronouncing-0.2.0.tar.gz (17 kB)
Collecting cmudict>=0.4.0
  Downloading cmudict-1.0.2-py2.py3-none-any.whl (939 kB)
[K     |████████████████████████████████| 939 kB 15.5 MB/s 
[?25hBuilding wheels for collected packages: pronouncing
  Building wheel for pronouncing (setup.py) ... [?25l[?25hdone
  Created wheel for pronouncing: filename=pronouncing-0.2.0-py2.py3-none-any.whl size=6252 sha256=81b6a2cd89c11a1185d01c8af40ce9ac0940192e48984e2c73a338aadbeae79e
  Stored in directory: /root/.cache/pip/wheels/09/e8/c0/3606d42fdbf5f3871564eb6a353591a8f5deeed013fdb73921
Successfully built pronouncing
Installing collected packages: cmudict, pronouncing
Successfully installed cmudict-1.0.2 pronouncing-0.2.0


In [2]:
# Only needed when running in colab
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

Mounted at /content/drive/


In [3]:
!git clone https://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@github.com/coderalo/11785-automatic-poetry-generation.git

Cloning into '11785-automatic-poetry-generation'...
remote: Enumerating objects: 101, done.[K
remote: Counting objects: 100% (101/101), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 101 (delta 36), reused 56 (delta 12), pack-reused 0[K
Receiving objects: 100% (101/101), 11.32 MiB | 5.17 MiB/s, done.
Resolving deltas: 100% (36/36), done.


In [4]:
import copy
import glob
import json
import math
import numpy as np
import os
import pronouncing
import random
import shutil
import string as string_utils
import sys
import tempfile
import torch
import torch.nn.functional as F
import torch.optim as optim
import tqdm.notebook as tqdm
import yaml

from hydra import compose
from hydra import initialize_config_dir
from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM
from transformers import GPT2LMHeadModel
from transformers import GPT2Model
from transformers import GPT2Tokenizer

In [5]:
%load_ext autoreload
%autoreload 2

sys.path.append("/content/11785-automatic-poetry-generation/")

from src.dataset import merge_lines, reorder, reverse_line
from src.dataset import LimerickDataset
from src.utils import load_dataset, get_tokenizer

In [6]:
def get_input_ids(
        prompt,
        tokenizer,
        use_bos,
        reverse,
        add_line_token
):
    """
    Arguments:
        prompt: str
        tokenizer: the tokenizer used to generate tokens
        use_bos: bool, use <BOS> token as the beginning of the prompt or not
        reverse: bool, revert the word order or not
        add_line_token: bool, add the <LINE> token at the end of prompt or not
    Return:
        input_ids: torch.LongTensor
    """
    prompt = prompt.strip()
    if add_line_token:
        if prompt != "" and prompt[-6:] != "<LINE>":
            prompt += " <LINE>"
    if use_bos and prompt[:5] != "<BOS>":
        prompt = "<BOS> " + prompt

    if reverse is True:
        input_ids = reverse_line(
            tokenizer(prompt, return_tensors="np").input_ids[0],
            use_bos, tokenizer)
        input_ids = torch.tensor(input_ids).reshape(1, -1)
    else:
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    return input_ids

In [7]:
def batch_decode(outputs, tokenizer, use_bos, reverse):
    """
    Arguments:
        outputs: List of torch.LongTensor
        tokenizer: the tokenizer used to decode tokens to words
        use_bos: bool, whether the <BOS> token is used or not
        reverse: bool, whether the tokens are in reverse order or not
    """
    if reverse is True:
        reversed = []
        for output in outputs:
            output = torch.tensor(
                reverse_line(
                    output.cpu().numpy(),
                    use_bos, tokenizer)
                ).reshape(-1)
            reversed.append(output)
        outputs = torch.stack(reversed)
    else:
        outputs = torch.stack(outputs)

    outputs = tokenizer.batch_decode(outputs.cpu(), skip_special_tokens=False)

    return outputs

In [8]:
def count_lines(prompt):
    return len(prompt.strip().split("<LINE>")) - 1

In [9]:
def generate_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size,
        add_line_token
):
    """
    Generate / finish one line of the limerick. The prompts should be in the 
    correct word order (you don't need to revert the words before passing into
    the function)
    """
    use_bos = config.data.use_bos
    reverse = config.data.reverse
    order = config.data.order

    # Step 1: concat the input ids into a large tensor
    full_input_ids = []
    num_lines = []
    for prompt in prompts:
        num_lines = count_lines(prompt)
        input_ids = get_input_ids(
            prompt, tokenizer,
            use_bos, reverse, add_line_token)
        input_ids = input_ids.repeat(num_generation, 1)
        full_input_ids.append(input_ids)

    full_input_ids = torch.cat(input_ids, dim=0)
    num_batches = math.ceil(full_input_ids.shape[0] / batch_size)

    # assume that a line cannot be longer than 30 tokens
    tmp_params = copy.deepcopy(generate_params)
    if "max_length" in tmp_params:
        tmp_params.pop("max_length")
    tmp_params["max_new_tokens"] = 30

    # Step 2: pass the batch into model to get generation output
    outputs = []
    for i in tqdm.trange(num_batches, leave=False):
        input_ids = full_input_ids[i * batch_size: (i + 1) * batch_size]
        input_ids = input_ids.to(device=config.device)
        with torch.no_grad():
            output = model.generate(
                input_ids, **tmp_params,
                pad_token_id=tokenizer.eos_token_id)
            output = torch.unbind(output)
            outputs.extend(output)
    
    # Step 3: convert the generation result back to strings
    outputs = batch_decode(outputs, tokenizer, use_bos, reverse)

    clean_outputs = []
    for output in outputs:
        new_num_lines = count_lines(output)
        if new_num_lines < num_lines + 1:
            continue
        output = output.strip().split(" <LINE> ")[:num_lines + 1]
        output = " <LINE> ".join(output) + " <LINE>"
        clean_outputs.append(output)
  
    return clean_outputs

In [10]:
def generate_new_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size
):
    return generate_lines(
        model=model,
        tokenizer=tokenizer,
        config=config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation,
        batch_size=batch_size,
        add_line_token=True)
    

def finish_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size
):
    return generate_lines(
        model=model,
        tokenizer=tokenizer,
        config=config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation,
        batch_size=batch_size,
        add_line_token=False)

In [11]:
def generate_limericks(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation=10,
        batch_size=1,
        add_line_token=True,
):
    use_bos = config.data.use_bos
    reverse = config.data.reverse
    order = config.data.order

    # Step 1: concat the input ids into a large tensor
    full_input_ids = []
    num_lines = []
    for prompt in prompts:
        num_lines = count_lines(prompt)
        input_ids = get_input_ids(
            prompt, tokenizer,
            use_bos, reverse, add_line_token)
        input_ids = input_ids.repeat(num_generation, 1)
        full_input_ids.append(input_ids)

    full_input_ids = torch.cat(input_ids, dim=0)
    num_batches = math.ceil(full_input_ids.shape[0] / batch_size)

    # Step 2: pass the batch into model to get generation output
    outputs = []
    for i in tqdm.trange(num_batches, leave=False):
        input_ids = full_input_ids[i * batch_size: (i + 1) * batch_size]
        input_ids = input_ids.to(device=config.device)
        with torch.no_grad():
            output = model.generate(
                input_ids, **tmp_params,
                pad_token_id=tokenizer.eos_token_id)
            output = torch.unbind(output)
            outputs.extend(output)

    # Step 3: convert the generation result back to strings
    outputs = batch_decode(outputs, tokenizer, use_bos, reverse)
    clean_outputs = []

    for output in outputs:
        new_num_lines = count_lines(output)
        if new_num_lines < 5:
            continue
        output = output.strip().split(" <LINE> ")[:5]
        output = " <LINE> ".join(output) + " <LINE>"
        clean_outputs.append(output)

    return clean_outputs

In [12]:
def generate_limericks_two_stage(
        standard_lm,
        reverse_lm,
        standard_tokenizer,
        reverse_tokenizer,
        standard_config,
        reverse_config,
        prompts,
        generate_params,
        num_generation_1=10,
        num_generation_2=1,
        batch_size=1,
):

    limericks = []
    for prompt in tqdm.tqdm(prompts, leave=False):
        # generate first line
        first_lines = finish_lines(
            model=standard_lm,
            tokenizer=standard_tokenizer,
            config=standard_config,
            prompts=[prompt],
            generate_params=generate_params,
            num_generation=num_generation_1,
            batch_size=batch_size)

        outputs = generate_limericks(
            model=reverse_lm,
            tokenizer=reverse_tokenizer,
            config=reverse_config,
            prompts=first_lines,
            generate_params=generate_params,
            num_generation=num_generation_2,
            batch_size=batch_size)
        
        limericks.extend(outputs)

    return limericks

In [79]:
def get_last_words(prompt):
    prompt = prompt.split(' ')
    
    words = []
    for i, word in enumerate(prompt):
        if word == "<LINE>":
            words.append(prompt[i - 1])

    return words


def get_current_rhymes(prompt, tokenizer, allow_repetition=False):
    num_lines = count_lines(prompt)
    words = get_last_words(prompt)

    if num_lines in [0, 2]:  # first A or first B
        return [], []
    elif num_lines in [1, 4]:  # 2nd and 3rd A in AABBA
        if num_lines == 1:
            words = [words[0]]
        else:
            words = [words[0], words[1]]
    elif num_lines == 3:
        words = [words[2]]

    rhymes = set()
    for word in words:
        rhymes.update(pronouncing.rhymes(word))
    if not allow_repetition:
        for word in words:
            if word in rhymes:
                rhymes.remove(word)
    rhymes = list(rhymes)

    if rhymes != []:
        rhyme_tokens = [
            rhyme[::-1] for rhyme in tokenizer(rhymes)['input_ids']]
    else:
        rhyme_tokens = []

    return rhyme_tokens, rhymes

In [57]:
def pad_tokens(tokens, tokenizer, max_len):
    padded_tokens = [
        tokens_ + [tokenizer.pad_token_id] * (max_len - len(tokens))
        for tokens_ in tokens]
    attention_mask = [
        [1.] * len(tokens_) + [0.] * (max_len - len(tokens_))
        for tokens_ in tokens]

    padded_words = torch.tensor(padded_tokens, dtype=torch.long)
    attention_mask = torch.tensor(attention_mask, dtype=torch.float)

    return padded_words, attention_mask


def lengths_to_mask(lengths, dtype, device):
    max_len = lengths.max().item()
    mask = torch.arange(
        max_len,
        dtype=lengths.dtype,
        device=lengths.device)
    mask = mask.expand(len(lengths), max_len)
    mask = (mask < lengths.unsqueeze(1))

    mask = mask.clone().detach()
    mask = mask.to(dtype=dtype, device=device)
    
    return mask

In [None]:
def get_rhyming_word_score(
        reverse_lm,
        tokenizer,
        config,
        prompts,
        rhymes,
        temperature,
        batch_size=64
):
    """
    Step 1: 
        generate input ids for each prompts (not concatenated now)
        also collect the max rhyme (tokens) len for next step
    """
    lengths, max_rhyme_len = [], 0
    input_ids_list = []
    for prompt, rhymes_ in zip(prompts, rhymes):
        input_ids = get_input_ids(
            prompt=prompt,
            tokenizer=tokenizer,
            use_bos=config.data.use_bos,
            reverse=True,
            add_line_token=False)
        
        # [l_0, ..., l_0, l_1, ..., l_1, ...]
        lengths.extend([input_ids.shape[1]] * len(rhymes))
        input_ids = input_ids.repeat(len(rhymes), 1)
        input_ids_list.append(input_ids)
 
        rhyme_len = max([len(rhyme) for rhyme in rhymes_])
        max_rhyme_len = max(max_rhyme_len, rhyme_len)

    """
    Step 2:
        generate input ids for each rhyme word list to concat with prompts
        the attention mask is generated to calculate the scores later
    """
    padded_rhymes_list = []
    rhyme_masks = []
    for rhymes_ in rhymes:
        padded_rhymes, attention_mask = \
            pad_tokens(rhymes, tokenizer, max_rhyme_len)
        padded_rhymes_list.append(padded_rhymes)
        rhyme_masks.append(attention_mask)

    padded_rhymes = torch.cat(padded_rhymes, dim=0)
    rhyme_masks = torch.cat(rhyme_masks, dim=0)

    """
    Step 3:
        concat the input ids of prompts with rhyme words
        also need to pad them to the same length for batching
    """
    input_ids_list = [
        torch.cat([input_ids, padded_rhymes], dim=1)
        for input_ids, padded_rhymes in
        zip(input_ids_list, padded_rhymes_list)]

    max_seq_len = max([input_ids.shape[1] for input_ids in input_ids_list])
    input_ids_list = [
        torch.cat(
            [
                input_ids,
                torch.full(
                    (input_ids.shape[0], max_seq_len - input_ids.shape[1]),
                    fill_value=tokenizer.pad_token_id,
                    dtype=torch.long, device="cpu")
            ], dim=1)
        for input_ids in input_ids_list]

    full_input_ids = torch.cat([input_ids_list], dim=0)
    num_examples = full_input_ids.shape[0]
    num_batches = math.ceil(num_examples / batch_size)

    lengths = torch.tensor(lengths, dtype=torch.long)
    total_lengths = lengths + max_rhyme_len
    attention_masks = lengths_to_mask(total_lengths, torch.float, "cpu")

    """
    Step 4:
        pass the batches into model to get logits, which then are converted
        into log probs and aggregated to get the final scores
    """
    full_logits = []
    for i in tqdm.trange(num_batches, leave=False):
        input_ids = full_input_ids[i * batch_size: (i + 1) * batch_size]
        attention_mask = attention_masks[i * batch_size: (i + 1) * batch_size]
        input_ids = input_ids.to(device=config.device)
        attention_mask = attention_mask.to(device=config.device)
        with torch.no_grad():
            logits = reverse_lm(input_ids, attention_mask)['logits']
            full_logits.append(logits)
    
    # [num_examples * max_seq_len, vocab_size]
    full_logits = torch.cat(full_logits, dim=0)
    full_logits = full_logits.reshape(-1, full_logits.shape[-1])

    # [num_examples * max_rhyme_len]
    offsets = (torch.arange(0, num_examples) * max_seq_len)
    offsets = offsets.reshape(-1, 1).repeat(1, max_rhyme_len)
    indices = (offsets + lengths.reshape(-1, 1)).reshape(-1)

    # [num_examples, max_rhyme_len, vocab_size]
    logits = torch.index_select(full_logits, 0, indices)
    logits = logits.reshape(num_examples, max_rhyme_len, -1)

    # [num_examples, max_rhyme_len]
    log_probs = F.softmax(logits, -1)
    scores = torch.gather(log_probs, 2, padded_rhymes.unsqueeze(2)).squeeze()

    # [num_examples]
    scores = torch.sum(scores * rhyme_masks, dim=1).cpu().numpy()

    """
    Step 5:
        split the final results back into array for each prompt
    """
    probs_list, anchor = [], 0
    for rhymes_ in rhymes:
        probs = np.exp(scores[anchor: anchor + len(rhymes_)])
        probs /= np.sum(probs)
        probs_list.append(probs)
        anchor += len(rhymes_)

    return probs_list

In [None]:
def attach_next_rhyming_word(
        reverse_lm,
        tokenizer,
        config,
        prompts,
        num_samples,
        weighted,
        temperature=None,
        batch_size=64
):
    prompts_with_next_word = [None for _ in prompts]
    prompts_with_rhymes, prompts_without_rhymes = [], []
    for idx, prompt in enumerate(prompts):
        tokens, words = get_current_rhymes(prompt, tokenizer)
        if tokens != []:
            prompts_with_rhymes.append([idx, prompt, tokens, words])
        else:
            prompts_without_rhymes.append([idx, prompt])

    if weighted:
        probs_list = get_rhyming_word_score(
            reverse_lm=reverse_lm,
            tokenizer=tokenizer,
            config=config,
            prompts=[p[1] for p in prompts_with_rhymes],
            rhymes=[p[2] for p in prompts_with_rhymes],
            temperature=(1.0 if temperature is None else temperature),
            batch_size=batch_size)
    else:
        probs_list = [
            np.ones(len(p[3])) / len(p[3])
            for p in prompts_with_rhymes]

    for prompt_info, probs in zip(prompts_with_rhymes, probs_list):
        idx, prompt, _, words = prompt_info
        samples = np.random.choice(len(words), num_samples, p=probs)
        prompts_with_next_word[prompt_info[0]] = \
            [f"{prompt} {words[s]}" for s in samples]

    for idx, prompt in prompts_without_rhymes:
        prompts_with_next_word[idx] = [prompt] * num_samples

    prompts_with_next_word = sum(prompts_with_next_word)

    return prompts_with_next_word

In [83]:
def generate_limericks_with_rhyming(
        reverse_lm,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation=10,
        batch_size=10,    
):
    
    limericks = []
    prompt = ""

    prompts = generate_new_lines(
        model=reverse_lm,
        tokenizer=tokenizer,
        config=config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation,
        batch_size=batch_size)
  
    for prompt in prompts:
        print(prompt)
    
    for _ in range(4):
        new_prompts = attach_next_rhyming_word(
            reverse_lm=reverse_lm,
            tokenizer=tokenizer,
            config=config,
            prompts=prompts,
            num_samples=1,
            weighted=True,
            temperature=1.0)
        prompts = finish_lines(
            model=reverse_lm,
            tokenizer=tokenizer,
            config=config,
            prompts=new_prompts,
            generate_params=generate_params,
            num_generation=1,
            batch_size=1)
        
        for prompt in prompts:
            print(prompt)
        
    return prompts

In [119]:
def generate_limericks_two_stage_with_rhyming(
        standard_lm,
        reverse_lm,
        standard_tokenizer,
        reverse_tokenizer,
        standard_config,
        reverse_config,
        prompts,
        generate_params,
        num_generation_1=10,
        num_generation_2=1,
        batch_size=1,
):

    limericks = []
    for prompt in tqdm.tqdm(prompts, leave=False):
        # generate first line
        lines = finish_lines(
            model=standard_lm,
            tokenizer=standard_tokenizer,
            config=standard_config,
            prompts=[prompt],
            generate_params=generate_params,
            num_generation=num_generation_1,
            batch_size=batch_size)

        for _ in range(4):
            lines = attach_next_rhyming_word(
                reverse_lm=reverse_lm,
                tokenizer=tokenizer,
                config=config,
                prompts=lines,
                num_samples=1,
                weighted=False,
                temperature=1.0)
            print(lines[0])
            lines = finish_lines(
                model=reverse_lm,
                tokenizer=tokenizer,
                config=config,
                prompts=lines,
                generate_params=generate_params,
                num_generation=1,
                batch_size=1)
        
        limericks.extend(lines)

    return limericks

In [17]:
def load_model(exp_dir, tmp_root="/content/test/"):
    config = OmegaConf.create(yaml.safe_load(open(exp_dir + "/config.yaml")))
    tokenizer = GPT2Tokenizer.from_pretrained(f"{exp_dir}/tokenizer")

    if not os.path.exists(tmp_root):
        os.makedirs(tmp_root, exist_ok=True)
    tmp_dir = tempfile.mkdtemp(dir=tmp_root)
    states = torch.load(f"{exp_dir}/best-model.ckpt")
    
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    model.resize_token_embeddings(len(tokenizer))
    model = model.cuda()
    model.load_state_dict(states['model_state_dict'])
    model.save_pretrained(tmp_dir)
    new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
    new_model = new_model.cuda()

    return config, tokenizer, new_model

In [133]:
exp_dir = f"/content/drive/MyDrive/11-785-final/ckpt/reverse-bos-gpt2"
config, tokenizer, model = load_model(exp_dir)

In [130]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = generate_limericks(
    model,
    tokenizer,
    config,
    ["once upon a time"],
    generate_params,
    num_generation=50,
    batch_size=10,
    add_line_token=False)

for res in results:
    print(res)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

<BOS> in the time of the time a upononce <LINE> from a tree that came down to the snow <LINE> with the melting ice <LINE> when it looked pretty nice <LINE> now-boiled up, it became aglow <LINE>
<BOS> when there was a time a upononce <LINE> lived so well, didn't seem to bewhite <LINE> how things changed, then you'd know <LINE> you had no way to go <LINE> then you had the way that turned clean <LINE>
<BOS> he knew every time a upononce <LINE> clad in leather and leather and sheen <LINE> he would brighten his skin <LINE> thrusting gold from his chin <LINE> of a man, wearing leather and sheen <LINE>
<BOS> every time a upononce <LINE> where a blue people used to be white <LINE> that was yellow in hue <LINE> with a name that was blue <LINE> now it's blue was not red, black-and-white <LINE>
<BOS> in england, it's time a upononce <LINE> finding walks in the woods, by the fawn <LINE> there's a hole where that grasses <LINE> with no trees and no grasses <LINE> and its gnatiness always is gone <L

In [None]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = generate_limericks_with_rhyming(
    model,
    tokenizer,
    config,
    [""],
    generate_params)
""
for res in results:
    print(res)

In [None]:
with open("rhyming.txt", 'w') as file:
    for result in results:
        result = result.replace("<BOS>", "").strip().split(" <LINE> ")
        result = [line for line in result if line != ""]
        file.write('\n'.join(result))
        file.write('\n\n')

In [None]:
inputs = tokenizer("<BOS> once upon a", return_tensors="pt")
for key, value in inputs.items():
    inputs[key] = value.cuda()

In [121]:
standard_exp_dir = "/content/drive/MyDrive/11-785-final/ckpt/bos-gpt2"
reverse_exp_dir = "/content/drive/MyDrive/11-785-final/ckpt/reverse-bos-gpt2"

standard_config, standard_tokenizer, standard_model = \
    load_model(standard_exp_dir)
reverse_config, reverse_tokenizer, reverse_model = \
    load_model(reverse_exp_dir)

In [135]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = generate_limericks_two_stage(
    standard_model,
    reverse_model,
    standard_tokenizer,
    reverse_tokenizer,
    standard_config,
    reverse_config,
    ["once upon a time"],
    generate_params=generate_params,
    num_generation_1=50,
    num_generation_2=10,
    batch_size=10)

results = [
    [
        line.strip() for line in 
        result.replace("<BOS> ", "").split("<LINE>")
    ] for result in results
]

results = [
    [line for line in result if line != ""]
    for result in results
]

print(results[0])

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

['once upon a time some poor dame', 'though his name was deserving of fame', 'when he claimed that his name', 'not a matter of fame', 'when he died, he enjoyed such acclaim']


In [122]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = generate_limericks_two_stage_with_rhyming(
    standard_model,
    reverse_model,
    standard_tokenizer,
    reverse_tokenizer,
    standard_config,
    reverse_config,
    [""],
    generate_params=generate_params,
    num_generation_1=1,
    num_generation_2=1,
    batch_size=1)

results = [
    [
        line.strip() for line in 
        result.replace("<BOS> ", "").split("<LINE>")
    ] for result in results
]

results = [
    [line for line in result if line != ""]
    for result in results
]

print(results[0])

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

<BOS> my son, a fine schoolteacher named dave <LINE> mave


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

<BOS> my son, a fine schoolteacher named dave <LINE> has left with the lessons. he, mah-mave <LINE> ened


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

<BOS> my son, a fine schoolteacher named dave <LINE> has left with the lessons. he, mah-mave <LINE> these lessons he careened <LINE> screened


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

<BOS> my son, a fine schoolteacher named dave <LINE> has left with the lessons. he, mah-mave <LINE> these lessons he careened <LINE> he lives so unscreened <LINE> nave


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

['my son, a fine schoolteacher named dave', 'has left with the lessons. he, mah-mave', 'these lessons he careened', 'he lives so unscreened', "but he's always-schooled-in-a-nave"]


In [None]:
with open("free_form_5000.txt", 'w') as file:
    for result in results:
        file.write('\n'.join(result))
        file.write('\n\n')