In [None]:
# Start by installing required libraries (mainly Transformers)
!pip install transformers==4.17.0
!pip install scikit-learn
!pip install hydra-core

In [None]:
# Only needed when running in colab
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

In [None]:
!git clone https://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@github.com/coderalo/11785-automatic-poetry-generation.git

In [None]:
import copy
import glob
import json
import math
import numpy as np
import os
import random
import shutil
import string as string_utils
import sys
import tempfile
import torch
import torch.optim as optim
import tqdm.notebook as tqdm
import yaml

from hydra import compose
from hydra import initialize_config_dir
from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM
from transformers import GPT2LMHeadModel
from transformers import GPT2Model
from transformers import GPT2Tokenizer

In [None]:
%load_ext autoreload
%autoreload 2

sys.path.append("/content/11785-automatic-poetry-generation/")

from src.dataset import merge_lines, reorder, reverse_line
from src.dataset import LimerickDataset
from src.utils import load_dataset, get_tokenizer

In [None]:
def get_input_ids(
        prompt,
        tokenizer,
        use_bos,
        reverse,
        add_line_token
):
    """
    Arguments:
        prompt: str
        tokenizer: the tokenizer used to generate tokens
        use_bos: bool, use <BOS> token as the beginning of the prompt or not
        reverse: bool, revert the word order or not
        add_line_token: bool, add the <LINE> token at the end of prompt or not
    Return:
        input_ids: torch.LongTensor
    """
    prompt = prompt.strip()
    if add_line_token:
        if prompt != "" and prompt[-6:] != "<LINE>":
            prompt += " <LINE>"
    if use_bos and prompt[:5] != "<BOS>":
        prompt = "<BOS> " + prompt

    if reverse is True:
        input_ids = reverse_line(
            tokenizer(prompt, return_tensors="np").input_ids[0],
            use_bos, tokenizer)
        input_ids = torch.tensor(input_ids).reshape(1, -1)
    else:
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    return input_ids

In [None]:
def batch_decode(outputs, tokenizer, use_bos, reverse):
    """
    Arguments:
        outputs: List of torch.LongTensor
        tokenizer: the tokenizer used to decode tokens to words
        use_bos: bool, whether the <BOS> token is used or not
        reverse: bool, whether the tokens are in reverse order or not
    """
    if reverse is True:
        reversed = []
        for output in outputs:
            output = torch.tensor(
                reverse_line(
                    output.cpu().numpy(),
                    use_bos, tokenizer)
                ).reshape(-1)
            reversed.append(output)
        outputs = torch.stack(reversed)
    else:
        outputs = torch.stack(outputs)

    outputs = tokenizer.batch_decode(outputs.cpu(), skip_special_tokens=False)

    return outputs

In [None]:
def count_lines(prompt):
    return len(prompt.strip().split("<LINE>")) - 1

In [None]:
def generate_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size,
        add_line_token
):
    """
    Generate / finish one line of the limerick. The prompts should be in the 
    correct word order (you don't need to revert the words before passing into
    the function)
    """
    final_outputs = []

    use_bos = config.data.use_bos
    reverse = config.data.reverse
    order = config.data.order

    for prompt in tqdm.tqdm(prompts, leave=False):
        num_lines = count_lines(prompt)
        input_ids = get_input_ids(
            prompt, tokenizer,
            use_bos, reverse, add_line_token)
        input_ids = input_ids.to(device=config.device)
        input_ids = input_ids.repeat(batch_size, 1)

        outputs = []

        num_batches = num_generation // batch_size

        # Assume that a line cannot be longer than 30 tokens
        tmp_params = copy.deepcopy(generate_params)
        tmp_params["max_length"] = 30

        for _ in tqdm.trange(num_batches, leave=False):
            output = model.generate(
                input_ids, **tmp_params,
                pad_token_id=tokenizer.eos_token_id)
            output = torch.unbind(output)
            outputs.extend(output)

        outputs = batch_decode(outputs, tokenizer, use_bos, reverse)
        clean_outputs = []

        for output in outputs:
            new_num_lines = count_lines(output)
            if new_num_lines < num_lines + 1:
                continue
            output = output.strip().split(" <LINE> ")[:num_lines + 1]
            output = " <LINE> ".join(output) + " <LINE>"
            clean_outputs.append(output)

        final_outputs.extend(clean_outputs)
  
    return final_outputs

In [None]:
def generate_new_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size
):
    return generate_lines(
        model=model,
        tokenizer=tokenizer,
        config=config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation,
        batch_size=batch_size,
        add_line_token=True)
    

def finish_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size
):
    return generate_lines(
        model=model,
        tokenizer=tokenizer,
        config=config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation,
        batch_size=batch_size,
        add_line_token=False)

In [None]:
def generate_limericks(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation=10,
        batch_size=1,
        add_line_token=True,
):
    final_outputs = []

    use_bos = config.data.use_bos
    reverse = config.data.reverse
    order = config.data.order

    for prompt in tqdm.tqdm(prompts, leave=False):
        input_ids = get_input_ids(
            prompt, tokenizer,
            use_bos, reverse, add_line_token)
        input_ids = input_ids.to(device=config.device)
        input_ids = input_ids.repeat(batch_size, 1)

        outputs = []

        num_batches = num_generation // batch_size

        for _ in tqdm.trange(num_batches, leave=False):
            output = model.generate(
                input_ids, **generate_params,
                pad_token_id=tokenizer.eos_token_id)
            output = torch.unbind(output)
            outputs.extend(output)

        outputs = batch_decode(outputs, tokenizer, use_bos, reverse)
        clean_outputs = []

        for output in outputs:
            new_num_lines = count_lines(output)
            if new_num_lines < 5:
                continue
            output = output.strip().split(" <LINE> ")[:5]
            output = " <LINE> ".join(output) + " <LINE>"
            clean_outputs.append(output)

        final_outputs.extend(clean_outputs)

    return final_outputs

In [None]:
def generate_limericks_two_stage(
        standard_lm,
        reverse_lm,
        standard_config,
        reverse_config,
        standard_tokenizer,
        reverse_tokenizer,
        prompts,
        generate_params,
        num_generation_1=10,
        num_generation_2=1,
        batch_size=1,
):

    limericks = []
    for prompt in tqdm.tqdm(prompts, leave=False):
        # generate first line
        first_lines = finish_lines(
            model=standard_lm,
            tokenizer=standard_tokenizer,
            config=standard_config,
            prompts=[prompt],
            generate_params=generate_params,
            num_generation=num_generation_1,
            batch_size=batch_size)

        outputs = generate_limericks(
            model=reverse_lm,
            tokenizer=reverse_tokenizer,
            config=reverse_config,
            prompts=first_lines,
            generate_params=generate_params,
            num_generation=num_generation_2,
            batch_size=batch_size)
        
        limericks.extend(outputs)

    return limericks

In [None]:
def load_model(exp_dir, tmp_root="/content/test/"):
    config = OmegaConf.create(yaml.safe_load(open(exp_dir + "/config.yaml")))
    tokenizer = GPT2Tokenizer.from_pretrained(f"{exp_dir}/tokenizer")

    if not os.path.exists(tmp_root):
        os.makedirs(tmp_root, exist_ok=True)
    tmp_dir = tempfile.mkdtemp(dir=tmp_root)
    states = torch.load(f"{exp_dir}/best-model.ckpt")
    
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    model.resize_token_embeddings(len(tokenizer))
    model = model.cuda()
    model.load_state_dict(states['model_state_dict'])
    model.save_pretrained(tmp_dir)
    new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
    new_model = new_model.cuda()

    return config, tokenizer, new_model

## Example of one-stage generation

In [None]:
exp_dir = f"/content/drive/MyDrive/11-785-final/ckpt/bos-gpt2"
config, tokenizer, model = load_model(exp_dir)

In [None]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = generate_limericks(
    model,
    tokenizer,
    config,
    [""],
    generate_params,
    num_generation=50,
    batch_size=10,
    add_line_token=True)

for res in results:
    print(res)

## Example of two-stage generation

In [None]:
standard_exp_dir = "/content/drive/MyDrive/11-785-final/ckpt/bos-gpt2"
reverse_exp_dir = "/content/drive/MyDrive/11-785-final/ckpt/reverse-bos-gpt2"

standard_config, standard_tokenizer, standard_model = \
    load_model(standard_exp_dir)
reverse_config, reverse_tokenizer, reverse_model = \
    load_model(reverse_exp_dir)

In [None]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = generate_limericks_two_stage(
    standard_model,
    reverse_model,
    standard_config,
    reverse_config,
    standard_tokenizer,
    reverse_tokenizer,
    [""],
    generate_params=generate_params,
    num_generation_1=10,
    num_generation_2=10,
    batch_size=50)

results = [
    [
        line.strip() for line in 
        result.replace("<BOS> ", "").split("<LINE>")
    ] for result in results
]

results = [
    [line for line in result if line != ""]
    for result in results
]