<a href="https://colab.research.google.com/github/coderalo/11785-automatic-poetry-generation/blob/main/src/GPT2_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Note about the dataset
You should start by running the data preprocessing code in the github repo (`data/preprocessing/get_data.ipynb`) or just clone the repo to get a copy of `limericks.json`, which is then used to finetune the GPT-2 model.

In [1]:
# Start by installing required libraries (mainly Transformers)
!pip install transformers==4.17.0
!pip install scikit-learn
!pip install hydra-core

Collecting transformers==4.17.0
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 8.6 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 6.2 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 47.8 MB/s 
[?25hCollecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 43.9 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 56.1 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    

In [2]:
# Only needed when running in colab
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

Mounted at /content/drive/


In [42]:
import copy
import glob
import json
import math
import numpy as np
import os
import random
import shutil
import string as string_utils
import tempfile
import torch
import torch.optim as optim
import tqdm.notebook as tqdm
import yaml

from hydra import compose
from hydra import initialize_config_dir
from omegaconf import OmegaConf
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2LMHeadModel
from transformers import GPT2Model
from transformers import GPT2Tokenizer
from transformers import AdamW, get_scheduler

In [4]:
config_path = "/content/drive/MyDrive/11-785-final/config/"
if not os.path.exists(config_path):
    os.makedirs(config_path, exist_ok=True)
    open(f"{config_path}/__init__.py", 'a').close()

initialize_config_dir(config_path)

hydra.initialize_config_dir()

In [6]:
class LimerickDataset(Dataset):
    def __init__(self, data, use_bos, order):
        self.data = [
            merge_lines(limerick, use_bos, order)
            for limerick in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [19]:
def load_dataset(config):
    data = json.load(open(f"{config.data.data_dir}/limericks.json"))
    limericks = []

    for _, limerick in data['limericks'].items():
        lines = limerick['lines']
        flag = True

        # Remove the final punctuation of each line
        # (we'll use a special separator instead)
        for idx, line in enumerate(lines):
            if len(line) == 0:
                flag = False
                break
            if line[-1] in string_utils.punctuation:
                lines[idx] = line[:-1]
        
        if flag:
            limericks.append(lines)

    print(f"# of limericks before clean-up: {len(data['limericks'])}")
    print(f"# of limericks after clean-up: {len(limericks)}")

    return limericks

In [5]:
# We can construct a training sample of limericks by merging the lines
# with the separator attached at the end of each line
def merge_lines(lines, use_bos, order=None):
    if order is not None:
        try:
            order = list(order)
        except Exception:
            return
        assert isinstance(order, list)
        assert sorted(order) == [0, 1, 2, 3, 4]

        lines = [lines[o] for o in order]

    words = ' <LINE> '.join(lines) + ' <LINE>'
    if use_bos:
        words = '<BOS> ' + words

    words = ' '.join(words.split())

    return words


def reorder(lines, order=None):
    if order is None:
        return lines 
    else:
        new = [(o, i) for i, o in enumerate(order)]
        new = sorted(new)
        new = [o[1] for o in new]

        lines = [lines[o] for o in new]

    return lines

In [7]:
def reverse_line(input_ids, use_bos):
    new_input_ids = np.zeros_like(input_ids)
    if use_bos:
        new_input_ids[0] = input_ids[0]
        start = 1
    else:
        start = 0

    for end in range(1, len(input_ids)):
        if input_ids[end] == tokenizer.sep_token_id:
            new_input_ids[start: end] = input_ids[start: end][::-1]
            new_input_ids[end] = tokenizer.sep_token_id
            start = end + 1
    new_input_ids[start:] = input_ids[start:]
    return new_input_ids

def gen_collate_fn(tokenizer, reverse=False, use_bos=False):
    def collate_fn(batch):
        if not reverse:
            batch = tokenizer(batch, padding="longest", return_tensors="pt")
        else:
            batch = tokenizer(batch, padding="longest", return_tensors="np")
            for i, input_ids in enumerate(batch['input_ids']):
                batch['input_ids'][i] = reverse_line(
                    batch['input_ids'][i], use_bos=use_bos)
            batch['input_ids'] = torch.tensor(batch['input_ids'])
            batch['attention_mask'] = torch.tensor(batch['attention_mask'])
        batch['labels'] = torch.clone(batch['input_ids']).detach()

        for key, value in batch.items():
            batch[key] = value.cuda()
        return batch

    return collate_fn

In [8]:
# finish configuration
config = compose(config_name="config")
config.exp_name = "reverse-bos-reordered-gpt2"
config.data.reverse = True
config.data.order = [0, 1, 4, 2, 3]

assert config.exp_name is not None
print(OmegaConf.to_yaml(config))

data:
  data_dir: /content/drive/MyDrive/11-785-final/data/
  ckpt_dir: /content/drive/MyDrive/11-785-final/ckpt/
  reverse: true
  use_bos: true
  order:
  - 0
  - 1
  - 4
  - 2
  - 3
training:
  learning_rate: 5.0e-05
  weight_decay: 0.0
  scheduler_type: linear
  num_warmup_steps: 0
  epochs: 20
  batch_size: 32
  gradient_accumulation_steps: 1
exp_name: reverse-bos-reordered-gpt2
debug: false
device: cuda



In [9]:
os.makedirs(config.data.ckpt_dir, exist_ok=True)
exp_dir = f"{config.data.ckpt_dir}/{config.exp_name}"
os.makedirs(exp_dir, exist_ok=True)
log_file = f"{exp_dir}/log.txt"

with open(f"{exp_dir}/config.yaml", 'w') as file:
    file.write(OmegaConf.to_yaml(config))

In [20]:
limericks = load_dataset(config)

# of limericks before clean-up: 72432
# of limericks after clean-up: 72431


In [12]:
# We'll use a new special token <LINE> as the separator between lines
# Also notice that we add the pad_token for padding purpose, but it should be
# masked out (i.e. ineffective) by using attention_mask throughout the training
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

special_tokens = {
    "sep_token": "<LINE>",
    "pad_token": "<PAD>",
    "bos_token": "<BOS>"
}

if not config.data.use_bos:
    special_tokens.pop("bos_token")

tokenizer.add_special_tokens(special_tokens)

for key in special_tokens:
    print(key)
    print(
        f"New {key}: {getattr(tokenizer, key)} "
        f"({getattr(tokenizer, key + '_id')})")
    
tokenizer.save_pretrained(f"{exp_dir}/tokenizer")

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

sep_token
New sep_token: <LINE> (50257)
pad_token
New pad_token: <PAD> (50258)
bos_token
New bos_token: <BOS> (50259)


('/content/drive/MyDrive/11-785-final/ckpt//reverse-bos-reordered-gpt2/tokenizer/tokenizer_config.json',
 '/content/drive/MyDrive/11-785-final/ckpt//reverse-bos-reordered-gpt2/tokenizer/special_tokens_map.json',
 '/content/drive/MyDrive/11-785-final/ckpt//reverse-bos-reordered-gpt2/tokenizer/vocab.json',
 '/content/drive/MyDrive/11-785-final/ckpt//reverse-bos-reordered-gpt2/tokenizer/merges.txt',
 '/content/drive/MyDrive/11-785-final/ckpt//reverse-bos-reordered-gpt2/tokenizer/added_tokens.json')

In [13]:
sample = random.sample(limericks, 1)[0]
string = merge_lines(sample, config.data.use_bos, config.data.order)
print(f"Lines with separator: {string}")
input_ids = tokenizer(string)['input_ids']
print(f"Tokens: {input_ids}")
decoded_string = tokenizer.decode(input_ids)
print(f"Decoding result: {decoded_string}")

Lines with separator: <BOS> the bull thistle's prickly and wild <LINE> annoying to all but a child <LINE> in wonderment?thistle-beguiled <LINE> who'll gaze at its flowers <LINE> (they're purple) for hours <LINE>
Tokens: [50259, 1169, 6473, 294, 12535, 338, 41409, 306, 290, 4295, 50257, 1236, 726, 278, 284, 477, 475, 257, 1200, 50257, 259, 4240, 434, 30, 400, 12535, 12, 1350, 5162, 3902, 50257, 8727, 1183, 17841, 379, 663, 12734, 50257, 7, 9930, 821, 14032, 8, 329, 2250, 50257]
Decoding result: <BOS> the bull thistle's prickly and wild <LINE> annoying to all but a child <LINE> in wonderment?thistle-beguiled <LINE> who'll gaze at its flowers <LINE> (they're purple) for hours <LINE>


In [14]:
np.random.seed(11785)
random.seed(11785)

train_data, val_data = train_test_split(limericks, train_size=0.9)
print(f"# of training samples: {len(train_data)}")
print(f"# of validation samples: {len(val_data)}")

# of training samples: 65187
# of validation samples: 7244


In [None]:
if not config.debug:
    train_dataset = LimerickDataset(
        train_data,
        config.data.use_bos,
        config.data.order)
    val_dataset = LimerickDataset(
        val_data,
        config.data.use_bos,
        config.data.order)
else:
    train_dataset = LimerickDataset(train_data[:config.training.batch_size * 8])
    val_dataset = LimerickDataset(val_data[:config.training.batch_size * 2])

train_loader = DataLoader(
    train_dataset,
    batch_size=config.training.batch_size,
    drop_last=True,
    shuffle=True,
    collate_fn=gen_collate_fn(
        tokenizer,
        reverse=config.data.reverse,
        use_bos=config.data.use_bos))
val_loader = DataLoader(
    val_dataset,
    batch_size=config.training.batch_size,
    drop_last=False,
    shuffle=False,
    collate_fn=gen_collate_fn(
        tokenizer,
        reverse=config.data.reverse,
        use_bos=config.data.use_bos))

In [None]:
# initialize the model, also resize the embeddings for new tokens
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))
model = model.cuda()

In [None]:
# Reference: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)],
        "weight_decay": config.training.weight_decay,
    },
    {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = optim.AdamW(
    optimizer_grouped_parameters,
    lr=config.training.learning_rate)

T_epoch = np.ceil(
    len(train_loader) //
    config.training.gradient_accumulation_steps)

scheduler = get_scheduler(
    name=config.training.scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=config.training.num_warmup_steps,
    num_training_steps=config.training.epochs * T_epoch)
scaler = torch.cuda.amp.GradScaler()

In [None]:
files = glob.glob(f"{exp_dir}/epoch-*.ckpt")
if len(files) != 0:
    files = sorted(files, key=lambda x: int(os.path.basename(x)[6:-5]))
    states = torch.load(files[-1])
    
    model.load_state_dict(states['model_state_dict'])
    optimizer.load_state_dict(states['optimizer_state_dict'])
    scheduler.load_state_dict(states['scheduler_state_dict'])
    scaler.load_state_dict(states['scaler_state_dict'])
    start_epoch = states['epoch'] + 1
    best_perplexity = states['perplexity']
else:
    start_epoch = 0
    best_perplexity = 1e30

if start_epoch == 0:
    print("Start training from scratch")
else:
    print(f"Resume training from epoch {start_epoch + 1}")

Start training from scratch


In [None]:
# Reference: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
def train_epoch(model, train_loader, optimizer, scheduler, scaler):
    model.train()
    optimizer.zero_grad()

    bar = tqdm.tqdm(train_loader, leave=False)
    loss_total = 0.

    for step, batch in enumerate(bar):
        outputs = model(**batch)
        loss = outputs.loss
        loss_total += loss.item()
        loss = loss / config.training.gradient_accumulation_steps
        scaler.scale(loss).backward()
  
        if (
                step % config.training.gradient_accumulation_steps == 0 or
                step == len(train_loader) - 1
        ):
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

        bar.set_postfix({"Loss": f"{loss_total / (step + 1):.4f}"})

    return loss_total / len(train_loader)

In [None]:
# Reference: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
def validation(model, val_loader):
    model.eval()

    bar = tqdm.tqdm(val_loader, leave=False)
    losses = []

    for step, batch in enumerate(bar):
        with torch.no_grad():
            outputs = model(**batch)

        batch_size = batch['input_ids'].shape[0]
        loss = outputs.loss.item()
        losses.extend([loss for _ in range(batch_size)])

        try:
            perplexity = math.exp(np.mean(losses))
        except OverflowError:
            perplexity = float('inf')

    return perplexity

In [None]:
# Reference: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
epoch_bar = tqdm.trange(start_epoch, config.training.epochs, leave=False)

for epoch in epoch_bar:
    loss = train_epoch(model, train_loader, optimizer, scheduler, scaler)
    perplexity = validation(model, val_loader)

    log = f"Epoch {epoch+1} Loss: {loss:.4f} Perplexity {perplexity:.4f}"
    epoch_bar.write(log)
    with open(log_file, 'a') as file:
        file.write(f"{log}\n")

    flag = False
    if perplexity < best_perplexity:
        best_perplexity = perplexity
        flag = True

    epoch_bar.write(f"Save model at epoch {epoch+1}")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': 
            scheduler.state_dict()
            if scheduler is not None else None,
        'scaler_state_dict': scaler.state_dict(),
        'epoch': epoch,
        'perplexity': perplexity,
        'best_perplexity': best_perplexity
    }, f"{exp_dir}/epoch-{epoch+1}.ckpt")
    if epoch != 0:
        prev_ckpt = f"{exp_dir}/epoch-{epoch}.ckpt"
        if os.path.exists(prev_ckpt):
            os.remove(f"{exp_dir}/epoch-{epoch}.ckpt")

    if flag:
        print(f"Save best model at epoch {epoch+1}")
        best_perplexity = perplexity
        shutil.copyfile(
            f"{exp_dir}/epoch-{epoch+1}.ckpt",
            f"{exp_dir}/best-model.ckpt")

In [57]:
def generate_limericks(
        model,
        config,
        prompts,
        generate_params,
        num_generation=10,
        batch_size=1,
        add_line_token=True,
):
    limericks = []

    use_bos = config.data.use_bos
    reverse = config.data.reverse
    order = config.data.order

    for prompt in tqdm.tqdm(prompts, leave=False):
        prompt = prompt.strip()
        if add_line_token:
            if prompt != "" and prompt[-6:] != "<LINE>":
                prompt += " <LINE>"
        if use_bos and prompt[:5] != "<BOS>":
            prompt = "<BOS> " + prompt

        if reverse is True:
              input_ids = reverse_line(
                  tokenizer(prompt, return_tensors="np").input_ids[0], use_bos)
              input_ids = torch.tensor(input_ids).reshape(1, -1)
        else:
              input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        input_ids = input_ids.to(device='cuda')
        input_ids = input_ids.repeat(batch_size, 1)

        outputs = []

        num_batches = num_generation // batch_size

        for _ in tqdm.trange(num_batches, leave=False):
            output = model.generate(
                input_ids, **generate_params,
                pad_token_id=tokenizer.eos_token_id)
            output = torch.unbind(output)
            outputs.extend(output)

        if reverse is True:
            reversed = []
            for output in outputs:
                output = torch.tensor(
                    reverse_line(output.cpu().numpy(), use_bos)).reshape(-1)
                reversed.append(output)
            outputs = torch.stack(reversed)
        else:
            outputs = torch.stack(outputs)

        outputs = tokenizer.batch_decode(
            outputs.cpu(),
            skip_special_tokens=False)

        for output in outputs:
            output = output.strip().split(" <LINE> ")[:-1]
            if len(output) != 5:
                continue
            if use_bos:
              output = [
                  line.replace("<BOS> ", "").strip()
                  for line in output]
            output = reorder(output, order)
            limericks.append(output)

    return limericks

In [37]:
def load_model(exp_dir, tmp_root="/content/test/"):
    config = OmegaConf.create(yaml.safe_load(open(exp_dir + "/config.yaml")))
    tokenizer = GPT2Tokenizer.from_pretrained(f"{exp_dir}/tokenizer")

    if not os.path.exists(tmp_root):
        os.makedirs(tmp_root, exist_ok=True)
    tmp_dir = tempfile.mkdtemp(dir=tmp_root)
    states = torch.load(f"{exp_dir}/best-model.ckpt")
    
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    model.resize_token_embeddings(len(tokenizer))
    model = model.cuda()
    model.load_state_dict(states['model_state_dict'])
    model.save_pretrained(tmp_dir)
    new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
    new_model = new_model.cuda()

    return config, tokenizer, new_model

In [40]:
exp_dir = f"/content/drive/MyDrive/11-785-final/ckpt/reverse-bos-gpt2"
config, tokenizer, model = load_model(exp_dir)

In [41]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = generate_limericks(
    model,
    config,
    [""],
    generate_params,
    num_generation=50,
    batch_size=10)

with open("free_form_2.txt", 'w') as file:
    for result in results:
        file.write('\n'.join(result))
        file.write('\n\n')

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

In [58]:
def generate_limericks_two_stage(
        standard_lm,
        reverse_lm,
        standard_config,
        reverse_config,
        prompts,
        generate_params,
        num_generation_1=10,
        num_generation_2=1,
        batch_size=1,
):

    limericks = []
    for prompt in tqdm.tqdm(prompts, leave=False):
        # generate first line
        prompt = prompt.strip()
        if standard_config.data.use_bos and prompt[:5] != "<BOS>":
            prompt = "<BOS> " + prompt

        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        input_ids = input_ids.to(device='cuda')
        input_ids = input_ids.repeat(batch_size, 1)

        outputs_1 = []
        num_batches = num_generation_1 // batch_size

        for _ in tqdm.trange(num_batches, leave=False):
            tmp_params = copy.deepcopy(generate_params)
            tmp_params["max_length"] = 20
            output = standard_lm.generate(
                input_ids, **generate_params,
                pad_token_id=tokenizer.eos_token_id)
            output = torch.unbind(output)
            outputs_1.extend(output)

        outputs_1 = torch.stack(outputs_1)
        outputs_1 = tokenizer.batch_decode(
            outputs_1.cpu(),
            skip_special_tokens=False)
        
        first_lines = []
        for output in outputs_1:
            output = output.strip().split(" <LINE> ")[0]
            first_lines.append(output)

        print(first_lines)

        outputs_2 = generate_limericks(
            reverse_lm,
            reverse_config,
            first_lines,
            generate_params,
            num_generation=num_generation_2,
            batch_size=batch_size)
        
        limericks.extend(outputs_2)

    return limericks

In [59]:
standard_exp_dir = "/content/drive/MyDrive/11-785-final/ckpt/bos-gpt2"
reverse_exp_dir = "/content/drive/MyDrive/11-785-final/ckpt/reverse-bos-gpt2"

standard_config, standard_tokenizer, standard_model = \
    load_model(standard_exp_dir)
reverse_config, reverse_tokenizer, reverse_model = \
    load_model(reverse_exp_dir)

In [60]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = generate_limericks_two_stage(
    standard_model,
    reverse_model,
    standard_config,
    reverse_config,
    ["once upon a time"],
    generate_params=generate_params,
    num_generation_1=10,
    num_generation_2=10,
    batch_size=10)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

['<BOS> once upon a time, when a starburst', "<BOS> once upon a time, man's most prized possession", '<BOS> once upon a time a chinese man mused', '<BOS> once upon a time there was a man', '<BOS> once upon a time it was thought', '<BOS> once upon a time, there was a guy', '<BOS> once upon a time, when i was a bloke', '<BOS> once upon a time, it was known', '<BOS> once upon a time i was taught', '<BOS> once upon a time i was young and dave']


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [61]:
with open("once_upon_a_time_two_stage.txt", 'w') as file:
    for result in results:
        file.write('\n'.join(result))
        file.write('\n\n')

In [None]:
data = json.load(open(f"{config.data.data_dir}/limericks.json"))
first_lines = []

for _, limerick in data['limericks'].items():
    lines = limerick['lines']
    flag = True

    # Remove the final punctuation of each line
    # (we'll use a special separator instead)
    for idx, line in enumerate(lines):
        if len(line) == 0:
            flag = False
            break
        if line[-1] in string.punctuation and line[-1] not in "\")":
            lines[idx] = line[:-1]
    
    if flag:
        first_lines.append(lines[0])

random.shuffle(first_lines)
first_lines = first_lines[:100]

In [None]:
for line in first_lines[:10]:
    print(line)

beware! cayenne peppers are hot
a concentrator system's a trick
though we know that you kids were just curious
to keep infantry marching in line
i'm not asking for blood from a stone
 prepares me for lies
it is part of god's clever design
my volume of verse isn't slim
now fifteen, she still acts like a tot
in my lifetime i've had two epiphanies


In [None]:
results = generate_limericks(
    ["once upon a time"],
    num_generation=500, batch_size=50)

In [None]:
with open("once_upon_a_time.txt", 'w') as file:
    for result in results:
        file.write('\n'.join(result))
        file.write('\n\n')

In [None]:
results = generate_limericks(
    ["once upon a time"],
    num_generation=500,
    batch_size=50,
    add_line_token=False)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
with open("once_upon_a_time_new.txt", 'w') as file:
    for result in results:
        file.write('\n'.join(result))
        file.write('\n\n')

In [None]:
num_generation = 50
batch_size = 10

results = generate_limericks(
    first_lines,
    num_generation=num_generation,
    batch_size=batch_size)

In [None]:
with open("sample.txt", 'w') as file:
    for result in results:
        file.write('\n'.join(result))
        file.write('\n\n')

In [None]:
results = generate_limericks([""], num_generation=500, batch_size=50)

with open("free_form.txt", 'w') as file:
    for result in results:
        file.write('\n'.join(result))
        file.write('\n\n')

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/10 [00:00<?, ?it/s]
<BOS> on my car, i'd always have been <LINE> in bed under the weather i'm snaky <LINE> the air mattress's the slipperiest piece <LINE> but a thistle that blows <LINE> when it grows, makes me drools <LINE> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
<BOS> your dad loves him. you're distraught <LINE> with your thoughts of him ever gone astray <LINE> and your mom says it's <LINE> though it's clear that you fear <LINE> to continue to fear <LINE> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>

In [None]:
results = generate_limericks(["if you're using a subsurface map"])
for result in results:
    print("\n".join(result))
    print()